1#![allow(clippy::type_complexity)] #![allow(unused_imports)] use crate::linalg::allocator::Allocator;
23use crate::linalg::{Const, DefaultAllocator, DimName, OMatrix, OVector, U1}; use crate::md::trajectory::{Interpolatable, Traj}; pub use crate::od::estimate::*;
26pub use crate::od::ground_station::*;
27pub use crate::od::snc::*; pub use crate::od::*;
29use crate::propagators::Propagator;
30pub use crate::time::{Duration, Epoch, Unit};
31use anise::prelude::Almanac;
32use indexmap::IndexSet;
33use log::{debug, info, trace, warn};
34use msr::sensitivity::TrackerSensitivity; use nalgebra::{Cholesky, Dyn, Matrix, VecStorage};
36use snafu::prelude::*;
37use solution::msr::MeasurementType;
38use std::collections::BTreeMap;
39use std::marker::PhantomData;
40use std::ops::Add;
41use std::sync::Arc;
42use typed_builder::TypedBuilder;
43
44mod solution;
45
46pub use solution::BLSSolution;
47
48use self::msr::TrackingDataArc;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum BLSSolver {
53 NormalEquations,
55 LevenbergMarquardt,
57}
58
59#[derive(Clone, TypedBuilder)]
61#[builder(doc)]
62pub struct BatchLeastSquares<
63 D: Dynamics,
64 Trk: TrackerSensitivity<D::StateType, D::StateType>, > where
66 D::StateType:
67 Interpolatable + Add<OVector<f64, <D::StateType as State>::Size>, Output = D::StateType>,
68 <D::StateType as State>::Size: DimName, <DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
71 <DefaultAllocator as Allocator<<D::StateType as State>::Size>>::Buffer<f64>: Copy,
72 <DefaultAllocator as Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>>::Buffer<f64>: Copy,
73 DefaultAllocator: Allocator<<D::StateType as State>::Size>
74 + Allocator<<D::StateType as State>::VecLength>
75 + Allocator<U1> + Allocator<U1, <D::StateType as State>::Size>
77 + Allocator<<D::StateType as State>::Size, U1>
78 + Allocator<U1, U1>
79 + Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>,
80{
81 pub prop: Propagator<D>,
83 pub devices: BTreeMap<String, Trk>,
85 #[builder(default = BLSSolver::NormalEquations)]
87 pub solver: BLSSolver,
88 #[builder(default = 1e-4)]
90 pub tolerance_pos_km: f64,
91 #[builder(default = 10)]
93 pub max_iterations: usize,
94 #[builder(default_code = "30 * Unit::Second")]
97 pub max_step: Duration,
98 #[builder(default_code = "1 * Unit::Microsecond")]
100 pub epoch_precision: Duration,
101 #[builder(default = 10.0)]
103 pub lm_lambda_init: f64,
104 #[builder(default = 10.0)] pub lm_lambda_decrease: f64,
107 #[builder(default = 10.0)] pub lm_lambda_increase: f64,
110 #[builder(default = 1e-12)]
112 pub lm_lambda_min: f64,
113 #[builder(default = 1e12)]
115 pub lm_lambda_max: f64,
116 #[builder(default = true)]
118 pub lm_use_diag_scaling: bool,
119 pub almanac: Arc<Almanac>,
120}
121
122#[allow(type_alias_bounds)]
123type StateMatrix<D: Dynamics> =
124 OMatrix<f64, <D::StateType as State>::Size, <D::StateType as State>::Size>;
125
126impl<D, Trk> BatchLeastSquares<D, Trk>
127where
128 D: Dynamics,
129 Trk: TrackerSensitivity<D::StateType, D::StateType> + Clone, D::StateType: Interpolatable
131 + Add<OVector<f64, <D::StateType as State>::Size>, Output = D::StateType>
132 + std::fmt::Debug, <D::StateType as State>::Size: DimName,
134 <DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
135 <DefaultAllocator as Allocator<<D::StateType as State>::Size>>::Buffer<f64>: Copy,
136 <DefaultAllocator as Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>>::Buffer<f64>: Copy,
137 DefaultAllocator: Allocator<<D::StateType as State>::Size>
138 + Allocator<<D::StateType as State>::VecLength>
139 + Allocator<U1>
140 + Allocator<U1, <D::StateType as State>::Size>
141 + Allocator<<D::StateType as State>::Size, U1>
142 + Allocator<U1, U1>
143 + Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>,
144{
145 pub fn estimate(
147 &self,
148 initial_guess: D::StateType,
149 arc: &TrackingDataArc,
150 ) -> Result<BLSSolution<D::StateType>, ODError> {
151 let measurements = &arc.measurements;
152 let num_measurements = measurements.len();
153 let mut devices = self.devices.clone();
154
155 ensure!(
156 num_measurements >= 2,
157 TooFewMeasurementsSnafu {
158 need: 2_usize,
159 action: "BLSE"
160 }
161 );
162
163 info!(
164 "Using {:?} in the Batch Least Squares estimation with {num_measurements} measurements",
165 self.solver
166 );
167 info!("Initial guess: {}", initial_guess.orbit());
168
169 let mut current_estimate = initial_guess;
170 let mut current_covariance = StateMatrix::<D>::zeros();
171 let mut converged = false;
172 let mut corr_pos_km = f64::MAX;
173 let mut lambda = self.lm_lambda_init;
174 let mut current_rms = f64::MAX;
175 let mut iter: usize = 0;
176
177 let mut unknown_trackers = IndexSet::new();
178
179 while iter < self.max_iterations {
181 iter += 1;
182 info!("[{iter}/{}] Current estimate: {}", self.max_iterations, current_estimate.orbit());
183
184 let mut info_matrix = StateMatrix::<D>::identity();
187 let mut normal_matrix = OVector::<f64, <D::StateType as State>::Size>::zeros();
189 let mut sum_sq_weighted_residuals = 0.0;
191
192 let mut prop_inst = self.prop.with(current_estimate.with_stm(), self.almanac.clone()).quiet();
194 let mut epoch = current_estimate.epoch();
195
196 let mut stm = StateMatrix::<D>::identity();
198
199 for (epoch_ref, msr) in measurements.iter() {
200 let msr_epoch = *epoch_ref;
201
202 loop {
203 let delta_t = msr_epoch - epoch;
204 if delta_t <= Duration::ZERO {
205 break;
207 }
208
209 let next_step = delta_t.min(prop_inst.step_size).min(self.max_step);
211
212 let this_state = prop_inst.for_duration(next_step).context(ODPropSnafu)?;
214 epoch = this_state.epoch();
215
216 let step_stm = this_state.stm().expect("STM unavailable");
218 stm = step_stm * stm;
220
221 if (epoch - msr_epoch).abs() < self.epoch_precision {
222 let device = match devices.get_mut(&msr.tracker) {
224 Some(d) => d,
225 None => {
226 if !unknown_trackers.contains(&msr.tracker) {
227 error!(
228 "Tracker {} is not in the list of configured devices",
229 msr.tracker
230 );
231 }
232 unknown_trackers.insert(msr.tracker.clone());
233 continue;
234 }
235 };
236
237 for msr_type in msr.data.keys().copied() {
238 let mut msr_types = IndexSet::new();
239 msr_types.insert(msr_type);
240
241 let h_tilde = device
242 .h_tilde::<U1>(msr, &msr_types, &this_state, self.almanac.clone())?;
243
244 let computed_meas_opt = device
246 .measure_instantaneous(this_state, None, self.almanac.clone())?;
247
248 let computed_meas = match computed_meas_opt {
249 Some(cm) => cm,
250 None => {
251 debug!("Device {} does not expect measurement at epoch {msr_epoch}, skipping", msr.tracker);
252 continue;
253 }
254 };
255
256 let computed_obs = computed_meas.observation::<U1>(&msr_types)[0];
258
259 let real_obs = msr.observation::<U1>(&msr_types)[0];
261
262 ensure!(
264 real_obs.is_finite(),
265 InvalidMeasurementSnafu {
266 epoch: msr_epoch,
267 val: real_obs
268 }
269 );
270
271 let residual = real_obs - computed_obs;
273
274 let r_matrix = device
276 .measurement_covar_matrix::<U1>(&msr_types, msr_epoch)?;
277 let r_variance = r_matrix[(0, 0)];
278
279 ensure!(r_variance > 0.0, SingularNoiseRkSnafu);
280 let weight = 1.0 / r_variance;
281
282 let h_matrix = h_tilde * stm;
284
285 info_matrix += h_matrix.transpose() * &h_matrix * weight;
288
289 normal_matrix += h_matrix.transpose() * residual * weight;
291
292 sum_sq_weighted_residuals += weight * residual * residual;
294 }
295 }
296 }
297 }
298
299 let state_correction: OVector<f64, <D::StateType as State>::Size>;
301 let iteration_cost_decreased; let current_iter_rms = (sum_sq_weighted_residuals / num_measurements as f64).sqrt();
305
306 match self.solver {
307 BLSSolver::NormalEquations => {
308 let info_matrix_chol = match info_matrix.cholesky() {
310 Some(chol) => chol,
311 None => return Err(ODError::SingularInformationMatrix)
312 };
313 state_correction = info_matrix_chol.solve(&normal_matrix);
314 iteration_cost_decreased = true;
316 current_rms = current_iter_rms;
317 }
318 BLSSolver::LevenbergMarquardt => {
319 let mut d_sq = StateMatrix::<D>::identity();
323 if self.lm_use_diag_scaling {
324 for i in 0..6 {
326 d_sq[(i, i)] = info_matrix.diagonal()[i];
327 }
328 for i in 0..6 {
330 if d_sq[(i, i)] <= 0.0 {
331 d_sq[(i, i)] = 1e-6; warn!("LM Scaling: Found non-positive diagonal element {} in H^TWH, using floor.", info_matrix[(i, i)]);
333 }
334 }
335 } let augmented_matrix = info_matrix + d_sq * lambda;
339
340 if let Some(aug_chol) = augmented_matrix.cholesky() {
341 state_correction = aug_chol.solve(&normal_matrix);
342
343 if current_iter_rms < current_rms || iter == 0 {
347 iteration_cost_decreased = true;
349 lambda /= self.lm_lambda_decrease;
351 lambda = lambda.max(self.lm_lambda_min);
353 debug!("LM: Cost decreased (RMS {current_rms} -> {current_iter_rms}). Decreasing lambda to {lambda}");
354 current_rms = current_iter_rms;
355 } else {
356 iteration_cost_decreased = false;
358 lambda *= self.lm_lambda_increase;
360 lambda = lambda.min(self.lm_lambda_max);
362 debug!("LM: Cost increased/stalled (RMS {current_rms} -> {current_iter_rms}). Increasing lambda to {lambda}");
363 }
365
366 } else {
367 warn!("LM: Augmented matrix (H^TWH + lambda*D^2) singular with lambda={lambda}. Increasing lambda.");
369 lambda *= self.lm_lambda_increase * 10.0; lambda = lambda.min(self.lm_lambda_max);
371 continue;
374 }
375 }
376 }
377
378 if iteration_cost_decreased {
382 current_estimate = current_estimate + state_correction;
383 corr_pos_km = state_correction.fixed_rows::<3>(0).norm();
384
385 let corr_vel_km_s = state_correction.fixed_rows::<3>(3).norm();
386 info!(
387 "[{iter}/{}] RMS: {current_iter_rms:.3}; corrections: {:.3} m\t{:.3} m/s",
388 self.max_iterations,
389 corr_pos_km * 1e3,
390 corr_vel_km_s * 1e3
391 );
392
393 current_covariance = match info_matrix.try_inverse() {
395 Some(cov) => cov,
396 None => {
397 warn!("Information matrix H^TWH is singular.");
398 StateMatrix::<D>::identity()
399 }
400 };
401
402 if corr_pos_km < self.tolerance_pos_km {
404 info!("Converged in {iter} iterations.");
405 converged = true;
406 break;
407 }
408 } else if self.solver == BLSSolver::LevenbergMarquardt {
409 info!("[{iter}/{}] LM: Step rejected, increasing lambda.", self.max_iterations);
411 corr_pos_km = f64::MAX;
413 }
415 }
416
417 if !converged {
418 warn!("Not converged after {} iterations.", self.max_iterations);
419 }
420
421 info!("Batch Least Squares estimation completed.");
422 Ok(BLSSolution {
423 estimated_state: current_estimate,
424 covariance: current_covariance,
425 num_iterations: iter,
426 final_rms: current_rms,
427 final_corr_pos_km: corr_pos_km,
428 converged,
429 })
430 }
431}