1#![allow(clippy::type_complexity)] #![allow(unused_imports)] use crate::linalg::allocator::Allocator;
22use crate::linalg::{Const, DefaultAllocator, DimName, OMatrix, OVector, U1}; use crate::md::trajectory::{Interpolatable, Traj}; pub use crate::od::estimate::*;
25pub use crate::od::ground_station::*;
26pub use crate::od::snc::*; pub use crate::od::*;
28use crate::propagators::Propagator;
29pub use crate::time::{Duration, Epoch, Unit};
30use anise::prelude::Almanac;
31use indexmap::IndexSet;
32use log::error;
33use log::{debug, info, trace, warn};
34use msr::sensitivity::TrackerSensitivity; use nalgebra::{Cholesky, Dyn, Matrix, VecStorage};
36use snafu::prelude::*;
37use solution::msr::MeasurementType;
38use std::collections::BTreeMap;
39use std::marker::PhantomData;
40use std::ops::Add;
41use std::sync::Arc;
42use typed_builder::TypedBuilder;
43
44mod solution;
45
46pub use solution::BLSSolution;
47
48use self::msr::TrackingDataArc;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum BLSSolver {
53 NormalEquations,
55 LevenbergMarquardt,
57}
58
59#[derive(Clone, TypedBuilder)]
61#[builder(doc)]
62pub struct BatchLeastSquares<
63 D: Dynamics,
64 Trk: TrackerSensitivity<D::StateType, D::StateType>, > where
66 D::StateType:
67 Interpolatable + Add<OVector<f64, <D::StateType as State>::Size>, Output = D::StateType>,
68 <D::StateType as State>::Size: DimName, <DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
71 <DefaultAllocator as Allocator<<D::StateType as State>::Size>>::Buffer<f64>: Copy,
72 <DefaultAllocator as Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>>::Buffer<f64>: Copy,
73 DefaultAllocator: Allocator<<D::StateType as State>::Size>
74 + Allocator<<D::StateType as State>::VecLength>
75 + Allocator<U1> + Allocator<U1, <D::StateType as State>::Size>
77 + Allocator<<D::StateType as State>::Size, U1>
78 + Allocator<U1, U1>
79 + Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>,
80{
81 pub prop: Propagator<D>,
83 pub devices: BTreeMap<String, Trk>,
85 #[builder(default = BLSSolver::NormalEquations)]
87 pub solver: BLSSolver,
88 #[builder(default = 1e-4)]
90 pub tolerance_pos_km: f64,
91 #[builder(default = 10)]
93 pub max_iterations: usize,
94 #[builder(default_code = "30 * Unit::Second")]
97 pub max_step: Duration,
98 #[builder(default_code = "1 * Unit::Microsecond")]
100 pub epoch_precision: Duration,
101 #[builder(default = 10.0)]
103 pub lm_lambda_init: f64,
104 #[builder(default = 10.0)] pub lm_lambda_decrease: f64,
107 #[builder(default = 10.0)] pub lm_lambda_increase: f64,
110 #[builder(default = 1e-12)]
112 pub lm_lambda_min: f64,
113 #[builder(default = 1e12)]
115 pub lm_lambda_max: f64,
116 #[builder(default = true)]
118 pub lm_use_diag_scaling: bool,
119 pub almanac: Arc<Almanac>,
120}
121
122#[allow(type_alias_bounds)]
123type StateMatrix<D: Dynamics> =
124 OMatrix<f64, <D::StateType as State>::Size, <D::StateType as State>::Size>;
125
126impl<D, Trk> BatchLeastSquares<D, Trk>
127where
128 D: Dynamics,
129 Trk: TrackerSensitivity<D::StateType, D::StateType> + Clone, D::StateType: Interpolatable
131 + Add<OVector<f64, <D::StateType as State>::Size>, Output = D::StateType>
132 + std::fmt::Debug, <D::StateType as State>::Size: DimName,
134 <DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
135 <DefaultAllocator as Allocator<<D::StateType as State>::Size>>::Buffer<f64>: Copy,
136 <DefaultAllocator as Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>>::Buffer<f64>: Copy,
137 DefaultAllocator: Allocator<<D::StateType as State>::Size>
138 + Allocator<<D::StateType as State>::VecLength>
139 + Allocator<U1>
140 + Allocator<U1, <D::StateType as State>::Size>
141 + Allocator<<D::StateType as State>::Size, U1>
142 + Allocator<U1, U1>
143 + Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>,
144{
145 pub fn estimate(
147 &self,
148 initial_guess: D::StateType,
149 arc: &TrackingDataArc,
150 ) -> Result<BLSSolution<D::StateType>, ODError> {
151 let measurements = &arc.measurements;
152 let num_measurements = measurements.values().filter(|m| !m.rejected).count();
153 let mut devices = self.devices.clone();
154
155 ensure!(
156 num_measurements >= 2,
157 TooFewMeasurementsSnafu {
158 need: 2_usize,
159 action: "BLSE"
160 }
161 );
162
163 info!(
164 "Using {:?} in the Batch Least Squares estimation with {num_measurements} measurements",
165 self.solver
166 );
167 info!("Initial guess: {}", initial_guess.orbit());
168
169 let mut current_estimate = initial_guess;
170 let mut current_covariance = StateMatrix::<D>::zeros();
171 let mut converged = false;
172 let mut corr_pos_km = f64::MAX;
173 let mut lambda = self.lm_lambda_init;
174 let mut current_rms = f64::MAX;
175 let mut iter: usize = 0;
176
177 let mut unknown_trackers = IndexSet::new();
178
179 while iter < self.max_iterations {
181 iter += 1;
182 info!("[{iter}/{}] Current estimate: {}", self.max_iterations, current_estimate.orbit());
183
184 let mut info_matrix = StateMatrix::<D>::identity();
187 let mut normal_matrix = OVector::<f64, <D::StateType as State>::Size>::zeros();
189 let mut sum_sq_weighted_residuals = 0.0;
191
192 let mut prop_inst = self.prop.with(current_estimate.with_stm(), self.almanac.clone()).quiet();
194 let mut epoch = current_estimate.epoch();
195
196 let mut stm = StateMatrix::<D>::identity();
198
199 for (epoch_ref, msr) in measurements.iter() {
200 if msr.rejected {
201 continue;
202 }
203 let msr_epoch = *epoch_ref;
204
205 loop {
206 let delta_t = msr_epoch - epoch;
207 if delta_t <= Duration::ZERO {
208 break;
210 }
211
212 let next_step = delta_t.min(prop_inst.step_size).min(self.max_step);
214
215 let this_state = prop_inst.for_duration(next_step).context(ODPropSnafu)?;
217 epoch = this_state.epoch();
218
219 let step_stm = this_state.stm().expect("STM unavailable");
221 stm = step_stm * stm;
223
224 if (epoch - msr_epoch).abs() < self.epoch_precision {
225 let device = match devices.get_mut(&msr.tracker) {
227 Some(d) => d,
228 None => {
229 if !unknown_trackers.contains(&msr.tracker) {
230 error!(
231 "Tracker {} is not in the list of configured devices",
232 msr.tracker
233 );
234 }
235 unknown_trackers.insert(msr.tracker.clone());
236 continue;
237 }
238 };
239
240 for msr_type in msr.data.keys().copied() {
241 let mut msr_types = IndexSet::new();
242 msr_types.insert(msr_type);
243
244 let h_tilde = device
245 .h_tilde::<U1>(msr, &msr_types, &this_state, self.almanac.clone())?;
246
247 let computed_meas_opt = device
249 .measure_instantaneous(this_state, None, self.almanac.clone())?;
250
251 let computed_meas = match computed_meas_opt {
252 Some(cm) => cm,
253 None => {
254 debug!("Device {} does not expect measurement at epoch {msr_epoch}, skipping", msr.tracker);
255 continue;
256 }
257 };
258
259 let computed_obs = computed_meas.observation::<U1>(&msr_types)[0];
261
262 let real_obs = msr.observation::<U1>(&msr_types)[0];
264
265 ensure!(
267 real_obs.is_finite(),
268 InvalidMeasurementSnafu {
269 epoch: msr_epoch,
270 val: real_obs
271 }
272 );
273
274 let residual = real_obs - computed_obs;
276
277 let r_matrix = device
279 .measurement_covar_matrix::<U1>(&msr_types, msr_epoch)?;
280 let r_variance = r_matrix[(0, 0)];
281
282 ensure!(r_variance > 0.0, SingularNoiseRkSnafu);
283 let weight = 1.0 / r_variance;
284
285 let h_matrix = h_tilde * stm;
287
288 info_matrix += h_matrix.transpose() * &h_matrix * weight;
291
292 normal_matrix += h_matrix.transpose() * residual * weight;
294
295 sum_sq_weighted_residuals += weight * residual * residual;
297 }
298 }
299 }
300 }
301
302 let state_correction: OVector<f64, <D::StateType as State>::Size>;
304 let iteration_cost_decreased; let current_iter_rms = (sum_sq_weighted_residuals / num_measurements as f64).sqrt();
308
309 match self.solver {
310 BLSSolver::NormalEquations => {
311 let info_matrix_chol = match info_matrix.cholesky() {
313 Some(chol) => chol,
314 None => return Err(ODError::SingularInformationMatrix)
315 };
316 state_correction = info_matrix_chol.solve(&normal_matrix);
317 iteration_cost_decreased = true;
319 current_rms = current_iter_rms;
320 }
321 BLSSolver::LevenbergMarquardt => {
322 let mut d_sq = StateMatrix::<D>::identity();
326 if self.lm_use_diag_scaling {
327 for i in 0..6 {
329 d_sq[(i, i)] = info_matrix.diagonal()[i];
330 }
331 for i in 0..6 {
333 if d_sq[(i, i)] <= 0.0 {
334 d_sq[(i, i)] = 1e-6; warn!("LM Scaling: Found non-positive diagonal element {} in H^TWH, using floor.", info_matrix[(i, i)]);
336 }
337 }
338 } let augmented_matrix = info_matrix + d_sq * lambda;
342
343 if let Some(aug_chol) = augmented_matrix.cholesky() {
344 state_correction = aug_chol.solve(&normal_matrix);
345
346 if current_iter_rms < current_rms || iter == 0 {
350 iteration_cost_decreased = true;
352 lambda /= self.lm_lambda_decrease;
354 lambda = lambda.max(self.lm_lambda_min);
356 debug!("LM: Cost decreased (RMS {current_rms} -> {current_iter_rms}). Decreasing lambda to {lambda}");
357 current_rms = current_iter_rms;
358 } else {
359 iteration_cost_decreased = false;
361 lambda *= self.lm_lambda_increase;
363 lambda = lambda.min(self.lm_lambda_max);
365 debug!("LM: Cost increased/stalled (RMS {current_rms} -> {current_iter_rms}). Increasing lambda to {lambda}");
366 }
368
369 } else {
370 warn!("LM: Augmented matrix (H^TWH + lambda*D^2) singular with lambda={lambda}. Increasing lambda.");
372 lambda *= self.lm_lambda_increase * 10.0; lambda = lambda.min(self.lm_lambda_max);
374 continue;
377 }
378 }
379 }
380
381 if iteration_cost_decreased {
385 current_estimate = current_estimate + state_correction;
386 corr_pos_km = state_correction.fixed_rows::<3>(0).norm();
387
388 let corr_vel_km_s = state_correction.fixed_rows::<3>(3).norm();
389 info!(
390 "[{iter}/{}] RMS: {current_iter_rms:.3}; corrections: {:.3} m\t{:.3} m/s",
391 self.max_iterations,
392 corr_pos_km * 1e3,
393 corr_vel_km_s * 1e3
394 );
395
396 current_covariance = match info_matrix.udu() {
398 Some(info_udu) => {
399 match info_udu.u.try_inverse() {
400 None =>{
401 warn!("Information matrix H^TWH is singular.");
402 StateMatrix::<D>::identity()
403 },
404 Some(u_inv) => {
405 let d_inv_v = OVector::<f64,<D::StateType as State>::Size>::from_iterator(info_udu.d.iter().map(|d_ii| 1.0 / d_ii));
406 let d_inv = OMatrix::from_diagonal(&d_inv_v);
407 let y = d_inv * u_inv;
408 u_inv.transpose() * y
409 }
410 }
411 }
412 None => {
413 warn!("Information matrix H^TWH is singular.");
414 StateMatrix::<D>::identity()
415 }
416 };
417
418 if corr_pos_km < self.tolerance_pos_km {
420 info!("Converged in {iter} iterations.");
421 converged = true;
422 break;
423 }
424 } else if self.solver == BLSSolver::LevenbergMarquardt {
425 info!("[{iter}/{}] LM: Step rejected, increasing lambda.", self.max_iterations);
427 corr_pos_km = f64::MAX;
429 }
431 }
432
433 if !converged {
434 warn!("Not converged after {} iterations.", self.max_iterations);
435 }
436
437 info!("Batch Least Squares estimation completed.");
438 Ok(BLSSolution {
439 estimated_state: current_estimate,
440 covariance: current_covariance,
441 num_iterations: iter,
442 final_rms: current_rms,
443 final_corr_pos_km: corr_pos_km,
444 converged,
445 })
446 }
447
448 pub fn evaluate(
451 &self,
452 state: D::StateType,
453 arc: &TrackingDataArc,
454 ) -> Result<f64, ODError> {
455 let measurements = &arc.measurements;
456 let num_measurements = measurements.values().filter(|m| !m.rejected).count();
457 let mut devices = self.devices.clone();
458
459 ensure!(
460 num_measurements >= 1,
461 TooFewMeasurementsSnafu {
462 need: 1_usize,
463 action: "BLSE Evaluate"
464 }
465 );
466
467 let mut sum_sq_weighted_residuals = 0.0;
468 let mut unknown_trackers = IndexSet::new();
469
470 let mut prop_inst = self.prop.with(state.with_stm(), self.almanac.clone()).quiet();
471 let mut epoch = state.epoch();
472
473 for (epoch_ref, msr) in measurements.iter() {
474 if msr.rejected {
475 continue;
476 }
477 let msr_epoch = *epoch_ref;
478
479 loop {
480 let delta_t = msr_epoch - epoch;
481 if delta_t <= Duration::ZERO {
482 break;
483 }
484
485 let next_step = delta_t.min(prop_inst.step_size).min(self.max_step);
486 let this_state = prop_inst.for_duration(next_step).context(ODPropSnafu)?;
487 epoch = this_state.epoch();
488
489 if (epoch - msr_epoch).abs() < self.epoch_precision {
490 let device = match devices.get_mut(&msr.tracker) {
491 Some(d) => d,
492 None => {
493 if !unknown_trackers.contains(&msr.tracker) {
494 error!(
495 "Tracker {} is not in the list of configured devices",
496 msr.tracker
497 );
498 }
499 unknown_trackers.insert(msr.tracker.clone());
500 continue;
501 }
502 };
503
504 for msr_type in msr.data.keys().copied() {
505 let mut msr_types = IndexSet::new();
506 msr_types.insert(msr_type);
507
508 let computed_meas_opt = device
509 .measure_instantaneous(this_state, None, self.almanac.clone())?;
510
511 let computed_meas = match computed_meas_opt {
512 Some(cm) => cm,
513 None => continue,
514 };
515
516 let computed_obs = computed_meas.observation::<U1>(&msr_types)[0];
517 let real_obs = msr.observation::<U1>(&msr_types)[0];
518
519 ensure!(
520 real_obs.is_finite(),
521 InvalidMeasurementSnafu {
522 epoch: msr_epoch,
523 val: real_obs
524 }
525 );
526
527 let residual = real_obs - computed_obs;
528 let r_matrix = device.measurement_covar_matrix::<U1>(&msr_types, msr_epoch)?;
529 let r_variance = r_matrix[(0, 0)];
530
531 ensure!(r_variance > 0.0, SingularNoiseRkSnafu);
532 let weight = 1.0 / r_variance;
533
534 sum_sq_weighted_residuals += weight * residual * residual;
535 }
536 }
537 }
538 }
539
540 Ok((sum_sq_weighted_residuals / num_measurements as f64).sqrt())
541 }
542}