Skip to main content

nyx_space/polyfit/
polynomial.rs

1/*
2    Nyx, blazing fast astrodynamics
3    Copyright (C) 2018-onwards Christopher Rabotin <christopher.rabotin@gmail.com>
4
5    This program is free software: you can redistribute it and/or modify
6    it under the terms of the GNU Affero General Public License as published
7    by the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU Affero General Public License for more details.
14
15    You should have received a copy of the GNU Affero General Public License
16    along with this program.  If not, see <https://www.gnu.org/licenses/>.
17*/
18
19/* NOTE: This code is effectively a clone of bacon-sci, MIT License, by Wyatt Campbell. */
20use serde::{Deserialize, Serialize};
21use serde_dhall::StaticType;
22use std::fmt;
23use std::ops;
24
25use crate::NyxError;
26
27/// Polynomial is a statically allocated polynomial.
28#[derive(Copy, Clone, Debug, PartialEq)]
29pub struct Polynomial<const SIZE: usize> {
30    /// Coefficients are orders by their power, e.g. index 0 is to the power 0, 1 is linear, 2 is quadratic, etc.
31    pub coefficients: [f64; SIZE],
32}
33
34impl<const SIZE: usize> Polynomial<SIZE> {
35    pub fn from_most_significant(mut coeffs: [f64; SIZE]) -> Self {
36        coeffs.reverse();
37        Self {
38            coefficients: coeffs,
39        }
40    }
41
42    /// Get the order of the polynomial
43    pub const fn order(&self) -> usize {
44        SIZE - 1
45    }
46
47    /// Evaluate the polynomial at the provided position
48    pub fn eval(&self, x: f64) -> f64 {
49        self.eval_n_deriv(x).0
50    }
51
52    /// Evaluate the derivative at the provided position
53    pub fn deriv(&self, x: f64) -> f64 {
54        self.eval_n_deriv(x).1
55    }
56
57    /// Evaluate the polynomial and its derivative at the provided position
58    pub fn eval_n_deriv(&self, x: f64) -> (f64, f64) {
59        if SIZE == 1 {
60            return (self.coefficients[0], 0.0);
61        }
62
63        // Start with biggest coefficients
64        let mut acc_eval = *self.coefficients.last().unwrap();
65        let mut acc_deriv = *self.coefficients.last().unwrap();
66        // For every coefficient except the constant and largest
67        for val in self.coefficients.iter().skip(1).rev().skip(1) {
68            acc_eval = acc_eval * x + *val;
69            acc_deriv = acc_deriv * x + acc_eval;
70        }
71        // Do the constant for the polynomial evaluation
72        acc_eval = x * acc_eval + self.coefficients[0];
73
74        (acc_eval, acc_deriv)
75    }
76
77    /// Initializes a Polynomial with only zeros
78    pub fn zeros() -> Self {
79        Self {
80            coefficients: [0.0; SIZE],
81        }
82    }
83
84    /// Set the i-th power of this polynomial to zero (e.g. if i=0, set the x^0 coefficient to zero, i.e. the constant part goes to zero)
85    pub fn zero_power(&mut self, i: usize) {
86        if i < SIZE {
87            self.coefficients[i] = 0.0;
88        }
89    }
90
91    /// Set all of the coefficients below this tolerance to zero
92    pub fn zero_below_tolerance(&mut self, tol: f64) {
93        for i in 0..=self.order() {
94            if self.coefficients[i].abs() < tol {
95                self.zero_power(i);
96            }
97        }
98    }
99
100    /// Returns true if any of the coefficients are NaN
101    pub fn is_nan(&self) -> bool {
102        for c in self.coefficients {
103            if c.is_nan() {
104                return true;
105            }
106        }
107        false
108    }
109
110    fn fmt_with_var(&self, f: &mut fmt::Formatter, var: String) -> fmt::Result {
111        write!(f, "P({var}) = ")?;
112        let mut data = Vec::with_capacity(SIZE);
113
114        for (i, c) in self.coefficients.iter().enumerate().rev() {
115            if c.abs() <= f64::EPSILON {
116                continue;
117            }
118
119            let mut d;
120            if c.abs() > 100.0 || c.abs() < 0.01 {
121                // Use scientific notation
122                if c > &0.0 {
123                    d = format!("+{c:e}");
124                } else {
125                    d = format!("{c:e}");
126                }
127            } else if c > &0.0 {
128                d = format!("+{c}");
129            } else {
130                d = format!("{c}");
131            }
132            // Add the power
133            let p = i;
134            match p {
135                0 => {} // Show nothing for zero
136                1 => d = format!("{d}{var}"),
137                _ => d = format!("{d}{var}^{p}"),
138            }
139            data.push(d);
140        }
141        write!(f, "{}", data.join(" "))
142    }
143}
144
145/// In-place multiplication of a polynomial with an f64
146impl<const SIZE: usize> ops::Mul<f64> for Polynomial<SIZE> {
147    type Output = Polynomial<SIZE>;
148
149    fn mul(mut self, rhs: f64) -> Self::Output {
150        for val in &mut self.coefficients {
151            *val *= rhs;
152        }
153        self
154    }
155}
156
157/// Clone current polynomial and then multiply it with an f64
158impl<const SIZE: usize> ops::Mul<f64> for &Polynomial<SIZE> {
159    type Output = Polynomial<SIZE>;
160
161    fn mul(self, rhs: f64) -> Self::Output {
162        *self * rhs
163    }
164}
165
166/// In-place multiplication of a polynomial with an f64
167impl<const SIZE: usize> ops::Mul<Polynomial<SIZE>> for f64 {
168    type Output = Polynomial<SIZE>;
169
170    fn mul(self, rhs: Polynomial<SIZE>) -> Self::Output {
171        let mut me = rhs;
172        for val in &mut me.coefficients {
173            *val *= self;
174        }
175        me
176    }
177}
178
179impl<const SIZE: usize> ops::AddAssign<f64> for Polynomial<SIZE> {
180    fn add_assign(&mut self, rhs: f64) {
181        self.coefficients[0] += rhs;
182    }
183}
184
185impl<const SIZE: usize> fmt::Display for Polynomial<SIZE> {
186    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
187        self.fmt_with_var(f, "t".to_string())
188    }
189}
190
191impl<const SIZE: usize> fmt::LowerHex for Polynomial<SIZE> {
192    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
193        self.fmt_with_var(f, "x".to_string())
194    }
195}
196
197pub(crate) fn add<const S1: usize, const S2: usize>(
198    p1: Polynomial<S1>,
199    p2: Polynomial<S2>,
200) -> Polynomial<S1> {
201    if S1 < S2 {
202        panic!();
203    }
204    let mut rtn = Polynomial::zeros();
205    for (i, c1) in p1.coefficients.iter().enumerate() {
206        rtn.coefficients[i] = match p2.coefficients.get(i) {
207            Some(c2) => c1 + c2,
208            None => *c1,
209        };
210    }
211    rtn
212}
213
214impl<const S1: usize, const S2: usize> ops::Add<Polynomial<S1>> for Polynomial<S2> {
215    type Output = Polynomial<S1>;
216    /// Add Self and Other, _IF_ S2 >= S1 (else panic!)
217    fn add(self, other: Polynomial<S1>) -> Self::Output {
218        add(other, self)
219    }
220}
221
222/// Subtracts p1 from p2 (p3 = p1 - p2)
223pub(crate) fn sub<const S1: usize, const S2: usize>(
224    p1: Polynomial<S1>,
225    p2: Polynomial<S2>,
226) -> Polynomial<S1> {
227    if S1 < S2 {
228        panic!();
229    }
230    let mut rtn = Polynomial::zeros();
231    for (i, c1) in p1.coefficients.iter().enumerate() {
232        rtn.coefficients[i] = match p2.coefficients.get(i) {
233            Some(c2) => c1 - c2,
234            None => *c1,
235        };
236    }
237    rtn
238}
239
240impl<const S1: usize, const S2: usize> ops::Sub<Polynomial<S2>> for Polynomial<S1> {
241    type Output = Polynomial<S1>;
242    fn sub(self, other: Polynomial<S2>) -> Self::Output {
243        sub(self, other)
244    }
245}
246
247#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize, StaticType)]
248pub enum CommonPolynomial {
249    Constant {
250        a: f64,
251    },
252    /// Linear(a, b) <=> f(x) = ax + b (order is FLIPPED from Polynomial<N> structure)
253    Linear {
254        a: f64,
255        b: f64,
256    },
257    /// Quadratic(a, b, c) <=> f(x) = ax^2 + bx + c (order is FLIPPED from Polynomial<N> structure)
258    Quadratic {
259        a: f64,
260        b: f64,
261        c: f64,
262    },
263}
264
265impl CommonPolynomial {
266    pub fn eval(&self, x: f64) -> f64 {
267        match *self {
268            Self::Constant { a } => Polynomial::<1> { coefficients: [a] }.eval(x),
269            Self::Linear { a, b } => Polynomial::<2> {
270                coefficients: [b, a],
271            }
272            .eval(x),
273            Self::Quadratic { a, b, c } => Polynomial::<3> {
274                coefficients: [c, b, a],
275            }
276            .eval(x),
277        }
278    }
279
280    pub fn deriv(&self, x: f64) -> f64 {
281        match *self {
282            Self::Constant { a } => Polynomial::<1> { coefficients: [a] }.deriv(x),
283            Self::Linear { a, b } => Polynomial::<2> {
284                coefficients: [b, a],
285            }
286            .deriv(x),
287            Self::Quadratic { a, b, c } => Polynomial::<3> {
288                coefficients: [c, b, a],
289            }
290            .deriv(x),
291        }
292    }
293
294    pub fn coeff_in_order(&self, order: usize) -> Result<f64, NyxError> {
295        match *self {
296            Self::Constant { a } => {
297                if order == 0 {
298                    Ok(a)
299                } else {
300                    Err(NyxError::PolynomialOrderError { order })
301                }
302            }
303            Self::Linear { a, b } => match order {
304                0 => Ok(b),
305                1 => Ok(a),
306                _ => Err(NyxError::PolynomialOrderError { order }),
307            },
308            Self::Quadratic { a, b, c } => match order {
309                0 => Ok(c),
310                1 => Ok(b),
311                2 => Ok(a),
312                _ => Err(NyxError::PolynomialOrderError { order }),
313            },
314        }
315    }
316
317    pub fn with_val_in_order(self, new_val: f64, order: usize) -> Result<Self, NyxError> {
318        match self {
319            Self::Constant { .. } => {
320                if order != 0 {
321                    Err(NyxError::PolynomialOrderError { order })
322                } else {
323                    Ok(Self::Constant { a: new_val })
324                }
325            }
326            Self::Linear { a, b } => match order {
327                0 => Ok(Self::Linear { a: new_val, b }),
328                1 => Ok(Self::Linear { a, b: new_val }),
329
330                _ => Err(NyxError::PolynomialOrderError { order }),
331            },
332            Self::Quadratic { a, b, c } => match order {
333                0 => Ok(Self::Quadratic { a: new_val, b, c }),
334                1 => Ok(Self::Quadratic { a, b: new_val, c }),
335                2 => Ok(Self::Quadratic { a, b, c: new_val }),
336                _ => Err(NyxError::PolynomialOrderError { order }),
337            },
338        }
339    }
340
341    pub fn add_val_in_order(self, new_val: f64, order: usize) -> Result<Self, NyxError> {
342        match self {
343            Self::Constant { a } => {
344                if order != 0 {
345                    Err(NyxError::PolynomialOrderError { order })
346                } else {
347                    Ok(Self::Constant { a: new_val + a })
348                }
349            }
350            Self::Linear { a, b } => match order {
351                0 => Ok(Self::Linear { a: new_val + a, b }),
352
353                1 => Ok(Self::Linear { a, b: new_val + b }),
354
355                _ => Err(NyxError::PolynomialOrderError { order }),
356            },
357            Self::Quadratic { a, b, c } => match order {
358                0 => Ok(Self::Quadratic {
359                    a: new_val + a,
360                    b,
361                    c,
362                }),
363                1 => Ok(Self::Quadratic {
364                    a,
365                    b: new_val + b,
366                    c,
367                }),
368                2 => Ok(Self::Quadratic {
369                    a,
370                    b,
371                    c: new_val + c,
372                }),
373                _ => Err(NyxError::PolynomialOrderError { order }),
374            },
375        }
376    }
377}
378
379impl fmt::Display for CommonPolynomial {
380    /// Prints the polynomial with the least significant coefficients first
381    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
382        match *self {
383            Self::Constant { a } => write!(f, "{}", Polynomial::<1> { coefficients: [a] }),
384            Self::Linear { a, b } => write!(
385                f,
386                "{}",
387                Polynomial::<2> {
388                    coefficients: [b, a],
389                }
390            ),
391            Self::Quadratic { a, b, c } => write!(
392                f,
393                "{}",
394                Polynomial::<3> {
395                    coefficients: [c, b, a],
396                }
397            ),
398        }
399    }
400}
401
402#[cfg(test)]
403mod ut_poly {
404    use crate::polyfit::{CommonPolynomial, Polynomial};
405
406    #[test]
407    fn poly_constant() {
408        let c = CommonPolynomial::Constant { a: 10.0 };
409        for i in -100..=100 {
410            assert!(
411                (c.eval(i as f64) - 10.0).abs() < f64::EPSILON,
412                "Constant polynomial returned wrong value"
413            );
414        }
415    }
416
417    #[test]
418    fn poly_linear() {
419        let c = CommonPolynomial::Linear { a: 2.0, b: 10.0 };
420        for i in -100..=100 {
421            let x = i as f64;
422            let expect = 2.0 * x + 10.0;
423            assert!(
424                (c.eval(x) - expect).abs() < f64::EPSILON,
425                "Constant polynomial returned wrong value"
426            );
427        }
428    }
429
430    #[test]
431    fn poly_quadratic() {
432        let p = Polynomial {
433            coefficients: [101.0, -2.0, 3.0],
434        };
435        let p2 = 2.0 * p;
436        let c = CommonPolynomial::Quadratic {
437            a: 3.0,
438            b: -2.0,
439            c: 101.0,
440        };
441        for i in -100..=100 {
442            let x = i as f64;
443            let expect = 3.0 * x.powi(2) - 2.0 * x + 101.0;
444            let expect_deriv = 6.0 * x - 2.0;
445            assert!(
446                (c.eval(x) - expect).abs() < f64::EPSILON,
447                "Polynomial returned wrong value"
448            );
449            assert!(
450                (p.deriv(x) - expect_deriv).abs() < f64::EPSILON,
451                "Polynomial derivative returned wrong value"
452            );
453
454            assert!(
455                (p.eval(x) - expect).abs() < f64::EPSILON,
456                "Polynomial returned wrong value"
457            );
458            assert!(
459                (p2.eval(x) - 2.0 * expect).abs() < f64::EPSILON,
460                "Polynomial returned wrong value"
461            );
462        }
463    }
464
465    #[test]
466    fn poly_print() {
467        let p = Polynomial {
468            coefficients: [101.0, -2.0, 3.0],
469        };
470        println!("{p}");
471        assert_eq!(
472            format!("{p}"),
473            format!(
474                "{}",
475                CommonPolynomial::Quadratic {
476                    a: 3.0,
477                    b: -2.0,
478                    c: 101.0
479                }
480            )
481        );
482    }
483
484    #[test]
485    fn poly_add() {
486        let p1 = Polynomial {
487            coefficients: [4.0, -2.0, 3.0],
488        };
489        let p2 = Polynomial {
490            coefficients: [0.0, -5.0, 0.0, 2.0],
491        };
492        //      P(x) = (3x^2 - 2x + 4) + (2x^3 - 5x)
493        // <=>  P(x) = 2x^3 + 3x^2 -7x + 4
494        let p_expected = Polynomial {
495            coefficients: [4.0, -7.0, 3.0, 2.0],
496        };
497
498        // let p3 = add::<4, 3>(p2, p1);
499        let p3 = p1 + p2;
500        println!("p3 = {p3:x}\npe = {p_expected:x}");
501        assert_eq!(p3, p_expected);
502        // Check this is correct
503        for i in -100..=100 {
504            let x = i as f64;
505            let expect = p1.eval(x) + p2.eval(x);
506            assert!(
507                (p3.eval(x) - expect).abs() < f64::EPSILON,
508                "Constant polynomial returned wrong value"
509            );
510        }
511    }
512
513    #[test]
514    fn poly_sub() {
515        let p2 = Polynomial {
516            coefficients: [4.0, -2.0, 3.0],
517        };
518        let p1 = Polynomial {
519            coefficients: [0.0, -5.0, 0.0, 2.0],
520        };
521        //      P(x) = (3x^2 - 2x + 4) + (2x^3 - 5x)
522        // <=>  P(x) = 2x^3 + 3x^2 -7x + 4
523        let p_expected = Polynomial {
524            coefficients: [-4.0, -3.0, -3.0, 2.0],
525        };
526
527        let p3 = p1 - p2;
528        println!("p3 = {p3:x}\npe = {p_expected:x}");
529        assert_eq!(p3, p_expected);
530        // Check this is correct
531        for i in -100..=100 {
532            let x = i as f64;
533            let expect = p1.eval(x) - p2.eval(x);
534            assert!(
535                (p3.eval(x) - expect).abs() < f64::EPSILON,
536                "Constant polynomial returned wrong value"
537            );
538        }
539    }
540
541    #[test]
542    fn poly_serde() {
543        let c = CommonPolynomial::Quadratic {
544            a: 3.0,
545            b: -2.0,
546            c: 101.0,
547        };
548        let c_yml = serde_yml::to_string(&c).unwrap();
549        println!("{c_yml}");
550        let c2 = serde_yml::from_str(&c_yml).unwrap();
551        assert_eq!(c, c2);
552    }
553}