nyx_space/polyfit/
polynomial.rs

1/*
2    Nyx, blazing fast astrodynamics
3    Copyright (C) 2018-onwards Christopher Rabotin <christopher.rabotin@gmail.com>
4
5    This program is free software: you can redistribute it and/or modify
6    it under the terms of the GNU Affero General Public License as published
7    by the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU Affero General Public License for more details.
14
15    You should have received a copy of the GNU Affero General Public License
16    along with this program.  If not, see <https://www.gnu.org/licenses/>.
17*/
18
19/* NOTE: This code is effectively a clone of bacon-sci, MIT License, by Wyatt Campbell. */
20use serde_derive::{Deserialize, Serialize};
21use std::fmt;
22use std::ops;
23
24use crate::NyxError;
25
26/// Polynomial is a statically allocated polynomial.
27#[derive(Copy, Clone, Debug, PartialEq)]
28pub struct Polynomial<const SIZE: usize> {
29    /// Coefficients are orders by their power, e.g. index 0 is to the power 0, 1 is linear, 2 is quadratic, etc.
30    pub coefficients: [f64; SIZE],
31}
32
33impl<const SIZE: usize> Polynomial<SIZE> {
34    pub fn from_most_significant(mut coeffs: [f64; SIZE]) -> Self {
35        coeffs.reverse();
36        Self {
37            coefficients: coeffs,
38        }
39    }
40
41    /// Get the order of the polynomial
42    pub const fn order(&self) -> usize {
43        SIZE - 1
44    }
45
46    /// Evaluate the polynomial at the provided position
47    pub fn eval(&self, x: f64) -> f64 {
48        self.eval_n_deriv(x).0
49    }
50
51    /// Evaluate the derivative at the provided position
52    pub fn deriv(&self, x: f64) -> f64 {
53        self.eval_n_deriv(x).1
54    }
55
56    /// Evaluate the polynomial and its derivative at the provided position
57    pub fn eval_n_deriv(&self, x: f64) -> (f64, f64) {
58        if SIZE == 1 {
59            return (self.coefficients[0], 0.0);
60        }
61
62        // Start with biggest coefficients
63        let mut acc_eval = *self.coefficients.last().unwrap();
64        let mut acc_deriv = *self.coefficients.last().unwrap();
65        // For every coefficient except the constant and largest
66        for val in self.coefficients.iter().skip(1).rev().skip(1) {
67            acc_eval = acc_eval * x + *val;
68            acc_deriv = acc_deriv * x + acc_eval;
69        }
70        // Do the constant for the polynomial evaluation
71        acc_eval = x * acc_eval + self.coefficients[0];
72
73        (acc_eval, acc_deriv)
74    }
75
76    /// Initializes a Polynomial with only zeros
77    pub fn zeros() -> Self {
78        Self {
79            coefficients: [0.0; SIZE],
80        }
81    }
82
83    /// Set the i-th power of this polynomial to zero (e.g. if i=0, set the x^0 coefficient to zero, i.e. the constant part goes to zero)
84    pub fn zero_power(&mut self, i: usize) {
85        if i < SIZE {
86            self.coefficients[i] = 0.0;
87        }
88    }
89
90    /// Set all of the coefficients below this tolerance to zero
91    pub fn zero_below_tolerance(&mut self, tol: f64) {
92        for i in 0..=self.order() {
93            if self.coefficients[i].abs() < tol {
94                self.zero_power(i);
95            }
96        }
97    }
98
99    /// Returns true if any of the coefficients are NaN
100    pub fn is_nan(&self) -> bool {
101        for c in self.coefficients {
102            if c.is_nan() {
103                return true;
104            }
105        }
106        false
107    }
108
109    fn fmt_with_var(&self, f: &mut fmt::Formatter, var: String) -> fmt::Result {
110        write!(f, "P({var}) = ")?;
111        let mut data = Vec::with_capacity(SIZE);
112
113        for (i, c) in self.coefficients.iter().enumerate().rev() {
114            if c.abs() <= f64::EPSILON {
115                continue;
116            }
117
118            let mut d;
119            if c.abs() > 100.0 || c.abs() < 0.01 {
120                // Use scientific notation
121                if c > &0.0 {
122                    d = format!("+{c:e}");
123                } else {
124                    d = format!("{c:e}");
125                }
126            } else if c > &0.0 {
127                d = format!("+{c}");
128            } else {
129                d = format!("{c}");
130            }
131            // Add the power
132            let p = i;
133            match p {
134                0 => {} // Show nothing for zero
135                1 => d = format!("{d}{var}"),
136                _ => d = format!("{d}{var}^{p}"),
137            }
138            data.push(d);
139        }
140        write!(f, "{}", data.join(" "))
141    }
142}
143
144/// In-place multiplication of a polynomial with an f64
145impl<const SIZE: usize> ops::Mul<f64> for Polynomial<SIZE> {
146    type Output = Polynomial<SIZE>;
147
148    fn mul(mut self, rhs: f64) -> Self::Output {
149        for val in &mut self.coefficients {
150            *val *= rhs;
151        }
152        self
153    }
154}
155
156/// Clone current polynomial and then multiply it with an f64
157impl<const SIZE: usize> ops::Mul<f64> for &Polynomial<SIZE> {
158    type Output = Polynomial<SIZE>;
159
160    fn mul(self, rhs: f64) -> Self::Output {
161        *self * rhs
162    }
163}
164
165/// In-place multiplication of a polynomial with an f64
166impl<const SIZE: usize> ops::Mul<Polynomial<SIZE>> for f64 {
167    type Output = Polynomial<SIZE>;
168
169    fn mul(self, rhs: Polynomial<SIZE>) -> Self::Output {
170        let mut me = rhs;
171        for val in &mut me.coefficients {
172            *val *= self;
173        }
174        me
175    }
176}
177
178impl<const SIZE: usize> ops::AddAssign<f64> for Polynomial<SIZE> {
179    fn add_assign(&mut self, rhs: f64) {
180        self.coefficients[0] += rhs;
181    }
182}
183
184impl<const SIZE: usize> fmt::Display for Polynomial<SIZE> {
185    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
186        self.fmt_with_var(f, "t".to_string())
187    }
188}
189
190impl<const SIZE: usize> fmt::LowerHex for Polynomial<SIZE> {
191    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192        self.fmt_with_var(f, "x".to_string())
193    }
194}
195
196pub(crate) fn add<const S1: usize, const S2: usize>(
197    p1: Polynomial<S1>,
198    p2: Polynomial<S2>,
199) -> Polynomial<S1> {
200    if S1 < S2 {
201        panic!();
202    }
203    let mut rtn = Polynomial::zeros();
204    for (i, c1) in p1.coefficients.iter().enumerate() {
205        rtn.coefficients[i] = match p2.coefficients.get(i) {
206            Some(c2) => c1 + c2,
207            None => *c1,
208        };
209    }
210    rtn
211}
212
213impl<const S1: usize, const S2: usize> ops::Add<Polynomial<S1>> for Polynomial<S2> {
214    type Output = Polynomial<S1>;
215    /// Add Self and Other, _IF_ S2 >= S1 (else panic!)
216    fn add(self, other: Polynomial<S1>) -> Self::Output {
217        add(other, self)
218    }
219}
220
221/// Subtracts p1 from p2 (p3 = p1 - p2)
222pub(crate) fn sub<const S1: usize, const S2: usize>(
223    p1: Polynomial<S1>,
224    p2: Polynomial<S2>,
225) -> Polynomial<S1> {
226    if S1 < S2 {
227        panic!();
228    }
229    let mut rtn = Polynomial::zeros();
230    for (i, c1) in p1.coefficients.iter().enumerate() {
231        rtn.coefficients[i] = match p2.coefficients.get(i) {
232            Some(c2) => c1 - c2,
233            None => *c1,
234        };
235    }
236    rtn
237}
238
239impl<const S1: usize, const S2: usize> ops::Sub<Polynomial<S2>> for Polynomial<S1> {
240    type Output = Polynomial<S1>;
241    fn sub(self, other: Polynomial<S2>) -> Self::Output {
242        sub(self, other)
243    }
244}
245
246#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
247pub enum CommonPolynomial {
248    Constant(f64),
249    /// Linear(a, b) <=> f(x) = ax + b (order is FLIPPED from Polynomial<N> structure)
250    Linear(f64, f64),
251    /// Quadratic(a, b, c) <=> f(x) = ax^2 + bx + c (order is FLIPPED from Polynomial<N> structure)
252    Quadratic(f64, f64, f64),
253}
254
255impl CommonPolynomial {
256    pub fn eval(&self, x: f64) -> f64 {
257        match *self {
258            Self::Constant(a) => Polynomial::<1> { coefficients: [a] }.eval(x),
259            Self::Linear(a, b) => Polynomial::<2> {
260                coefficients: [b, a],
261            }
262            .eval(x),
263            Self::Quadratic(a, b, c) => Polynomial::<3> {
264                coefficients: [c, b, a],
265            }
266            .eval(x),
267        }
268    }
269
270    pub fn deriv(&self, x: f64) -> f64 {
271        match *self {
272            Self::Constant(a) => Polynomial::<1> { coefficients: [a] }.deriv(x),
273            Self::Linear(a, b) => Polynomial::<2> {
274                coefficients: [b, a],
275            }
276            .deriv(x),
277            Self::Quadratic(a, b, c) => Polynomial::<3> {
278                coefficients: [c, b, a],
279            }
280            .deriv(x),
281        }
282    }
283
284    pub fn coeff_in_order(&self, order: usize) -> Result<f64, NyxError> {
285        match *self {
286            Self::Constant(a) => {
287                if order == 0 {
288                    Ok(a)
289                } else {
290                    Err(NyxError::PolynomialOrderError { order })
291                }
292            }
293            Self::Linear(a, b) => match order {
294                0 => Ok(b),
295                1 => Ok(a),
296                _ => Err(NyxError::PolynomialOrderError { order }),
297            },
298            Self::Quadratic(a, b, c) => match order {
299                0 => Ok(c),
300                1 => Ok(b),
301                2 => Ok(a),
302                _ => Err(NyxError::PolynomialOrderError { order }),
303            },
304        }
305    }
306
307    pub fn with_val_in_order(self, new_val: f64, order: usize) -> Result<Self, NyxError> {
308        match self {
309            Self::Constant(_) => {
310                if order != 0 {
311                    Err(NyxError::PolynomialOrderError { order })
312                } else {
313                    Ok(Self::Constant(new_val))
314                }
315            }
316            Self::Linear(x, y) => match order {
317                0 => Ok(Self::Linear(new_val, y)),
318                1 => Ok(Self::Linear(x, new_val)),
319                _ => Err(NyxError::PolynomialOrderError { order }),
320            },
321            Self::Quadratic(x, y, z) => match order {
322                0 => Ok(Self::Quadratic(new_val, y, z)),
323                1 => Ok(Self::Quadratic(x, new_val, z)),
324                2 => Ok(Self::Quadratic(x, y, new_val)),
325                _ => Err(NyxError::PolynomialOrderError { order }),
326            },
327        }
328    }
329
330    pub fn add_val_in_order(self, new_val: f64, order: usize) -> Result<Self, NyxError> {
331        match self {
332            Self::Constant(x) => {
333                if order != 0 {
334                    Err(NyxError::PolynomialOrderError { order })
335                } else {
336                    Ok(Self::Constant(new_val + x))
337                }
338            }
339            Self::Linear(x, y) => match order {
340                0 => Ok(Self::Linear(new_val + x, y)),
341                1 => Ok(Self::Linear(x, new_val + y)),
342                _ => Err(NyxError::PolynomialOrderError { order }),
343            },
344            Self::Quadratic(x, y, z) => match order {
345                0 => Ok(Self::Quadratic(new_val + x, y, z)),
346                1 => Ok(Self::Quadratic(x, new_val + y, z)),
347                2 => Ok(Self::Quadratic(x, y, new_val + z)),
348                _ => Err(NyxError::PolynomialOrderError { order }),
349            },
350        }
351    }
352}
353
354impl fmt::Display for CommonPolynomial {
355    /// Prints the polynomial with the least significant coefficients first
356    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
357        match *self {
358            Self::Constant(a) => write!(f, "{}", Polynomial::<1> { coefficients: [a] }),
359            Self::Linear(a, b) => write!(
360                f,
361                "{}",
362                Polynomial::<2> {
363                    coefficients: [b, a],
364                }
365            ),
366            Self::Quadratic(a, b, c) => write!(
367                f,
368                "{}",
369                Polynomial::<3> {
370                    coefficients: [c, b, a],
371                }
372            ),
373        }
374    }
375}
376
377#[cfg(test)]
378mod ut_poly {
379    use crate::polyfit::{CommonPolynomial, Polynomial};
380
381    #[test]
382    fn poly_constant() {
383        let c = CommonPolynomial::Constant(10.0);
384        for i in -100..=100 {
385            assert!(
386                (c.eval(i as f64) - 10.0).abs() < f64::EPSILON,
387                "Constant polynomial returned wrong value"
388            );
389        }
390    }
391
392    #[test]
393    fn poly_linear() {
394        let c = CommonPolynomial::Linear(2.0, 10.0);
395        for i in -100..=100 {
396            let x = i as f64;
397            let expect = 2.0 * x + 10.0;
398            assert!(
399                (c.eval(x) - expect).abs() < f64::EPSILON,
400                "Constant polynomial returned wrong value"
401            );
402        }
403    }
404
405    #[test]
406    fn poly_quadratic() {
407        let p = Polynomial {
408            coefficients: [101.0, -2.0, 3.0],
409        };
410        let p2 = 2.0 * p;
411        let c = CommonPolynomial::Quadratic(3.0, -2.0, 101.0);
412        for i in -100..=100 {
413            let x = i as f64;
414            let expect = 3.0 * x.powi(2) - 2.0 * x + 101.0;
415            let expect_deriv = 6.0 * x - 2.0;
416            assert!(
417                (c.eval(x) - expect).abs() < f64::EPSILON,
418                "Polynomial returned wrong value"
419            );
420            assert!(
421                (p.deriv(x) - expect_deriv).abs() < f64::EPSILON,
422                "Polynomial derivative returned wrong value"
423            );
424
425            assert!(
426                (p.eval(x) - expect).abs() < f64::EPSILON,
427                "Polynomial returned wrong value"
428            );
429            assert!(
430                (p2.eval(x) - 2.0 * expect).abs() < f64::EPSILON,
431                "Polynomial returned wrong value"
432            );
433        }
434    }
435
436    #[test]
437    fn poly_print() {
438        let p = Polynomial {
439            coefficients: [101.0, -2.0, 3.0],
440        };
441        println!("{}", p);
442        assert_eq!(
443            format!("{}", p),
444            format!("{}", CommonPolynomial::Quadratic(3.0, -2.0, 101.0))
445        );
446    }
447
448    #[test]
449    fn poly_add() {
450        let p1 = Polynomial {
451            coefficients: [4.0, -2.0, 3.0],
452        };
453        let p2 = Polynomial {
454            coefficients: [0.0, -5.0, 0.0, 2.0],
455        };
456        //      P(x) = (3x^2 - 2x + 4) + (2x^3 - 5x)
457        // <=>  P(x) = 2x^3 + 3x^2 -7x + 4
458        let p_expected = Polynomial {
459            coefficients: [4.0, -7.0, 3.0, 2.0],
460        };
461
462        // let p3 = add::<4, 3>(p2, p1);
463        let p3 = p1 + p2;
464        println!("p3 = {:x}\npe = {:x}", p3, p_expected);
465        assert_eq!(p3, p_expected);
466        // Check this is correct
467        for i in -100..=100 {
468            let x = i as f64;
469            let expect = p1.eval(x) + p2.eval(x);
470            assert!(
471                (p3.eval(x) - expect).abs() < f64::EPSILON,
472                "Constant polynomial returned wrong value"
473            );
474        }
475    }
476
477    #[test]
478    fn poly_sub() {
479        let p2 = Polynomial {
480            coefficients: [4.0, -2.0, 3.0],
481        };
482        let p1 = Polynomial {
483            coefficients: [0.0, -5.0, 0.0, 2.0],
484        };
485        //      P(x) = (3x^2 - 2x + 4) + (2x^3 - 5x)
486        // <=>  P(x) = 2x^3 + 3x^2 -7x + 4
487        let p_expected = Polynomial {
488            coefficients: [-4.0, -3.0, -3.0, 2.0],
489        };
490
491        let p3 = p1 - p2;
492        println!("p3 = {:x}\npe = {:x}", p3, p_expected);
493        assert_eq!(p3, p_expected);
494        // Check this is correct
495        for i in -100..=100 {
496            let x = i as f64;
497            let expect = p1.eval(x) - p2.eval(x);
498            assert!(
499                (p3.eval(x) - expect).abs() < f64::EPSILON,
500                "Constant polynomial returned wrong value"
501            );
502        }
503    }
504
505    #[test]
506    fn poly_serde() {
507        let c = CommonPolynomial::Quadratic(3.0, -2.0, 101.0);
508        let c_yml = serde_yml::to_string(&c).unwrap();
509        println!("{c_yml}");
510        let c2 = serde_yml::from_str(&c_yml).unwrap();
511        assert_eq!(c, c2);
512    }
513}