1use super::{MvnSpacecraft, StateDispersion};
20use crate::Spacecraft;
21use crate::md::StateParameter;
22use nalgebra::{SMatrix, SVector};
23use pyo3::exceptions::PyValueError;
24use pyo3::prelude::*;
25use pyo3::types::PyType;
26use rand::SeedableRng;
27use rand_distr::Distribution;
28use rand_pcg::Pcg64Mcg;
29
30#[pymethods]
31impl MvnSpacecraft {
32 #[new]
33 fn py_new(template: Spacecraft, dispersions: Vec<StateDispersion>) -> PyResult<Self> {
34 MvnSpacecraft::new(template, dispersions).map_err(|e| PyValueError::new_err(e.to_string()))
35 }
36
37 #[classmethod]
38 #[pyo3(name = "from_spacecraft_cov")]
39 fn py_from_spacecraft_cov(
40 _cls: &Bound<'_, PyType>,
41 template: Spacecraft,
42 cov: Vec<Vec<f64>>,
43 mean: Vec<f64>,
44 ) -> PyResult<Self> {
45 if cov.len() != 9 || cov.iter().any(|row| row.len() != 9) {
46 return Err(PyValueError::new_err(
47 "Covariance matrix must be 9x9 (rows and columns)",
48 ));
49 }
50 if mean.len() != 9 {
51 return Err(PyValueError::new_err("Mean vector must be length 9"));
52 }
53
54 let cov_mat = SMatrix::<f64, 9, 9>::from_fn(|r, c| cov[r][c]);
55 let mean_vec = SVector::<f64, 9>::from_vec(mean);
56
57 MvnSpacecraft::from_spacecraft_cov(template, cov_mat, mean_vec)
58 .map_err(|e| PyValueError::new_err(e.to_string()))
59 }
60
61 #[classmethod]
62 #[pyo3(name = "zero_mean")]
63 fn py_zero_mean(
64 _cls: &Bound<'_, PyType>,
65 template: Spacecraft,
66 dispersions: Vec<StateDispersion>,
67 ) -> PyResult<Self> {
68 MvnSpacecraft::zero_mean(template, dispersions)
69 .map_err(|e| PyValueError::new_err(e.to_string()))
70 }
71
72 #[pyo3(signature = (count, seed=None))]
81 fn sample(&self, count: usize, seed: Option<u64>) -> Vec<Spacecraft> {
82 let mut rng = match seed {
83 Some(s) => Pcg64Mcg::seed_from_u64(s),
84 None => Pcg64Mcg::from_rng(&mut rand::rng()),
85 };
86
87 let mut samples = Vec::with_capacity(count.min(100_000));
88 for _ in 0..count {
89 let dispersed_state = Distribution::sample(self, &mut rng);
90 samples.push(dispersed_state.state);
91 }
92 samples
93 }
94}
95
96#[pymethods]
97impl StateDispersion {
98 #[new]
99 fn py_new(param: StateParameter, std_dev: Option<f64>, mean: Option<f64>) -> PyResult<Self> {
100 let builder = StateDispersion::builder().param(param);
101 Ok(match (std_dev, mean) {
102 (Some(s), Some(m)) => builder.std_dev(s).mean(m).build(),
103 (Some(s), None) => builder.std_dev(s).build(),
104 (None, Some(m)) => builder.mean(m).build(),
105 (None, None) => builder.build(),
106 })
107 }
108
109 #[classmethod]
110 #[pyo3(name = "zero_mean")]
111 fn py_zero_mean(
112 _cls: &Bound<'_, PyType>,
113 param: StateParameter,
114 std_dev: f64,
115 ) -> PyResult<Self> {
116 Ok(StateDispersion::zero_mean(param, std_dev))
117 }
118}