View Javadoc
1   /*
2    * Copyright (C) 2015 Alberto Irurueta Carro (alberto@irurueta.com)
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *         http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  package com.irurueta.numerical.fitting;
17  
18  import com.irurueta.numerical.NotReadyException;
19  import com.irurueta.statistics.Gamma;
20  import com.irurueta.statistics.MaxIterationsExceededException;
21  
22  /**
23   * Fits provided data (x,y) to a straight line following equation y = a + b*x,
24   * estimates parameters a and b their variances, covariance and their chi square
25   * value.
26   * This class is based on the implementation available at Numerical Recipes
27   * 3rd Ed, page 784.
28   */
29  public class StraightLineFitter extends Fitter {
30  
31      /**
32       * Array containing x coordinates of input data to be fitted to a straight
33       * line.
34       */
35      private double[] x;
36  
37      /**
38       * Array containing y coordinates of input data to be fitted to a straight
39       * line.
40       */
41      private double[] y;
42  
43      /**
44       * Standard deviations of each pair of points (x,y). This is optional, if
45       * not provided, variances of a and b will be estimated assuming equal
46       * error for all input points.
47       */
48      private double[] sig;
49  
50      /**
51       * Estimated "a" parameter of line following equation y = a + b*x
52       */
53      private double a;
54  
55      /**
56       * Estimated "b" parameter of line following equation y = a + b*X
57       */
58      private double b;
59  
60      /**
61       * Estimated standard deviation of parameter "a".
62       */
63      private double siga;
64  
65      /**
66       * Estimated standard deviation of parameter "b".
67       */
68      private double sigb;
69  
70      /**
71       * Estimated chi square value.
72       */
73      private double chi2;
74  
75      /**
76       * Estimated goodness-of-fit probability (i.e. that the fit would have a
77       * chi square value equal or larger than the estimated one).
78       */
79      private double q;
80  
81      /**
82       * Estimated standard deviation of provided input data. This is only
83       * estimated if array of standard deviations of input points is not provided.
84       */
85      private double sigdat;
86  
87      /**
88       * Constructor.
89       */
90      public StraightLineFitter() {
91          q = 1.0;
92          chi2 = sigdat = 0.0;
93      }
94  
95      /**
96       * Constructor.
97       *
98       * @param x x coordinates of input data to be fitted to a straight line.
99       * @param y y coordinates of input data to be fitted to a straight line.
100      * @throws IllegalArgumentException if provided arrays don't have the same
101      *                                  length.
102      */
103     public StraightLineFitter(final double[] x, final double[] y) {
104         this();
105         setInputData(x, y);
106     }
107 
108     /**
109      * Constructor.
110      *
111      * @param x   x coordinates of input data to be fitted to a straight line.
112      * @param y   y coordinates of input data to be fitted to a straight line.
113      * @param sig standard deviation (i.e. errors) of provided data. This is
114      *            optional, if not provided, variances of a and b will be estimated
115      *            assuming equal error for all input points.
116      * @throws IllegalArgumentException if provided arrays don't have the same
117      *                                  length.
118      */
119     public StraightLineFitter(final double[] x, final double[] y, final double[] sig) {
120         this();
121         setInputDataAndStandardDeviations(x, y, sig);
122     }
123 
124     /**
125      * Returns array containing x coordinates of input data to be fitted to a
126      * straight line.
127      *
128      * @return array containing x coordinates of input data to be fitted to a
129      * straight line.
130      */
131     public double[] getX() {
132         return x;
133     }
134 
135     /**
136      * Returns array containing y coordinates of input data to be fitted to a
137      * straight line.
138      *
139      * @return array containing y coordinates of input data to be fitted to a
140      * straight line.
141      */
142     public double[] getY() {
143         return y;
144     }
145 
146     /**
147      * Returns standard deviations of each pair of points (x,y). This is
148      * optional, if not provided, variances of a and b will be estimated
149      * assuming equal error for all input points.
150      *
151      * @return standard deviations of each pair of points (x,y).
152      */
153     public double[] getSig() {
154         return sig;
155     }
156 
157     /**
158      * Sets input data to fit a straight line to.
159      *
160      * @param x x coordinates.
161      * @param y y coordinates.
162      * @throws IllegalArgumentException if arrays don't have the same length.
163      */
164     public final void setInputData(final double[] x, final double[] y) {
165         if (x.length != y.length) {
166             throw new IllegalArgumentException();
167         }
168 
169         this.x = x;
170         this.y = y;
171         this.sig = null;
172     }
173 
174     /**
175      * Sets input data and standard deviations of input data to fit a straight
176      * line to.
177      *
178      * @param x   x coordinates.
179      * @param y   y coordinates.
180      * @param sig standard deviations of each pair of points (x,y). This is
181      *            optional, if not provided, variances of a and b will be estimated
182      *            assuming equal error for all input points.
183      * @throws IllegalArgumentException if arrays don't have the same length.
184      */
185     public final void setInputDataAndStandardDeviations(
186             final double[] x, final double[] y, final double[] sig) {
187         if (sig != null) {
188             if (x.length != y.length || y.length != sig.length) {
189                 throw new IllegalArgumentException();
190             }
191 
192             this.x = x;
193             this.y = y;
194             this.sig = sig;
195         } else {
196             setInputData(x, y);
197         }
198     }
199 
200 
201     /**
202      * Indicates whether this instance is ready because enough input data has
203      * been provided to start the fitting process.
204      *
205      * @return true if this fitter is ready, false otherwise.
206      */
207     @Override
208     public boolean isReady() {
209         return x != null && y != null && x.length == y.length && (sig == null || sig.length == y.length);
210     }
211 
212     /**
213      * Returns estimated "a" parameter of line following equation y = a + b*x
214      *
215      * @return estimated "a" parameter.
216      */
217     public double getA() {
218         return a;
219     }
220 
221     /**
222      * Returns estimated "b" parameter of line following equation y = a + b*x
223      *
224      * @return estimated "b" parameter
225      */
226     public double getB() {
227         return b;
228     }
229 
230     /**
231      * Returns estimated standard deviation of parameter "a".
232      *
233      * @return estimated standard deviation of parameter "a".
234      */
235     public double getSigA() {
236         return siga;
237     }
238 
239     /**
240      * Returns estimated standard deviation of parameter "b".
241      *
242      * @return estimated standard deviation of parameter "b".
243      */
244     public double getSigB() {
245         return sigb;
246     }
247 
248     /**
249      * Returns estimated chi square value.
250      *
251      * @return estimated chi square value.
252      */
253     public double getChi2() {
254         return chi2;
255     }
256 
257     /**
258      * Returns estimated goodness-of-fit probability (i.e. that the fit would
259      * have a chi square value equal or larger than the estimated one).
260      *
261      * @return estimated goodness-of-fit probability.
262      */
263     public double getQ() {
264         return q;
265     }
266 
267     /**
268      * Returns estimated standard deviation of provided input data. This is only
269      * estimated if array of standard deviations of input points is not provided.
270      *
271      * @return estimated standard deviation of provided input data.
272      */
273     public double getSigdat() {
274         return sigdat;
275     }
276 
277     /**
278      * Fits a straight line following equation y = a + b*x to provided data
279      * (x, y) so that parameters associated a, b can be estimated along with
280      * their variances, covariance and chi square value.
281      *
282      * @throws FittingException  if fitting fails.
283      * @throws NotReadyException if enough input data has not yet been provided.
284      */
285     @Override
286     public void fit() throws FittingException, NotReadyException {
287         if (!isReady()) {
288             throw new NotReadyException();
289         }
290 
291         resultAvailable = false;
292 
293         if (sig != null) {
294             fitWithSig();
295         } else {
296             fitWithoutSig();
297         }
298 
299         resultAvailable = true;
300     }
301 
302     /**
303      * Fits data when standard deviations of input data is provided.
304      *
305      * @throws FittingException if fitting fails.
306      */
307     private void fitWithSig() throws FittingException {
308         final var gam = new Gamma();
309         int i;
310         double ss = 0.0;
311         double sx = 0.0;
312         double sy = 0.0;
313         double st2 = 0.0;
314         double t;
315         double wt;
316         final double sxoss;
317         final var ndata = x.length;
318         b = 0.0;
319         for (i = 0; i < ndata; i++) {
320             wt = 1.0 / Math.pow(sig[i], 2.0);
321             ss += wt;
322             sx += x[i] * wt;
323             sy += y[i] * wt;
324         }
325         sxoss = sx / ss;
326         for (i = 0; i < ndata; i++) {
327             t = (x[i] - sxoss) / sig[i];
328             st2 += t * t;
329             b += t * y[i] / sig[i];
330         }
331         b /= st2;
332         a = (sy - sx * b) / ss;
333         siga = Math.sqrt((1.0 + sx * sx / (ss * st2)) / ss);
334         sigb = Math.sqrt(1.0 / st2);
335         for (i = 0; i < ndata; i++) {
336             chi2 += Math.pow((y[i] - a - b * x[i]) / sig[i], 2.0);
337         }
338         try {
339             if (ndata > 2) {
340                 q = gam.gammq(0.5 * (ndata - 2), 0.5 * chi2);
341             }
342         } catch (final MaxIterationsExceededException e) {
343             throw new FittingException(e);
344         }
345     }
346 
347     /**
348      * Fits data when standard deviations of input data is not provided.
349      */
350     private void fitWithoutSig() {
351         int i;
352         final double ss;
353         var sx = 0.0;
354         var sy = 0.0;
355         var st2 = 0.0;
356         double t;
357         final double sxoss;
358         final var ndata = x.length;
359         b = 0.0;
360         for (i = 0; i < ndata; i++) {
361             sx += x[i];
362             sy += y[i];
363         }
364         ss = ndata;
365         sxoss = sx / ss;
366         for (i = 0; i < ndata; i++) {
367             t = x[i] - sxoss;
368             st2 += t * t;
369             b += t * y[i];
370         }
371         b /= st2;
372         a = (sy - sx * b) / ss;
373         siga = Math.sqrt((1.0 + sx * sx / (ss * st2)) / ss);
374         sigb = Math.sqrt(1.0 / st2);
375         for (i = 0; i < ndata; i++) {
376             chi2 += Math.pow(y[i] - a - b * x[i], 2.0);
377         }
378         if (ndata > 2) {
379             sigdat = Math.sqrt(chi2 / (ndata - 2));
380         }
381         siga *= sigdat;
382         sigb *= sigdat;
383     }
384 }