use either::Either;
use nalgebra::{Matrix2, Matrix6, Vector2, Vector6};
use serde::{Deserialize, Serialize};
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Matrix2Serde {
#[serde(with = "either::serde_untagged")]
inner: Either<Diag2, Mat2>,
}
impl Matrix2Serde {
pub fn to_matrix(&self) -> Matrix2<f64> {
match self.inner {
Either::Left(diag) => Matrix2::from_diagonal(&Vector2::from_iterator(diag.0)),
Either::Right(mat2) => {
let mut flat: [f64; 4] = [0.0; 4];
for (i, row) in mat2.0.iter().enumerate() {
for (j, val) in row.iter().enumerate() {
flat[4 * i + j] = *val;
}
}
Matrix2::from_row_slice(&flat)
}
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Diag2([f64; 2]);
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Mat2([[f64; 2]; 2]);
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Matrix6Serde {
#[serde(with = "either::serde_untagged")]
inner: Either<Diag6, Mat6>,
}
impl Matrix6Serde {
pub fn to_matrix(&self) -> Matrix6<f64> {
match self.inner {
Either::Left(diag) => Matrix6::from_diagonal(&Vector6::from_iterator(diag.0)),
Either::Right(mat6) => {
let mut flat: [f64; 36] = [0.0; 36];
for (i, row) in mat6.0.iter().enumerate() {
for (j, val) in row.iter().enumerate() {
flat[6 * i + j] = *val;
}
}
Matrix6::from_row_slice(&flat)
}
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Diag6([f64; 6]);
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Mat6([[f64; 6]; 6]);
#[test]
fn test_serde2() {
use serde_yaml;
let m_diag = Matrix2Serde {
inner: Either::Left(Diag2([1.0, 2.0])),
};
println!("Diag -- \n{}", serde_yaml::to_string(&m_diag).unwrap());
let diag_s = "[1.0, 2.0]";
let diag_loaded: Matrix2Serde = serde_yaml::from_str(diag_s).unwrap();
assert_eq!(diag_loaded, m_diag);
let m_full = Matrix2Serde {
inner: Either::Right(Mat2([[1.0, 2.0]; 2])),
};
println!("Full -- \n{}", serde_yaml::to_string(&m_full).unwrap());
let full_mat = r#"
- [1.0, 2.0] # Row 1
- [1.0, 2.0] # Row 2
"#;
let full_loaded: Matrix2Serde = serde_yaml::from_str(full_mat).unwrap();
assert_eq!(full_loaded, m_full);
}
#[test]
fn test_serde6() {
use serde_yaml;
let m_diag = Matrix6Serde {
inner: Either::Left(Diag6([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])),
};
println!("Diag -- \n{}", serde_yaml::to_string(&m_diag).unwrap());
let diag_s = "[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]";
let diag_loaded: Matrix6Serde = serde_yaml::from_str(diag_s).unwrap();
assert_eq!(diag_loaded, m_diag);
let m_full = Matrix6Serde {
inner: Either::Right(Mat6([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; 6])),
};
println!("Full -- \n{}", serde_yaml::to_string(&m_full).unwrap());
let full_mat = r#"
- [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] # Row 1
- [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] # Row 2
- [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] # Row 3
- [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] # Row 4
- [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] # Row 5
- [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] # Row 6
"#;
let full_loaded: Matrix6Serde = serde_yaml::from_str(full_mat).unwrap();
assert_eq!(full_loaded, m_full);
}