use serde_derive::{Deserialize, Serialize};
use std::fmt;
use std::ops;
use crate::NyxError;
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct Polynomial<const SIZE: usize> {
pub coefficients: [f64; SIZE],
}
impl<const SIZE: usize> Polynomial<SIZE> {
pub fn from_most_significant(mut coeffs: [f64; SIZE]) -> Self {
coeffs.reverse();
Self {
coefficients: coeffs,
}
}
pub const fn order(&self) -> usize {
SIZE - 1
}
pub fn eval(&self, x: f64) -> f64 {
self.eval_n_deriv(x).0
}
pub fn deriv(&self, x: f64) -> f64 {
self.eval_n_deriv(x).1
}
pub fn eval_n_deriv(&self, x: f64) -> (f64, f64) {
if SIZE == 1 {
return (self.coefficients[0], 0.0);
}
let mut acc_eval = *self.coefficients.last().unwrap();
let mut acc_deriv = *self.coefficients.last().unwrap();
for val in self.coefficients.iter().skip(1).rev().skip(1) {
acc_eval = acc_eval * x + *val;
acc_deriv = acc_deriv * x + acc_eval;
}
acc_eval = x * acc_eval + self.coefficients[0];
(acc_eval, acc_deriv)
}
pub fn zeros() -> Self {
Self {
coefficients: [0.0; SIZE],
}
}
pub fn zero_power(&mut self, i: usize) {
if i < SIZE {
self.coefficients[i] = 0.0;
}
}
pub fn zero_below_tolerance(&mut self, tol: f64) {
for i in 0..=self.order() {
if self.coefficients[i].abs() < tol {
self.zero_power(i);
}
}
}
pub fn is_nan(&self) -> bool {
for c in self.coefficients {
if c.is_nan() {
return true;
}
}
false
}
fn fmt_with_var(&self, f: &mut fmt::Formatter, var: String) -> fmt::Result {
write!(f, "P({var}) = ")?;
let mut data = Vec::with_capacity(SIZE);
for (i, c) in self.coefficients.iter().enumerate().rev() {
if c.abs() <= f64::EPSILON {
continue;
}
let mut d;
if c.abs() > 100.0 || c.abs() < 0.01 {
if c > &0.0 {
d = format!("+{c:e}");
} else {
d = format!("{c:e}");
}
} else if c > &0.0 {
d = format!("+{c}");
} else {
d = format!("{c}");
}
let p = i;
match p {
0 => {} 1 => d = format!("{d}{var}"),
_ => d = format!("{d}{var}^{p}"),
}
data.push(d);
}
write!(f, "{}", data.join(" "))
}
}
impl<const SIZE: usize> ops::Mul<f64> for Polynomial<SIZE> {
type Output = Polynomial<SIZE>;
fn mul(mut self, rhs: f64) -> Self::Output {
for val in &mut self.coefficients {
*val *= rhs;
}
self
}
}
impl<const SIZE: usize> ops::Mul<f64> for &Polynomial<SIZE> {
type Output = Polynomial<SIZE>;
fn mul(self, rhs: f64) -> Self::Output {
*self * rhs
}
}
impl<const SIZE: usize> ops::Mul<Polynomial<SIZE>> for f64 {
type Output = Polynomial<SIZE>;
fn mul(self, rhs: Polynomial<SIZE>) -> Self::Output {
let mut me = rhs;
for val in &mut me.coefficients {
*val *= self;
}
me
}
}
impl<const SIZE: usize> ops::AddAssign<f64> for Polynomial<SIZE> {
fn add_assign(&mut self, rhs: f64) {
self.coefficients[0] += rhs;
}
}
impl<const SIZE: usize> fmt::Display for Polynomial<SIZE> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt_with_var(f, "t".to_string())
}
}
impl<const SIZE: usize> fmt::LowerHex for Polynomial<SIZE> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt_with_var(f, "x".to_string())
}
}
pub(crate) fn add<const S1: usize, const S2: usize>(
p1: Polynomial<S1>,
p2: Polynomial<S2>,
) -> Polynomial<S1> {
if S1 < S2 {
panic!();
}
let mut rtn = Polynomial::zeros();
for (i, c1) in p1.coefficients.iter().enumerate() {
rtn.coefficients[i] = match p2.coefficients.get(i) {
Some(c2) => c1 + c2,
None => *c1,
};
}
rtn
}
impl<const S1: usize, const S2: usize> ops::Add<Polynomial<S1>> for Polynomial<S2> {
type Output = Polynomial<S1>;
fn add(self, other: Polynomial<S1>) -> Self::Output {
add(other, self)
}
}
pub(crate) fn sub<const S1: usize, const S2: usize>(
p1: Polynomial<S1>,
p2: Polynomial<S2>,
) -> Polynomial<S1> {
if S1 < S2 {
panic!();
}
let mut rtn = Polynomial::zeros();
for (i, c1) in p1.coefficients.iter().enumerate() {
rtn.coefficients[i] = match p2.coefficients.get(i) {
Some(c2) => c1 - c2,
None => *c1,
};
}
rtn
}
impl<const S1: usize, const S2: usize> ops::Sub<Polynomial<S2>> for Polynomial<S1> {
type Output = Polynomial<S1>;
fn sub(self, other: Polynomial<S2>) -> Self::Output {
sub(self, other)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub enum CommonPolynomial {
Constant(f64),
Linear(f64, f64),
Quadratic(f64, f64, f64),
}
impl CommonPolynomial {
pub fn eval(&self, x: f64) -> f64 {
match *self {
Self::Constant(a) => Polynomial::<1> { coefficients: [a] }.eval(x),
Self::Linear(a, b) => Polynomial::<2> {
coefficients: [b, a],
}
.eval(x),
Self::Quadratic(a, b, c) => Polynomial::<3> {
coefficients: [c, b, a],
}
.eval(x),
}
}
pub fn deriv(&self, x: f64) -> f64 {
match *self {
Self::Constant(a) => Polynomial::<1> { coefficients: [a] }.deriv(x),
Self::Linear(a, b) => Polynomial::<2> {
coefficients: [b, a],
}
.deriv(x),
Self::Quadratic(a, b, c) => Polynomial::<3> {
coefficients: [c, b, a],
}
.deriv(x),
}
}
pub fn coeff_in_order(&self, order: usize) -> Result<f64, NyxError> {
match *self {
Self::Constant(a) => {
if order == 0 {
Ok(a)
} else {
Err(NyxError::PolynomialOrderError { order })
}
}
Self::Linear(a, b) => match order {
0 => Ok(b),
1 => Ok(a),
_ => Err(NyxError::PolynomialOrderError { order }),
},
Self::Quadratic(a, b, c) => match order {
0 => Ok(c),
1 => Ok(b),
2 => Ok(a),
_ => Err(NyxError::PolynomialOrderError { order }),
},
}
}
pub fn with_val_in_order(self, new_val: f64, order: usize) -> Result<Self, NyxError> {
match self {
Self::Constant(_) => {
if order != 0 {
Err(NyxError::PolynomialOrderError { order })
} else {
Ok(Self::Constant(new_val))
}
}
Self::Linear(x, y) => match order {
0 => Ok(Self::Linear(new_val, y)),
1 => Ok(Self::Linear(x, new_val)),
_ => Err(NyxError::PolynomialOrderError { order }),
},
Self::Quadratic(x, y, z) => match order {
0 => Ok(Self::Quadratic(new_val, y, z)),
1 => Ok(Self::Quadratic(x, new_val, z)),
2 => Ok(Self::Quadratic(x, y, new_val)),
_ => Err(NyxError::PolynomialOrderError { order }),
},
}
}
pub fn add_val_in_order(self, new_val: f64, order: usize) -> Result<Self, NyxError> {
match self {
Self::Constant(x) => {
if order != 0 {
Err(NyxError::PolynomialOrderError { order })
} else {
Ok(Self::Constant(new_val + x))
}
}
Self::Linear(x, y) => match order {
0 => Ok(Self::Linear(new_val + x, y)),
1 => Ok(Self::Linear(x, new_val + y)),
_ => Err(NyxError::PolynomialOrderError { order }),
},
Self::Quadratic(x, y, z) => match order {
0 => Ok(Self::Quadratic(new_val + x, y, z)),
1 => Ok(Self::Quadratic(x, new_val + y, z)),
2 => Ok(Self::Quadratic(x, y, new_val + z)),
_ => Err(NyxError::PolynomialOrderError { order }),
},
}
}
}
impl fmt::Display for CommonPolynomial {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::Constant(a) => write!(f, "{}", Polynomial::<1> { coefficients: [a] }),
Self::Linear(a, b) => write!(
f,
"{}",
Polynomial::<2> {
coefficients: [b, a],
}
),
Self::Quadratic(a, b, c) => write!(
f,
"{}",
Polynomial::<3> {
coefficients: [c, b, a],
}
),
}
}
}
#[cfg(test)]
mod ut_poly {
use crate::polyfit::{CommonPolynomial, Polynomial};
#[test]
fn poly_constant() {
let c = CommonPolynomial::Constant(10.0);
for i in -100..=100 {
assert!(
(c.eval(i as f64) - 10.0).abs() < f64::EPSILON,
"Constant polynomial returned wrong value"
);
}
}
#[test]
fn poly_linear() {
let c = CommonPolynomial::Linear(2.0, 10.0);
for i in -100..=100 {
let x = i as f64;
let expect = 2.0 * x + 10.0;
assert!(
(c.eval(x) - expect).abs() < f64::EPSILON,
"Constant polynomial returned wrong value"
);
}
}
#[test]
fn poly_quadratic() {
let p = Polynomial {
coefficients: [101.0, -2.0, 3.0],
};
let p2 = 2.0 * p;
let c = CommonPolynomial::Quadratic(3.0, -2.0, 101.0);
for i in -100..=100 {
let x = i as f64;
let expect = 3.0 * x.powi(2) - 2.0 * x + 101.0;
let expect_deriv = 6.0 * x - 2.0;
assert!(
(c.eval(x) - expect).abs() < f64::EPSILON,
"Polynomial returned wrong value"
);
assert!(
(p.deriv(x) - expect_deriv).abs() < f64::EPSILON,
"Polynomial derivative returned wrong value"
);
assert!(
(p.eval(x) - expect).abs() < f64::EPSILON,
"Polynomial returned wrong value"
);
assert!(
(p2.eval(x) - 2.0 * expect).abs() < f64::EPSILON,
"Polynomial returned wrong value"
);
}
}
#[test]
fn poly_print() {
let p = Polynomial {
coefficients: [101.0, -2.0, 3.0],
};
println!("{}", p);
assert_eq!(
format!("{}", p),
format!("{}", CommonPolynomial::Quadratic(3.0, -2.0, 101.0))
);
}
#[test]
fn poly_add() {
let p1 = Polynomial {
coefficients: [4.0, -2.0, 3.0],
};
let p2 = Polynomial {
coefficients: [0.0, -5.0, 0.0, 2.0],
};
let p_expected = Polynomial {
coefficients: [4.0, -7.0, 3.0, 2.0],
};
let p3 = p1 + p2;
println!("p3 = {:x}\npe = {:x}", p3, p_expected);
assert_eq!(p3, p_expected);
for i in -100..=100 {
let x = i as f64;
let expect = p1.eval(x) + p2.eval(x);
assert!(
(p3.eval(x) - expect).abs() < f64::EPSILON,
"Constant polynomial returned wrong value"
);
}
}
#[test]
fn poly_sub() {
let p2 = Polynomial {
coefficients: [4.0, -2.0, 3.0],
};
let p1 = Polynomial {
coefficients: [0.0, -5.0, 0.0, 2.0],
};
let p_expected = Polynomial {
coefficients: [-4.0, -3.0, -3.0, 2.0],
};
let p3 = p1 - p2;
println!("p3 = {:x}\npe = {:x}", p3, p_expected);
assert_eq!(p3, p_expected);
for i in -100..=100 {
let x = i as f64;
let expect = p1.eval(x) - p2.eval(x);
assert!(
(p3.eval(x) - expect).abs() < f64::EPSILON,
"Constant polynomial returned wrong value"
);
}
}
#[test]
fn poly_serde() {
let c = CommonPolynomial::Quadratic(3.0, -2.0, 101.0);
let c_yml = serde_yml::to_string(&c).unwrap();
println!("{c_yml}");
let c2 = serde_yml::from_str(&c_yml).unwrap();
assert_eq!(c, c2);
}
}