View Javadoc
1   /*
2    * Copyright (C) 2015 Alberto Irurueta Carro (alberto@irurueta.com)
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *         http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  package com.irurueta.numerical.fitting;
17  
18  import com.irurueta.algebra.AlgebraException;
19  import com.irurueta.algebra.Matrix;
20  import com.irurueta.algebra.SingularValueDecomposer;
21  import com.irurueta.numerical.EvaluationException;
22  import com.irurueta.numerical.NotReadyException;
23  
24  /**
25   * Fits provided data (x,y) to a function made of a linear combination of
26   * functions used as a basis (i.e. f(x1, x2, ...) = a * f0(x1, x2, ...) +
27   * b * f1(x1, x2, ...) + ...).
28   * Where f0, f1, ... is the function basis which ideally should be formed by
29   * orthogonal function.
30   * This class is based on the implementation available at Numerical Recipes
31   * 3rd Ed, page 795.
32   */
33  public class SvdMultiDimensionLinearFitter extends MultiDimensionLinearFitter {
34  
35      /**
36       * Default tolerance.
37       */
38      public static final double DEFAULT_TOL = 1e-12;
39  
40      /**
41       * Tolerance to define convergence threshold for SVD.
42       */
43      private double tol;
44  
45      /**
46       * Constructor.
47       *
48       * @param x   input points x where a linear multi-dimensional function
49       *            f(x1, x2, ...) = a * f0(x1, x2, ...) + b * f1(x1, x2, ...) + ...
50       * @param y   result of evaluation of linear multi-dimensional function
51       *            f(x1, x2, ...) at provided x points.
52       * @param sig standard deviations of each pair of points (x, y).
53       * @throws IllegalArgumentException if provided matrix rows and arrays
54       *                                  don't have the same length.
55       */
56      public SvdMultiDimensionLinearFitter(final Matrix x, final double[] y, final double[] sig) {
57          super(x, y, sig);
58          tol = DEFAULT_TOL;
59      }
60  
61      /**
62       * Constructor.
63       *
64       * @param x   input points x where a linear multi-dimensional function
65       *            f(x1, x2, ...) = a * f0(x1, x2, ...) + b * f1(x1, x2, ...) + ...
66       * @param y   result of evaluation of linear multi-dimensional function
67       *            f(x1, x2, ...) at provided x points.
68       * @param sig standard deviation of all pair of points assuming that
69       *            standard deviations are constant.
70       * @throws IllegalArgumentException if provided matrix rows and arrays
71       *                                  don't have the same length.
72       */
73      public SvdMultiDimensionLinearFitter(final Matrix x, final double[] y, final double sig) {
74          super(x, y, sig);
75          tol = DEFAULT_TOL;
76      }
77  
78      /**
79       * Constructor.
80       *
81       * @param evaluator evaluator to evaluate function at provided point and
82       *                  obtain the evaluation of function basis at such point.
83       * @throws FittingException if evaluation fails.
84       */
85      public SvdMultiDimensionLinearFitter(final LinearFitterMultiDimensionFunctionEvaluator evaluator)
86              throws FittingException {
87          super(evaluator);
88          tol = DEFAULT_TOL;
89      }
90  
91      /**
92       * Constructor.
93       *
94       * @param evaluator evaluator to evaluate function at provided point and
95       *                  obtain the evaluation of function basis at such point.
96       * @param x         input points x where a linear multi-dimensional function
97       *                  f(x1, x2, ...) = a * f0(x1, x2, ...) + b * f1(x1, x2, ...) + ...
98       * @param y         result of evaluation of linear multi-dimensional function
99       *                  f(x1, x2, ...) at provided x points.
100      * @param sig       standard deviations of each pair of points (x, y).
101      * @throws FittingException         if evaluation fails.
102      * @throws IllegalArgumentException if provided matrix rows and arrays
103      *                                  don't have the same length.
104      */
105     public SvdMultiDimensionLinearFitter(
106             final LinearFitterMultiDimensionFunctionEvaluator evaluator, final Matrix x, final double[] y,
107             final double[] sig) throws FittingException {
108         super(evaluator, x, y, sig);
109         tol = DEFAULT_TOL;
110     }
111 
112     /**
113      * Constructor.
114      *
115      * @param evaluator evaluator to evaluate function at provided point and
116      *                  obtain the evaluation of function basis at such point.
117      * @param x         input points x where a linear multi-dimensional function
118      *                  f(x1, x2, ...) = a * f0(x1, x2, ...) + b * f1(x1, x2, ...) + ...
119      * @param y         result of evaluation of linear multi-dimensional function
120      *                  f(x1, x2, ...) at provided x points.
121      * @param sig       standard deviation of all pair of points assuming that
122      *                  standard deviations are constant.
123      * @throws FittingException         if evaluation fails.
124      * @throws IllegalArgumentException if provided matrix rows and arrays
125      *                                  don't have the same length.
126      */
127     public SvdMultiDimensionLinearFitter(
128             final LinearFitterMultiDimensionFunctionEvaluator evaluator, final Matrix x, final double[] y,
129             final double sig) throws FittingException {
130         super(evaluator, x, y, sig);
131         tol = DEFAULT_TOL;
132     }
133 
134     /**
135      * Constructor.
136      */
137     SvdMultiDimensionLinearFitter() {
138         super();
139         tol = DEFAULT_TOL;
140     }
141 
142     /**
143      * Returns tolerance to define convergence threshold for SVD.
144      *
145      * @return tolerance to define convergence threshold for SVD.
146      */
147     public double getTol() {
148         return tol;
149     }
150 
151     /**
152      * Sets tolerance to define convergence threshold for SVD.
153      *
154      * @param tol tolerance to define convergence threshold for SVD.
155      */
156     public void setTol(final double tol) {
157         this.tol = tol;
158     }
159 
160     /**
161      * Fits a function to provided data so that parameters associated to that
162      * function can be estimated along with their covariance matrix and chi
163      * square value.
164      *
165      * @throws FittingException  if fitting fails.
166      * @throws NotReadyException if enough input data has not yet been provided.
167      */
168     @SuppressWarnings("DuplicatedCode")
169     @Override
170     public void fit() throws FittingException, NotReadyException {
171         if (!isReady()) {
172             throw new NotReadyException();
173         }
174 
175         final var xRow = new double[x.getColumns()];
176         final var xCols = evaluator.getNumberOfDimensions();
177 
178         try {
179             resultAvailable = false;
180 
181             int i;
182             int j;
183             int k;
184             double tmp;
185             final double thresh;
186             double sum;
187             final var aa = new Matrix(ndat, ma);
188             final var b = new double[ndat];
189             for (i = 0; i < ndat; i++) {
190                 x.getSubmatrixAsArray(i, 0, i, xCols - 1, xRow);
191                 evaluator.evaluate(xRow, afunc);
192                 tmp = 1.0 / sig[i];
193                 for (j = 0; j < ma; j++) {
194                     aa.setElementAt(i, j, afunc[j] * tmp);
195                 }
196                 b[i] = y[i] * tmp;
197             }
198 
199             final var svd = new SingularValueDecomposer(aa);
200             svd.decompose();
201             thresh = (tol > 0. ? tol * svd.getSingularValues()[0] : -1.0);
202             svd.solve(b, thresh, a);
203             chisq = 0.0;
204             for (i = 0; i < ndat; i++) {
205                 sum = 0.0;
206                 for (j = 0; j < ma; j++) {
207                     sum += aa.getElementAt(i, j) * a[j];
208                 }
209                 chisq += Math.pow(sum - b[i], 2.0);
210             }
211             for (i = 0; i < ma; i++) {
212                 for (j = 0; j < i + 1; j++) {
213                     sum = 0.0;
214                     final var w = svd.getSingularValues();
215                     final var tsh = svd.getNegligibleSingularValueThreshold();
216                     final var v = svd.getV();
217                     for (k = 0; k < ma; k++) {
218                         if (w[k] > tsh) {
219                             sum += v.getElementAt(i, k) * v.getElementAt(j, k) / Math.pow(w[k], 2.0);
220                         }
221                     }
222                     covar.setElementAt(j, i, sum);
223                     covar.setElementAt(i, j, sum);
224                 }
225             }
226 
227             resultAvailable = true;
228 
229         } catch (final AlgebraException | EvaluationException e) {
230             throw new FittingException(e);
231         }
232     }
233 }