1use serde_derive::{Deserialize, Serialize};
21use std::fmt;
22use std::ops;
23
24use crate::NyxError;
25
26#[derive(Copy, Clone, Debug, PartialEq)]
28pub struct Polynomial<const SIZE: usize> {
29 pub coefficients: [f64; SIZE],
31}
32
33impl<const SIZE: usize> Polynomial<SIZE> {
34 pub fn from_most_significant(mut coeffs: [f64; SIZE]) -> Self {
35 coeffs.reverse();
36 Self {
37 coefficients: coeffs,
38 }
39 }
40
41 pub const fn order(&self) -> usize {
43 SIZE - 1
44 }
45
46 pub fn eval(&self, x: f64) -> f64 {
48 self.eval_n_deriv(x).0
49 }
50
51 pub fn deriv(&self, x: f64) -> f64 {
53 self.eval_n_deriv(x).1
54 }
55
56 pub fn eval_n_deriv(&self, x: f64) -> (f64, f64) {
58 if SIZE == 1 {
59 return (self.coefficients[0], 0.0);
60 }
61
62 let mut acc_eval = *self.coefficients.last().unwrap();
64 let mut acc_deriv = *self.coefficients.last().unwrap();
65 for val in self.coefficients.iter().skip(1).rev().skip(1) {
67 acc_eval = acc_eval * x + *val;
68 acc_deriv = acc_deriv * x + acc_eval;
69 }
70 acc_eval = x * acc_eval + self.coefficients[0];
72
73 (acc_eval, acc_deriv)
74 }
75
76 pub fn zeros() -> Self {
78 Self {
79 coefficients: [0.0; SIZE],
80 }
81 }
82
83 pub fn zero_power(&mut self, i: usize) {
85 if i < SIZE {
86 self.coefficients[i] = 0.0;
87 }
88 }
89
90 pub fn zero_below_tolerance(&mut self, tol: f64) {
92 for i in 0..=self.order() {
93 if self.coefficients[i].abs() < tol {
94 self.zero_power(i);
95 }
96 }
97 }
98
99 pub fn is_nan(&self) -> bool {
101 for c in self.coefficients {
102 if c.is_nan() {
103 return true;
104 }
105 }
106 false
107 }
108
109 fn fmt_with_var(&self, f: &mut fmt::Formatter, var: String) -> fmt::Result {
110 write!(f, "P({var}) = ")?;
111 let mut data = Vec::with_capacity(SIZE);
112
113 for (i, c) in self.coefficients.iter().enumerate().rev() {
114 if c.abs() <= f64::EPSILON {
115 continue;
116 }
117
118 let mut d;
119 if c.abs() > 100.0 || c.abs() < 0.01 {
120 if c > &0.0 {
122 d = format!("+{c:e}");
123 } else {
124 d = format!("{c:e}");
125 }
126 } else if c > &0.0 {
127 d = format!("+{c}");
128 } else {
129 d = format!("{c}");
130 }
131 let p = i;
133 match p {
134 0 => {} 1 => d = format!("{d}{var}"),
136 _ => d = format!("{d}{var}^{p}"),
137 }
138 data.push(d);
139 }
140 write!(f, "{}", data.join(" "))
141 }
142}
143
144impl<const SIZE: usize> ops::Mul<f64> for Polynomial<SIZE> {
146 type Output = Polynomial<SIZE>;
147
148 fn mul(mut self, rhs: f64) -> Self::Output {
149 for val in &mut self.coefficients {
150 *val *= rhs;
151 }
152 self
153 }
154}
155
156impl<const SIZE: usize> ops::Mul<f64> for &Polynomial<SIZE> {
158 type Output = Polynomial<SIZE>;
159
160 fn mul(self, rhs: f64) -> Self::Output {
161 *self * rhs
162 }
163}
164
165impl<const SIZE: usize> ops::Mul<Polynomial<SIZE>> for f64 {
167 type Output = Polynomial<SIZE>;
168
169 fn mul(self, rhs: Polynomial<SIZE>) -> Self::Output {
170 let mut me = rhs;
171 for val in &mut me.coefficients {
172 *val *= self;
173 }
174 me
175 }
176}
177
178impl<const SIZE: usize> ops::AddAssign<f64> for Polynomial<SIZE> {
179 fn add_assign(&mut self, rhs: f64) {
180 self.coefficients[0] += rhs;
181 }
182}
183
184impl<const SIZE: usize> fmt::Display for Polynomial<SIZE> {
185 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
186 self.fmt_with_var(f, "t".to_string())
187 }
188}
189
190impl<const SIZE: usize> fmt::LowerHex for Polynomial<SIZE> {
191 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192 self.fmt_with_var(f, "x".to_string())
193 }
194}
195
196pub(crate) fn add<const S1: usize, const S2: usize>(
197 p1: Polynomial<S1>,
198 p2: Polynomial<S2>,
199) -> Polynomial<S1> {
200 if S1 < S2 {
201 panic!();
202 }
203 let mut rtn = Polynomial::zeros();
204 for (i, c1) in p1.coefficients.iter().enumerate() {
205 rtn.coefficients[i] = match p2.coefficients.get(i) {
206 Some(c2) => c1 + c2,
207 None => *c1,
208 };
209 }
210 rtn
211}
212
213impl<const S1: usize, const S2: usize> ops::Add<Polynomial<S1>> for Polynomial<S2> {
214 type Output = Polynomial<S1>;
215 fn add(self, other: Polynomial<S1>) -> Self::Output {
217 add(other, self)
218 }
219}
220
221pub(crate) fn sub<const S1: usize, const S2: usize>(
223 p1: Polynomial<S1>,
224 p2: Polynomial<S2>,
225) -> Polynomial<S1> {
226 if S1 < S2 {
227 panic!();
228 }
229 let mut rtn = Polynomial::zeros();
230 for (i, c1) in p1.coefficients.iter().enumerate() {
231 rtn.coefficients[i] = match p2.coefficients.get(i) {
232 Some(c2) => c1 - c2,
233 None => *c1,
234 };
235 }
236 rtn
237}
238
239impl<const S1: usize, const S2: usize> ops::Sub<Polynomial<S2>> for Polynomial<S1> {
240 type Output = Polynomial<S1>;
241 fn sub(self, other: Polynomial<S2>) -> Self::Output {
242 sub(self, other)
243 }
244}
245
246#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
247pub enum CommonPolynomial {
248 Constant(f64),
249 Linear(f64, f64),
251 Quadratic(f64, f64, f64),
253}
254
255impl CommonPolynomial {
256 pub fn eval(&self, x: f64) -> f64 {
257 match *self {
258 Self::Constant(a) => Polynomial::<1> { coefficients: [a] }.eval(x),
259 Self::Linear(a, b) => Polynomial::<2> {
260 coefficients: [b, a],
261 }
262 .eval(x),
263 Self::Quadratic(a, b, c) => Polynomial::<3> {
264 coefficients: [c, b, a],
265 }
266 .eval(x),
267 }
268 }
269
270 pub fn deriv(&self, x: f64) -> f64 {
271 match *self {
272 Self::Constant(a) => Polynomial::<1> { coefficients: [a] }.deriv(x),
273 Self::Linear(a, b) => Polynomial::<2> {
274 coefficients: [b, a],
275 }
276 .deriv(x),
277 Self::Quadratic(a, b, c) => Polynomial::<3> {
278 coefficients: [c, b, a],
279 }
280 .deriv(x),
281 }
282 }
283
284 pub fn coeff_in_order(&self, order: usize) -> Result<f64, NyxError> {
285 match *self {
286 Self::Constant(a) => {
287 if order == 0 {
288 Ok(a)
289 } else {
290 Err(NyxError::PolynomialOrderError { order })
291 }
292 }
293 Self::Linear(a, b) => match order {
294 0 => Ok(b),
295 1 => Ok(a),
296 _ => Err(NyxError::PolynomialOrderError { order }),
297 },
298 Self::Quadratic(a, b, c) => match order {
299 0 => Ok(c),
300 1 => Ok(b),
301 2 => Ok(a),
302 _ => Err(NyxError::PolynomialOrderError { order }),
303 },
304 }
305 }
306
307 pub fn with_val_in_order(self, new_val: f64, order: usize) -> Result<Self, NyxError> {
308 match self {
309 Self::Constant(_) => {
310 if order != 0 {
311 Err(NyxError::PolynomialOrderError { order })
312 } else {
313 Ok(Self::Constant(new_val))
314 }
315 }
316 Self::Linear(x, y) => match order {
317 0 => Ok(Self::Linear(new_val, y)),
318 1 => Ok(Self::Linear(x, new_val)),
319 _ => Err(NyxError::PolynomialOrderError { order }),
320 },
321 Self::Quadratic(x, y, z) => match order {
322 0 => Ok(Self::Quadratic(new_val, y, z)),
323 1 => Ok(Self::Quadratic(x, new_val, z)),
324 2 => Ok(Self::Quadratic(x, y, new_val)),
325 _ => Err(NyxError::PolynomialOrderError { order }),
326 },
327 }
328 }
329
330 pub fn add_val_in_order(self, new_val: f64, order: usize) -> Result<Self, NyxError> {
331 match self {
332 Self::Constant(x) => {
333 if order != 0 {
334 Err(NyxError::PolynomialOrderError { order })
335 } else {
336 Ok(Self::Constant(new_val + x))
337 }
338 }
339 Self::Linear(x, y) => match order {
340 0 => Ok(Self::Linear(new_val + x, y)),
341 1 => Ok(Self::Linear(x, new_val + y)),
342 _ => Err(NyxError::PolynomialOrderError { order }),
343 },
344 Self::Quadratic(x, y, z) => match order {
345 0 => Ok(Self::Quadratic(new_val + x, y, z)),
346 1 => Ok(Self::Quadratic(x, new_val + y, z)),
347 2 => Ok(Self::Quadratic(x, y, new_val + z)),
348 _ => Err(NyxError::PolynomialOrderError { order }),
349 },
350 }
351 }
352}
353
354impl fmt::Display for CommonPolynomial {
355 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
357 match *self {
358 Self::Constant(a) => write!(f, "{}", Polynomial::<1> { coefficients: [a] }),
359 Self::Linear(a, b) => write!(
360 f,
361 "{}",
362 Polynomial::<2> {
363 coefficients: [b, a],
364 }
365 ),
366 Self::Quadratic(a, b, c) => write!(
367 f,
368 "{}",
369 Polynomial::<3> {
370 coefficients: [c, b, a],
371 }
372 ),
373 }
374 }
375}
376
377#[cfg(test)]
378mod ut_poly {
379 use crate::polyfit::{CommonPolynomial, Polynomial};
380
381 #[test]
382 fn poly_constant() {
383 let c = CommonPolynomial::Constant(10.0);
384 for i in -100..=100 {
385 assert!(
386 (c.eval(i as f64) - 10.0).abs() < f64::EPSILON,
387 "Constant polynomial returned wrong value"
388 );
389 }
390 }
391
392 #[test]
393 fn poly_linear() {
394 let c = CommonPolynomial::Linear(2.0, 10.0);
395 for i in -100..=100 {
396 let x = i as f64;
397 let expect = 2.0 * x + 10.0;
398 assert!(
399 (c.eval(x) - expect).abs() < f64::EPSILON,
400 "Constant polynomial returned wrong value"
401 );
402 }
403 }
404
405 #[test]
406 fn poly_quadratic() {
407 let p = Polynomial {
408 coefficients: [101.0, -2.0, 3.0],
409 };
410 let p2 = 2.0 * p;
411 let c = CommonPolynomial::Quadratic(3.0, -2.0, 101.0);
412 for i in -100..=100 {
413 let x = i as f64;
414 let expect = 3.0 * x.powi(2) - 2.0 * x + 101.0;
415 let expect_deriv = 6.0 * x - 2.0;
416 assert!(
417 (c.eval(x) - expect).abs() < f64::EPSILON,
418 "Polynomial returned wrong value"
419 );
420 assert!(
421 (p.deriv(x) - expect_deriv).abs() < f64::EPSILON,
422 "Polynomial derivative returned wrong value"
423 );
424
425 assert!(
426 (p.eval(x) - expect).abs() < f64::EPSILON,
427 "Polynomial returned wrong value"
428 );
429 assert!(
430 (p2.eval(x) - 2.0 * expect).abs() < f64::EPSILON,
431 "Polynomial returned wrong value"
432 );
433 }
434 }
435
436 #[test]
437 fn poly_print() {
438 let p = Polynomial {
439 coefficients: [101.0, -2.0, 3.0],
440 };
441 println!("{}", p);
442 assert_eq!(
443 format!("{}", p),
444 format!("{}", CommonPolynomial::Quadratic(3.0, -2.0, 101.0))
445 );
446 }
447
448 #[test]
449 fn poly_add() {
450 let p1 = Polynomial {
451 coefficients: [4.0, -2.0, 3.0],
452 };
453 let p2 = Polynomial {
454 coefficients: [0.0, -5.0, 0.0, 2.0],
455 };
456 let p_expected = Polynomial {
459 coefficients: [4.0, -7.0, 3.0, 2.0],
460 };
461
462 let p3 = p1 + p2;
464 println!("p3 = {:x}\npe = {:x}", p3, p_expected);
465 assert_eq!(p3, p_expected);
466 for i in -100..=100 {
468 let x = i as f64;
469 let expect = p1.eval(x) + p2.eval(x);
470 assert!(
471 (p3.eval(x) - expect).abs() < f64::EPSILON,
472 "Constant polynomial returned wrong value"
473 );
474 }
475 }
476
477 #[test]
478 fn poly_sub() {
479 let p2 = Polynomial {
480 coefficients: [4.0, -2.0, 3.0],
481 };
482 let p1 = Polynomial {
483 coefficients: [0.0, -5.0, 0.0, 2.0],
484 };
485 let p_expected = Polynomial {
488 coefficients: [-4.0, -3.0, -3.0, 2.0],
489 };
490
491 let p3 = p1 - p2;
492 println!("p3 = {:x}\npe = {:x}", p3, p_expected);
493 assert_eq!(p3, p_expected);
494 for i in -100..=100 {
496 let x = i as f64;
497 let expect = p1.eval(x) - p2.eval(x);
498 assert!(
499 (p3.eval(x) - expect).abs() < f64::EPSILON,
500 "Constant polynomial returned wrong value"
501 );
502 }
503 }
504
505 #[test]
506 fn poly_serde() {
507 let c = CommonPolynomial::Quadratic(3.0, -2.0, 101.0);
508 let c_yml = serde_yml::to_string(&c).unwrap();
509 println!("{c_yml}");
510 let c2 = serde_yml::from_str(&c_yml).unwrap();
511 assert_eq!(c, c2);
512 }
513}