1use serde::{Deserialize, Serialize};
21use serde_dhall::StaticType;
22use std::fmt;
23use std::ops;
24
25use crate::NyxError;
26
27#[derive(Copy, Clone, Debug, PartialEq)]
29pub struct Polynomial<const SIZE: usize> {
30 pub coefficients: [f64; SIZE],
32}
33
34impl<const SIZE: usize> Polynomial<SIZE> {
35 pub fn from_most_significant(mut coeffs: [f64; SIZE]) -> Self {
36 coeffs.reverse();
37 Self {
38 coefficients: coeffs,
39 }
40 }
41
42 pub const fn order(&self) -> usize {
44 SIZE - 1
45 }
46
47 pub fn eval(&self, x: f64) -> f64 {
49 self.eval_n_deriv(x).0
50 }
51
52 pub fn deriv(&self, x: f64) -> f64 {
54 self.eval_n_deriv(x).1
55 }
56
57 pub fn eval_n_deriv(&self, x: f64) -> (f64, f64) {
59 if SIZE == 1 {
60 return (self.coefficients[0], 0.0);
61 }
62
63 let mut acc_eval = *self.coefficients.last().unwrap();
65 let mut acc_deriv = *self.coefficients.last().unwrap();
66 for val in self.coefficients.iter().skip(1).rev().skip(1) {
68 acc_eval = acc_eval * x + *val;
69 acc_deriv = acc_deriv * x + acc_eval;
70 }
71 acc_eval = x * acc_eval + self.coefficients[0];
73
74 (acc_eval, acc_deriv)
75 }
76
77 pub fn zeros() -> Self {
79 Self {
80 coefficients: [0.0; SIZE],
81 }
82 }
83
84 pub fn zero_power(&mut self, i: usize) {
86 if i < SIZE {
87 self.coefficients[i] = 0.0;
88 }
89 }
90
91 pub fn zero_below_tolerance(&mut self, tol: f64) {
93 for i in 0..=self.order() {
94 if self.coefficients[i].abs() < tol {
95 self.zero_power(i);
96 }
97 }
98 }
99
100 pub fn is_nan(&self) -> bool {
102 for c in self.coefficients {
103 if c.is_nan() {
104 return true;
105 }
106 }
107 false
108 }
109
110 fn fmt_with_var(&self, f: &mut fmt::Formatter, var: String) -> fmt::Result {
111 write!(f, "P({var}) = ")?;
112 let mut data = Vec::with_capacity(SIZE);
113
114 for (i, c) in self.coefficients.iter().enumerate().rev() {
115 if c.abs() <= f64::EPSILON {
116 continue;
117 }
118
119 let mut d;
120 if c.abs() > 100.0 || c.abs() < 0.01 {
121 if c > &0.0 {
123 d = format!("+{c:e}");
124 } else {
125 d = format!("{c:e}");
126 }
127 } else if c > &0.0 {
128 d = format!("+{c}");
129 } else {
130 d = format!("{c}");
131 }
132 let p = i;
134 match p {
135 0 => {} 1 => d = format!("{d}{var}"),
137 _ => d = format!("{d}{var}^{p}"),
138 }
139 data.push(d);
140 }
141 write!(f, "{}", data.join(" "))
142 }
143}
144
145impl<const SIZE: usize> ops::Mul<f64> for Polynomial<SIZE> {
147 type Output = Polynomial<SIZE>;
148
149 fn mul(mut self, rhs: f64) -> Self::Output {
150 for val in &mut self.coefficients {
151 *val *= rhs;
152 }
153 self
154 }
155}
156
157impl<const SIZE: usize> ops::Mul<f64> for &Polynomial<SIZE> {
159 type Output = Polynomial<SIZE>;
160
161 fn mul(self, rhs: f64) -> Self::Output {
162 *self * rhs
163 }
164}
165
166impl<const SIZE: usize> ops::Mul<Polynomial<SIZE>> for f64 {
168 type Output = Polynomial<SIZE>;
169
170 fn mul(self, rhs: Polynomial<SIZE>) -> Self::Output {
171 let mut me = rhs;
172 for val in &mut me.coefficients {
173 *val *= self;
174 }
175 me
176 }
177}
178
179impl<const SIZE: usize> ops::AddAssign<f64> for Polynomial<SIZE> {
180 fn add_assign(&mut self, rhs: f64) {
181 self.coefficients[0] += rhs;
182 }
183}
184
185impl<const SIZE: usize> fmt::Display for Polynomial<SIZE> {
186 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
187 self.fmt_with_var(f, "t".to_string())
188 }
189}
190
191impl<const SIZE: usize> fmt::LowerHex for Polynomial<SIZE> {
192 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
193 self.fmt_with_var(f, "x".to_string())
194 }
195}
196
197pub(crate) fn add<const S1: usize, const S2: usize>(
198 p1: Polynomial<S1>,
199 p2: Polynomial<S2>,
200) -> Polynomial<S1> {
201 if S1 < S2 {
202 panic!();
203 }
204 let mut rtn = Polynomial::zeros();
205 for (i, c1) in p1.coefficients.iter().enumerate() {
206 rtn.coefficients[i] = match p2.coefficients.get(i) {
207 Some(c2) => c1 + c2,
208 None => *c1,
209 };
210 }
211 rtn
212}
213
214impl<const S1: usize, const S2: usize> ops::Add<Polynomial<S1>> for Polynomial<S2> {
215 type Output = Polynomial<S1>;
216 fn add(self, other: Polynomial<S1>) -> Self::Output {
218 add(other, self)
219 }
220}
221
222pub(crate) fn sub<const S1: usize, const S2: usize>(
224 p1: Polynomial<S1>,
225 p2: Polynomial<S2>,
226) -> Polynomial<S1> {
227 if S1 < S2 {
228 panic!();
229 }
230 let mut rtn = Polynomial::zeros();
231 for (i, c1) in p1.coefficients.iter().enumerate() {
232 rtn.coefficients[i] = match p2.coefficients.get(i) {
233 Some(c2) => c1 - c2,
234 None => *c1,
235 };
236 }
237 rtn
238}
239
240impl<const S1: usize, const S2: usize> ops::Sub<Polynomial<S2>> for Polynomial<S1> {
241 type Output = Polynomial<S1>;
242 fn sub(self, other: Polynomial<S2>) -> Self::Output {
243 sub(self, other)
244 }
245}
246
247#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize, StaticType)]
248pub enum CommonPolynomial {
249 Constant {
250 a: f64,
251 },
252 Linear {
254 a: f64,
255 b: f64,
256 },
257 Quadratic {
259 a: f64,
260 b: f64,
261 c: f64,
262 },
263}
264
265impl CommonPolynomial {
266 pub fn eval(&self, x: f64) -> f64 {
267 match *self {
268 Self::Constant { a } => Polynomial::<1> { coefficients: [a] }.eval(x),
269 Self::Linear { a, b } => Polynomial::<2> {
270 coefficients: [b, a],
271 }
272 .eval(x),
273 Self::Quadratic { a, b, c } => Polynomial::<3> {
274 coefficients: [c, b, a],
275 }
276 .eval(x),
277 }
278 }
279
280 pub fn deriv(&self, x: f64) -> f64 {
281 match *self {
282 Self::Constant { a } => Polynomial::<1> { coefficients: [a] }.deriv(x),
283 Self::Linear { a, b } => Polynomial::<2> {
284 coefficients: [b, a],
285 }
286 .deriv(x),
287 Self::Quadratic { a, b, c } => Polynomial::<3> {
288 coefficients: [c, b, a],
289 }
290 .deriv(x),
291 }
292 }
293
294 pub fn coeff_in_order(&self, order: usize) -> Result<f64, NyxError> {
295 match *self {
296 Self::Constant { a } => {
297 if order == 0 {
298 Ok(a)
299 } else {
300 Err(NyxError::PolynomialOrderError { order })
301 }
302 }
303 Self::Linear { a, b } => match order {
304 0 => Ok(b),
305 1 => Ok(a),
306 _ => Err(NyxError::PolynomialOrderError { order }),
307 },
308 Self::Quadratic { a, b, c } => match order {
309 0 => Ok(c),
310 1 => Ok(b),
311 2 => Ok(a),
312 _ => Err(NyxError::PolynomialOrderError { order }),
313 },
314 }
315 }
316
317 pub fn with_val_in_order(self, new_val: f64, order: usize) -> Result<Self, NyxError> {
318 match self {
319 Self::Constant { .. } => {
320 if order != 0 {
321 Err(NyxError::PolynomialOrderError { order })
322 } else {
323 Ok(Self::Constant { a: new_val })
324 }
325 }
326 Self::Linear { a, b } => match order {
327 0 => Ok(Self::Linear { a: new_val, b }),
328 1 => Ok(Self::Linear { a, b: new_val }),
329
330 _ => Err(NyxError::PolynomialOrderError { order }),
331 },
332 Self::Quadratic { a, b, c } => match order {
333 0 => Ok(Self::Quadratic { a: new_val, b, c }),
334 1 => Ok(Self::Quadratic { a, b: new_val, c }),
335 2 => Ok(Self::Quadratic { a, b, c: new_val }),
336 _ => Err(NyxError::PolynomialOrderError { order }),
337 },
338 }
339 }
340
341 pub fn add_val_in_order(self, new_val: f64, order: usize) -> Result<Self, NyxError> {
342 match self {
343 Self::Constant { a } => {
344 if order != 0 {
345 Err(NyxError::PolynomialOrderError { order })
346 } else {
347 Ok(Self::Constant { a: new_val + a })
348 }
349 }
350 Self::Linear { a, b } => match order {
351 0 => Ok(Self::Linear { a: new_val + a, b }),
352
353 1 => Ok(Self::Linear { a, b: new_val + b }),
354
355 _ => Err(NyxError::PolynomialOrderError { order }),
356 },
357 Self::Quadratic { a, b, c } => match order {
358 0 => Ok(Self::Quadratic {
359 a: new_val + a,
360 b,
361 c,
362 }),
363 1 => Ok(Self::Quadratic {
364 a,
365 b: new_val + b,
366 c,
367 }),
368 2 => Ok(Self::Quadratic {
369 a,
370 b,
371 c: new_val + c,
372 }),
373 _ => Err(NyxError::PolynomialOrderError { order }),
374 },
375 }
376 }
377}
378
379impl fmt::Display for CommonPolynomial {
380 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
382 match *self {
383 Self::Constant { a } => write!(f, "{}", Polynomial::<1> { coefficients: [a] }),
384 Self::Linear { a, b } => write!(
385 f,
386 "{}",
387 Polynomial::<2> {
388 coefficients: [b, a],
389 }
390 ),
391 Self::Quadratic { a, b, c } => write!(
392 f,
393 "{}",
394 Polynomial::<3> {
395 coefficients: [c, b, a],
396 }
397 ),
398 }
399 }
400}
401
402#[cfg(test)]
403mod ut_poly {
404 use crate::polyfit::{CommonPolynomial, Polynomial};
405
406 #[test]
407 fn poly_constant() {
408 let c = CommonPolynomial::Constant { a: 10.0 };
409 for i in -100..=100 {
410 assert!(
411 (c.eval(i as f64) - 10.0).abs() < f64::EPSILON,
412 "Constant polynomial returned wrong value"
413 );
414 }
415 }
416
417 #[test]
418 fn poly_linear() {
419 let c = CommonPolynomial::Linear { a: 2.0, b: 10.0 };
420 for i in -100..=100 {
421 let x = i as f64;
422 let expect = 2.0 * x + 10.0;
423 assert!(
424 (c.eval(x) - expect).abs() < f64::EPSILON,
425 "Constant polynomial returned wrong value"
426 );
427 }
428 }
429
430 #[test]
431 fn poly_quadratic() {
432 let p = Polynomial {
433 coefficients: [101.0, -2.0, 3.0],
434 };
435 let p2 = 2.0 * p;
436 let c = CommonPolynomial::Quadratic {
437 a: 3.0,
438 b: -2.0,
439 c: 101.0,
440 };
441 for i in -100..=100 {
442 let x = i as f64;
443 let expect = 3.0 * x.powi(2) - 2.0 * x + 101.0;
444 let expect_deriv = 6.0 * x - 2.0;
445 assert!(
446 (c.eval(x) - expect).abs() < f64::EPSILON,
447 "Polynomial returned wrong value"
448 );
449 assert!(
450 (p.deriv(x) - expect_deriv).abs() < f64::EPSILON,
451 "Polynomial derivative returned wrong value"
452 );
453
454 assert!(
455 (p.eval(x) - expect).abs() < f64::EPSILON,
456 "Polynomial returned wrong value"
457 );
458 assert!(
459 (p2.eval(x) - 2.0 * expect).abs() < f64::EPSILON,
460 "Polynomial returned wrong value"
461 );
462 }
463 }
464
465 #[test]
466 fn poly_print() {
467 let p = Polynomial {
468 coefficients: [101.0, -2.0, 3.0],
469 };
470 println!("{p}");
471 assert_eq!(
472 format!("{p}"),
473 format!(
474 "{}",
475 CommonPolynomial::Quadratic {
476 a: 3.0,
477 b: -2.0,
478 c: 101.0
479 }
480 )
481 );
482 }
483
484 #[test]
485 fn poly_add() {
486 let p1 = Polynomial {
487 coefficients: [4.0, -2.0, 3.0],
488 };
489 let p2 = Polynomial {
490 coefficients: [0.0, -5.0, 0.0, 2.0],
491 };
492 let p_expected = Polynomial {
495 coefficients: [4.0, -7.0, 3.0, 2.0],
496 };
497
498 let p3 = p1 + p2;
500 println!("p3 = {p3:x}\npe = {p_expected:x}");
501 assert_eq!(p3, p_expected);
502 for i in -100..=100 {
504 let x = i as f64;
505 let expect = p1.eval(x) + p2.eval(x);
506 assert!(
507 (p3.eval(x) - expect).abs() < f64::EPSILON,
508 "Constant polynomial returned wrong value"
509 );
510 }
511 }
512
513 #[test]
514 fn poly_sub() {
515 let p2 = Polynomial {
516 coefficients: [4.0, -2.0, 3.0],
517 };
518 let p1 = Polynomial {
519 coefficients: [0.0, -5.0, 0.0, 2.0],
520 };
521 let p_expected = Polynomial {
524 coefficients: [-4.0, -3.0, -3.0, 2.0],
525 };
526
527 let p3 = p1 - p2;
528 println!("p3 = {p3:x}\npe = {p_expected:x}");
529 assert_eq!(p3, p_expected);
530 for i in -100..=100 {
532 let x = i as f64;
533 let expect = p1.eval(x) - p2.eval(x);
534 assert!(
535 (p3.eval(x) - expect).abs() < f64::EPSILON,
536 "Constant polynomial returned wrong value"
537 );
538 }
539 }
540
541 #[test]
542 fn poly_serde() {
543 let c = CommonPolynomial::Quadratic {
544 a: 3.0,
545 b: -2.0,
546 c: 101.0,
547 };
548 let c_yml = serde_yml::to_string(&c).unwrap();
549 println!("{c_yml}");
550 let c2 = serde_yml::from_str(&c_yml).unwrap();
551 assert_eq!(c, c2);
552 }
553}