1 /*
2 * Copyright (C) 2015 Alberto Irurueta Carro (alberto@irurueta.com)
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 package com.irurueta.numerical.fitting;
17
18 import com.irurueta.algebra.AlgebraException;
19 import com.irurueta.algebra.Matrix;
20 import com.irurueta.algebra.SingularValueDecomposer;
21 import com.irurueta.numerical.EvaluationException;
22 import com.irurueta.numerical.NotReadyException;
23
24 /**
25 * Fits provided data (x,y) to a function made of a linear combination of
26 * functions used as a basis (i.e. f(x1, x2, ...) = a * f0(x1, x2, ...) +
27 * b * f1(x1, x2, ...) + ...).
28 * Where f0, f1, ... is the function basis which ideally should be formed by
29 * orthogonal function.
30 * This class is based on the implementation available at Numerical Recipes
31 * 3rd Ed, page 795.
32 */
33 public class SvdMultiDimensionLinearFitter extends MultiDimensionLinearFitter {
34
35 /**
36 * Default tolerance.
37 */
38 public static final double DEFAULT_TOL = 1e-12;
39
40 /**
41 * Tolerance to define convergence threshold for SVD.
42 */
43 private double tol;
44
45 /**
46 * Constructor.
47 *
48 * @param x input points x where a linear multi-dimensional function
49 * f(x1, x2, ...) = a * f0(x1, x2, ...) + b * f1(x1, x2, ...) + ...
50 * @param y result of evaluation of linear multi-dimensional function
51 * f(x1, x2, ...) at provided x points.
52 * @param sig standard deviations of each pair of points (x, y).
53 * @throws IllegalArgumentException if provided matrix rows and arrays
54 * don't have the same length.
55 */
56 public SvdMultiDimensionLinearFitter(final Matrix x, final double[] y, final double[] sig) {
57 super(x, y, sig);
58 tol = DEFAULT_TOL;
59 }
60
61 /**
62 * Constructor.
63 *
64 * @param x input points x where a linear multi-dimensional function
65 * f(x1, x2, ...) = a * f0(x1, x2, ...) + b * f1(x1, x2, ...) + ...
66 * @param y result of evaluation of linear multi-dimensional function
67 * f(x1, x2, ...) at provided x points.
68 * @param sig standard deviation of all pair of points assuming that
69 * standard deviations are constant.
70 * @throws IllegalArgumentException if provided matrix rows and arrays
71 * don't have the same length.
72 */
73 public SvdMultiDimensionLinearFitter(final Matrix x, final double[] y, final double sig) {
74 super(x, y, sig);
75 tol = DEFAULT_TOL;
76 }
77
78 /**
79 * Constructor.
80 *
81 * @param evaluator evaluator to evaluate function at provided point and
82 * obtain the evaluation of function basis at such point.
83 * @throws FittingException if evaluation fails.
84 */
85 public SvdMultiDimensionLinearFitter(final LinearFitterMultiDimensionFunctionEvaluator evaluator)
86 throws FittingException {
87 super(evaluator);
88 tol = DEFAULT_TOL;
89 }
90
91 /**
92 * Constructor.
93 *
94 * @param evaluator evaluator to evaluate function at provided point and
95 * obtain the evaluation of function basis at such point.
96 * @param x input points x where a linear multi-dimensional function
97 * f(x1, x2, ...) = a * f0(x1, x2, ...) + b * f1(x1, x2, ...) + ...
98 * @param y result of evaluation of linear multi-dimensional function
99 * f(x1, x2, ...) at provided x points.
100 * @param sig standard deviations of each pair of points (x, y).
101 * @throws FittingException if evaluation fails.
102 * @throws IllegalArgumentException if provided matrix rows and arrays
103 * don't have the same length.
104 */
105 public SvdMultiDimensionLinearFitter(
106 final LinearFitterMultiDimensionFunctionEvaluator evaluator, final Matrix x, final double[] y,
107 final double[] sig) throws FittingException {
108 super(evaluator, x, y, sig);
109 tol = DEFAULT_TOL;
110 }
111
112 /**
113 * Constructor.
114 *
115 * @param evaluator evaluator to evaluate function at provided point and
116 * obtain the evaluation of function basis at such point.
117 * @param x input points x where a linear multi-dimensional function
118 * f(x1, x2, ...) = a * f0(x1, x2, ...) + b * f1(x1, x2, ...) + ...
119 * @param y result of evaluation of linear multi-dimensional function
120 * f(x1, x2, ...) at provided x points.
121 * @param sig standard deviation of all pair of points assuming that
122 * standard deviations are constant.
123 * @throws FittingException if evaluation fails.
124 * @throws IllegalArgumentException if provided matrix rows and arrays
125 * don't have the same length.
126 */
127 public SvdMultiDimensionLinearFitter(
128 final LinearFitterMultiDimensionFunctionEvaluator evaluator, final Matrix x, final double[] y,
129 final double sig) throws FittingException {
130 super(evaluator, x, y, sig);
131 tol = DEFAULT_TOL;
132 }
133
134 /**
135 * Constructor.
136 */
137 SvdMultiDimensionLinearFitter() {
138 super();
139 tol = DEFAULT_TOL;
140 }
141
142 /**
143 * Returns tolerance to define convergence threshold for SVD.
144 *
145 * @return tolerance to define convergence threshold for SVD.
146 */
147 public double getTol() {
148 return tol;
149 }
150
151 /**
152 * Sets tolerance to define convergence threshold for SVD.
153 *
154 * @param tol tolerance to define convergence threshold for SVD.
155 */
156 public void setTol(final double tol) {
157 this.tol = tol;
158 }
159
160 /**
161 * Fits a function to provided data so that parameters associated to that
162 * function can be estimated along with their covariance matrix and chi
163 * square value.
164 *
165 * @throws FittingException if fitting fails.
166 * @throws NotReadyException if enough input data has not yet been provided.
167 */
168 @SuppressWarnings("DuplicatedCode")
169 @Override
170 public void fit() throws FittingException, NotReadyException {
171 if (!isReady()) {
172 throw new NotReadyException();
173 }
174
175 final var xRow = new double[x.getColumns()];
176 final var xCols = evaluator.getNumberOfDimensions();
177
178 try {
179 resultAvailable = false;
180
181 int i;
182 int j;
183 int k;
184 double tmp;
185 final double thresh;
186 double sum;
187 final var aa = new Matrix(ndat, ma);
188 final var b = new double[ndat];
189 for (i = 0; i < ndat; i++) {
190 x.getSubmatrixAsArray(i, 0, i, xCols - 1, xRow);
191 evaluator.evaluate(xRow, afunc);
192 tmp = 1.0 / sig[i];
193 for (j = 0; j < ma; j++) {
194 aa.setElementAt(i, j, afunc[j] * tmp);
195 }
196 b[i] = y[i] * tmp;
197 }
198
199 final var svd = new SingularValueDecomposer(aa);
200 svd.decompose();
201 thresh = (tol > 0. ? tol * svd.getSingularValues()[0] : -1.0);
202 svd.solve(b, thresh, a);
203 chisq = 0.0;
204 for (i = 0; i < ndat; i++) {
205 sum = 0.0;
206 for (j = 0; j < ma; j++) {
207 sum += aa.getElementAt(i, j) * a[j];
208 }
209 chisq += Math.pow(sum - b[i], 2.0);
210 }
211 for (i = 0; i < ma; i++) {
212 for (j = 0; j < i + 1; j++) {
213 sum = 0.0;
214 final var w = svd.getSingularValues();
215 final var tsh = svd.getNegligibleSingularValueThreshold();
216 final var v = svd.getV();
217 for (k = 0; k < ma; k++) {
218 if (w[k] > tsh) {
219 sum += v.getElementAt(i, k) * v.getElementAt(j, k) / Math.pow(w[k], 2.0);
220 }
221 }
222 covar.setElementAt(j, i, sum);
223 covar.setElementAt(i, j, sum);
224 }
225 }
226
227 resultAvailable = true;
228
229 } catch (final AlgebraException | EvaluationException e) {
230 throw new FittingException(e);
231 }
232 }
233 }