1#![allow(clippy::type_complexity)] #![allow(unused_imports)] use crate::linalg::allocator::Allocator;
23use crate::linalg::{Const, DefaultAllocator, DimName, OMatrix, OVector, U1}; use crate::md::trajectory::{Interpolatable, Traj}; pub use crate::od::estimate::*;
26pub use crate::od::ground_station::*;
27pub use crate::od::snc::*; pub use crate::od::*;
29use crate::propagators::Propagator;
30pub use crate::time::{Duration, Epoch, Unit};
31use anise::prelude::Almanac;
32use indexmap::IndexSet;
33use log::{debug, info, trace, warn};
34use msr::sensitivity::TrackerSensitivity; use nalgebra::{Cholesky, Dyn, Matrix, VecStorage};
36use snafu::prelude::*;
37use solution::msr::MeasurementType;
38use std::collections::BTreeMap;
39use std::marker::PhantomData;
40use std::ops::Add;
41use std::sync::Arc;
42use typed_builder::TypedBuilder;
43
44mod solution;
45
46pub use solution::BLSSolution;
47
48use self::msr::TrackingDataArc;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum BLSSolver {
53 NormalEquations,
55 LevenbergMarquardt,
57}
58
59#[derive(Clone, TypedBuilder)]
61#[builder(doc)]
62pub struct BatchLeastSquares<
63 D: Dynamics,
64 Trk: TrackerSensitivity<D::StateType, D::StateType>, > where
66 D::StateType:
67 Interpolatable + Add<OVector<f64, <D::StateType as State>::Size>, Output = D::StateType>,
68 <D::StateType as State>::Size: DimName, <DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
71 <DefaultAllocator as Allocator<<D::StateType as State>::Size>>::Buffer<f64>: Copy,
72 <DefaultAllocator as Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>>::Buffer<f64>: Copy,
73 DefaultAllocator: Allocator<<D::StateType as State>::Size>
74 + Allocator<<D::StateType as State>::VecLength>
75 + Allocator<U1> + Allocator<U1, <D::StateType as State>::Size>
77 + Allocator<<D::StateType as State>::Size, U1>
78 + Allocator<U1, U1>
79 + Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>,
80{
81 pub prop: Propagator<D>,
83 pub devices: BTreeMap<String, Trk>,
85 #[builder(default = BLSSolver::NormalEquations)]
87 pub solver: BLSSolver,
88 #[builder(default = 1e-4)]
90 pub tolerance_pos_km: f64,
91 #[builder(default = 10)]
93 pub max_iterations: usize,
94 #[builder(default_code = "30 * Unit::Second")]
97 pub max_step: Duration,
98 #[builder(default_code = "1 * Unit::Microsecond")]
100 pub epoch_precision: Duration,
101 #[builder(default = 10.0)]
103 pub lm_lambda_init: f64,
104 #[builder(default = 10.0)] pub lm_lambda_decrease: f64,
107 #[builder(default = 10.0)] pub lm_lambda_increase: f64,
110 #[builder(default = 1e-12)]
112 pub lm_lambda_min: f64,
113 #[builder(default = 1e12)]
115 pub lm_lambda_max: f64,
116 #[builder(default = true)]
118 pub lm_use_diag_scaling: bool,
119 pub almanac: Arc<Almanac>,
120}
121
122#[allow(type_alias_bounds)]
123type StateMatrix<D: Dynamics> =
124 OMatrix<f64, <D::StateType as State>::Size, <D::StateType as State>::Size>;
125
126impl<D, Trk> BatchLeastSquares<D, Trk>
127where
128 D: Dynamics,
129 Trk: TrackerSensitivity<D::StateType, D::StateType> + Clone, D::StateType: Interpolatable
131 + Add<OVector<f64, <D::StateType as State>::Size>, Output = D::StateType>
132 + std::fmt::Debug, <D::StateType as State>::Size: DimName,
134 <DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
135 <DefaultAllocator as Allocator<<D::StateType as State>::Size>>::Buffer<f64>: Copy,
136 <DefaultAllocator as Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>>::Buffer<f64>: Copy,
137 DefaultAllocator: Allocator<<D::StateType as State>::Size>
138 + Allocator<<D::StateType as State>::VecLength>
139 + Allocator<U1>
140 + Allocator<U1, <D::StateType as State>::Size>
141 + Allocator<<D::StateType as State>::Size, U1>
142 + Allocator<U1, U1>
143 + Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>,
144{
145 pub fn estimate(
147 &self,
148 initial_guess: D::StateType,
149 arc: &TrackingDataArc,
150 ) -> Result<BLSSolution<D::StateType>, ODError> {
151 let measurements = &arc.measurements;
152 let num_measurements = measurements.len();
153 let mut devices = self.devices.clone();
154
155 ensure!(
156 num_measurements >= 2,
157 TooFewMeasurementsSnafu {
158 need: 2_usize,
159 action: "BLSE"
160 }
161 );
162
163 info!(
164 "Using {:?} in the Batch Least Squares estimation with {num_measurements} measurements",
165 self.solver
166 );
167 info!("Initial guess: {}", initial_guess.orbit());
168
169 let mut current_estimate = initial_guess;
170 let mut current_covariance = StateMatrix::<D>::zeros();
171 let mut converged = false;
172 let mut corr_pos_km = f64::MAX;
173 let mut lambda = self.lm_lambda_init;
174 let mut current_rms = f64::MAX;
175 let mut iter: usize = 0;
176
177 while iter < self.max_iterations {
179 iter += 1;
180 info!("[{iter}/{}] Current estimate: {}", self.max_iterations, current_estimate.orbit());
181
182 let mut info_matrix = StateMatrix::<D>::identity();
185 let mut normal_matrix = OVector::<f64, <D::StateType as State>::Size>::zeros();
187 let mut sum_sq_weighted_residuals = 0.0;
189
190 let mut prop_inst = self.prop.with(current_estimate.with_stm(), self.almanac.clone()).quiet();
192 let mut epoch = current_estimate.epoch();
193
194 let mut stm = StateMatrix::<D>::identity();
196
197 for (epoch_ref, msr) in measurements.iter() {
198 let msr_epoch = *epoch_ref;
199
200 loop {
201 let delta_t = msr_epoch - epoch;
202 if delta_t <= Duration::ZERO {
203 break;
205 }
206
207 let next_step = delta_t.min(prop_inst.step_size).min(self.max_step);
209
210 let this_state = prop_inst.for_duration(next_step).context(ODPropSnafu)?;
212 epoch = this_state.epoch();
213
214 let step_stm = this_state.stm().expect("STM unavailable");
216 stm = step_stm * stm;
218
219 if (epoch - msr_epoch).abs() < self.epoch_precision {
220 let device = match devices.get_mut(&msr.tracker) {
222 Some(d) => d,
223 None => {
224 error!(
225 "Tracker {} is not in the list of configured devices",
226 msr.tracker
227 );
228 continue;
229 }
230 };
231
232 for msr_type in msr.data.keys().copied() {
233 let mut msr_types = IndexSet::new();
234 msr_types.insert(msr_type);
235
236 let h_tilde = device
237 .h_tilde::<U1>(msr, &msr_types, &this_state, self.almanac.clone())?;
238
239 let computed_meas_opt = device
241 .measure_instantaneous(this_state, None, self.almanac.clone())?;
242
243 let computed_meas = match computed_meas_opt {
244 Some(cm) => cm,
245 None => {
246 debug!("Device {} does not expect measurement at epoch {msr_epoch}, skipping", msr.tracker);
247 continue;
248 }
249 };
250
251 let computed_obs = computed_meas.observation::<U1>(&msr_types)[0];
253
254 let real_obs = msr.observation::<U1>(&msr_types)[0];
256
257 ensure!(
259 real_obs.is_finite(),
260 InvalidMeasurementSnafu {
261 epoch: msr_epoch,
262 val: real_obs
263 }
264 );
265
266 let residual = real_obs - computed_obs;
268
269 let r_matrix = device
271 .measurement_covar_matrix::<U1>(&msr_types, msr_epoch)?;
272 let r_variance = r_matrix[(0, 0)];
273
274 ensure!(r_variance > 0.0, SingularNoiseRkSnafu);
275 let weight = 1.0 / r_variance;
276
277 let h_matrix = h_tilde * stm;
279
280 info_matrix += h_matrix.transpose() * &h_matrix * weight;
283
284 normal_matrix += h_matrix.transpose() * residual * weight;
286
287 sum_sq_weighted_residuals += weight * residual * residual;
289 }
290 }
291 }
292 }
293
294 let state_correction: OVector<f64, <D::StateType as State>::Size>;
296 let iteration_cost_decreased; let current_iter_rms = (sum_sq_weighted_residuals / num_measurements as f64).sqrt();
300
301 match self.solver {
302 BLSSolver::NormalEquations => {
303 let info_matrix_chol = match info_matrix.cholesky() {
305 Some(chol) => chol,
306 None => return Err(ODError::SingularInformationMatrix)
307 };
308 state_correction = info_matrix_chol.solve(&normal_matrix);
309 iteration_cost_decreased = true;
311 current_rms = current_iter_rms;
312 }
313 BLSSolver::LevenbergMarquardt => {
314 let mut d_sq = StateMatrix::<D>::identity();
318 if self.lm_use_diag_scaling {
319 for i in 0..6 {
321 d_sq[(i, i)] = info_matrix.diagonal()[i];
322 }
323 for i in 0..6 {
325 if d_sq[(i, i)] <= 0.0 {
326 d_sq[(i, i)] = 1e-6; warn!("LM Scaling: Found non-positive diagonal element {} in H^TWH, using floor.", info_matrix[(i, i)]);
328 }
329 }
330 } let augmented_matrix = info_matrix + d_sq * lambda;
334
335 if let Some(aug_chol) = augmented_matrix.cholesky() {
336 state_correction = aug_chol.solve(&normal_matrix);
337
338 if current_iter_rms < current_rms || iter == 0 {
342 iteration_cost_decreased = true;
344 lambda /= self.lm_lambda_decrease;
346 lambda = lambda.max(self.lm_lambda_min);
348 debug!("LM: Cost decreased (RMS {current_rms} -> {current_iter_rms}). Decreasing lambda to {lambda}");
349 current_rms = current_iter_rms;
350 } else {
351 iteration_cost_decreased = false;
353 lambda *= self.lm_lambda_increase;
355 lambda = lambda.min(self.lm_lambda_max);
357 debug!("LM: Cost increased/stalled (RMS {current_rms} -> {current_iter_rms}). Increasing lambda to {lambda}");
358 }
360
361 } else {
362 warn!("LM: Augmented matrix (H^TWH + lambda*D^2) singular with lambda={lambda}. Increasing lambda.");
364 lambda *= self.lm_lambda_increase * 10.0; lambda = lambda.min(self.lm_lambda_max);
366 continue;
369 }
370 }
371 }
372
373 if iteration_cost_decreased {
377 current_estimate = current_estimate + state_correction;
378 corr_pos_km = state_correction.fixed_rows::<3>(0).norm();
379
380 let corr_vel_km_s = state_correction.fixed_rows::<3>(3).norm();
381 info!(
382 "[{iter}/{}] RMS: {current_iter_rms:.3}; corrections: {:.3} m\t{:.3} m/s",
383 self.max_iterations,
384 corr_pos_km * 1e3,
385 corr_vel_km_s * 1e3
386 );
387
388 current_covariance = match info_matrix.try_inverse() {
390 Some(cov) => cov,
391 None => {
392 warn!("Information matrix H^TWH is singular.");
393 StateMatrix::<D>::identity()
394 }
395 };
396
397 if corr_pos_km < self.tolerance_pos_km {
399 info!("Converged in {iter} iterations.");
400 converged = true;
401 break;
402 }
403 } else if self.solver == BLSSolver::LevenbergMarquardt {
404 info!("[{iter}/{}] LM: Step rejected, increasing lambda.", self.max_iterations);
406 corr_pos_km = f64::MAX;
408 }
410 }
411
412 if !converged {
413 warn!("Not converged after {} iterations.", self.max_iterations);
414 }
415
416 info!("Batch Least Squares estimation completed.");
417 Ok(BLSSolution {
418 estimated_state: current_estimate,
419 covariance: current_covariance,
420 num_iterations: iter,
421 final_rms: current_rms,
422 final_corr_pos_km: corr_pos_km,
423 converged,
424 })
425 }
426}