nyx_space/od/blse/
mod.rs

1/*
2    Nyx, blazing fast astrodynamics
3    Copyright (C) 2018-onwards Christopher Rabotin <christopher.rabotin@gmail.com>
4
5    This program is free software: you can redistribute it and/or modify
6    it under the terms of the GNU Affero General Public License as published
7    by the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU Affero General Public License for more details.
14
15    You should have received a copy of the GNU Affero General Public License
16    along with this program.  If not, see <https://www.gnu.org/licenses/>.
17*/
18
19#![allow(clippy::type_complexity)] // Allow complex types for generics
20#![allow(unused_imports)] // Keep imports for context even if slightly unused in snippet
21
22use crate::linalg::allocator::Allocator;
23use crate::linalg::{Const, DefaultAllocator, DimName, OMatrix, OVector, U1}; // Use U1 for MsrSize
24use crate::md::trajectory::{Interpolatable, Traj}; // May not need Traj if we propagate point-to-point
25pub use crate::od::estimate::*;
26pub use crate::od::ground_station::*;
27pub use crate::od::snc::*; // SNC not typically used in BLS, but keep context
28pub use crate::od::*;
29use crate::propagators::Propagator;
30pub use crate::time::{Duration, Epoch, Unit};
31use anise::prelude::Almanac;
32use indexmap::IndexSet;
33use log::{debug, info, trace, warn};
34use msr::sensitivity::TrackerSensitivity; // Assuming this is the correct path
35use nalgebra::{Cholesky, Dyn, Matrix, VecStorage};
36use snafu::prelude::*;
37use solution::msr::MeasurementType;
38use std::collections::BTreeMap;
39use std::marker::PhantomData;
40use std::ops::Add;
41use std::sync::Arc;
42use typed_builder::TypedBuilder;
43
44mod solution;
45
46pub use solution::BLSSolution;
47
48use self::msr::TrackingDataArc;
49
50/// Solver choice for the Batch Least Squares estimator
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum BLSSolver {
53    /// Standard Normal Equations: (H^T W H) dx = H^T W dy
54    NormalEquations,
55    /// Levenberg-Marquardt: (H^T W H + lambda * D^T D) dx = H^T W dy
56    LevenbergMarquardt,
57}
58
59/// Configuration for the Batch Least Squares estimator
60#[derive(Clone, TypedBuilder)]
61#[builder(doc)]
62pub struct BatchLeastSquares<
63    D: Dynamics,
64    Trk: TrackerSensitivity<D::StateType, D::StateType>, // Use the same TrackerSensitivity
65> where
66    D::StateType:
67        Interpolatable + Add<OVector<f64, <D::StateType as State>::Size>, Output = D::StateType>,
68    <D::StateType as State>::Size: DimName, // Add DimName bound for state size
69    // Add Allocator constraints similar to KalmanODProcess, but using U1 for MsrSize
70    <DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
71    <DefaultAllocator as Allocator<<D::StateType as State>::Size>>::Buffer<f64>: Copy,
72    <DefaultAllocator as Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>>::Buffer<f64>: Copy,
73    DefaultAllocator: Allocator<<D::StateType as State>::Size>
74        + Allocator<<D::StateType as State>::VecLength>
75        + Allocator<U1> // MsrSize is U1
76        + Allocator<U1, <D::StateType as State>::Size>
77        + Allocator<<D::StateType as State>::Size, U1>
78        + Allocator<U1, U1>
79        + Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>,
80{
81    /// Propagator used for the estimation reference trajectory
82    pub prop: Propagator<D>,
83    /// Tracking devices
84    pub devices: BTreeMap<String, Trk>,
85    /// Solver method
86    #[builder(default = BLSSolver::NormalEquations)]
87    pub solver: BLSSolver,
88    /// Convergence tolerance on the norm of the correction on position, in kilometers
89    #[builder(default = 1e-4)]
90    pub tolerance_pos_km: f64,
91    /// Maximum number of iterations
92    #[builder(default = 10)]
93    pub max_iterations: usize,
94    /// Maximum step size where the STM linearization is assumed correct
95    /// (30 seconds is usually fine, but too large and info matrix could be singular)
96    #[builder(default_code = "30 * Unit::Second")]
97    pub max_step: Duration,
98    /// Precision of the measurement epoch when processing measurements.
99    #[builder(default_code = "1 * Unit::Microsecond")]
100    pub epoch_precision: Duration,
101    /// Initial damping factor for Levenberg-Marquardt
102    #[builder(default = 10.0)]
103    pub lm_lambda_init: f64,
104    /// Factor to decrease lambda by in LM
105    #[builder(default = 10.0)] // Decrease aggressively if step is good
106    pub lm_lambda_decrease: f64,
107    /// Factor to increase lambda by in LM
108    #[builder(default = 10.0)] // Increase aggressively if step is bad
109    pub lm_lambda_increase: f64,
110    /// Minimum value for LM lambda
111    #[builder(default = 1e-12)]
112    pub lm_lambda_min: f64,
113    /// Maximum value for LM lambda
114    #[builder(default = 1e12)]
115    pub lm_lambda_max: f64,
116    /// Use diagonal scaling (D = sqrt(diag(H^T W H))) in LM
117    #[builder(default = true)]
118    pub lm_use_diag_scaling: bool,
119    pub almanac: Arc<Almanac>,
120}
121
122#[allow(type_alias_bounds)]
123type StateMatrix<D: Dynamics> =
124    OMatrix<f64, <D::StateType as State>::Size, <D::StateType as State>::Size>;
125
126impl<D, Trk> BatchLeastSquares<D, Trk>
127where
128    D: Dynamics,
129    Trk: TrackerSensitivity<D::StateType, D::StateType> + Clone, // Add Clone requirement for Trk
130    D::StateType: Interpolatable
131        + Add<OVector<f64, <D::StateType as State>::Size>, Output = D::StateType>
132        + std::fmt::Debug, // Add Debug for logging
133    <D::StateType as State>::Size: DimName,
134    <DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
135    <DefaultAllocator as Allocator<<D::StateType as State>::Size>>::Buffer<f64>: Copy,
136    <DefaultAllocator as Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>>::Buffer<f64>: Copy,
137    DefaultAllocator: Allocator<<D::StateType as State>::Size>
138        + Allocator<<D::StateType as State>::VecLength>
139        + Allocator<U1>
140        + Allocator<U1, <D::StateType as State>::Size>
141        + Allocator<<D::StateType as State>::Size, U1>
142        + Allocator<U1, U1>
143        + Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>,
144{
145    /// Processes a tracking data arc to estimate the state using Batch Least Squares.
146    pub fn estimate(
147        &self,
148        initial_guess: D::StateType,
149        arc: &TrackingDataArc,
150    ) -> Result<BLSSolution<D::StateType>, ODError> {
151        let measurements = &arc.measurements;
152        let num_measurements = measurements.len();
153        let mut devices = self.devices.clone();
154
155        ensure!(
156            num_measurements >= 2,
157            TooFewMeasurementsSnafu {
158                need: 2_usize,
159                action: "BLSE"
160            }
161        );
162
163        info!(
164            "Using {:?} in the Batch Least Squares estimation with {num_measurements} measurements",
165            self.solver
166        );
167        info!("Initial guess: {}", initial_guess.orbit());
168
169        let mut current_estimate = initial_guess;
170        let mut current_covariance = StateMatrix::<D>::zeros();
171        let mut converged = false;
172        let mut corr_pos_km = f64::MAX;
173        let mut lambda = self.lm_lambda_init;
174        let mut current_rms = f64::MAX;
175        let mut iter: usize = 0;
176
177        // --- Iteration Loop ---
178        while iter < self.max_iterations {
179            iter += 1;
180            info!("[{iter}/{}] Current estimate: {}", self.max_iterations, current_estimate.orbit());
181
182            // Re-initialize matrices for this iteration
183            // Information Matrix: Lambda = H^T * W * H
184            let mut info_matrix = StateMatrix::<D>::identity();
185            // Normal Matrix: N = H^T * W * dy
186            let mut normal_matrix = OVector::<f64, <D::StateType as State>::Size>::zeros();
187            // Sum of squares of weighted residuals for RMS calculation and LM cost
188            let mut sum_sq_weighted_residuals = 0.0;
189
190            // Set up a single propagator for the whole iteration.
191            let mut prop_inst = self.prop.with(current_estimate.with_stm(), self.almanac.clone()).quiet();
192            let mut epoch = current_estimate.epoch();
193
194            // Store the STM to the start of the batch.
195            let mut stm = StateMatrix::<D>::identity();
196
197            for (epoch_ref, msr) in measurements.iter() {
198                let msr_epoch = *epoch_ref;
199
200                loop {
201                    let delta_t = msr_epoch - epoch;
202                    if delta_t <= Duration::ZERO {
203                        // Move onto the next measurement.
204                        break;
205                    }
206
207                    // Propagate for the minimum time between the maximum step size, the next step size, and the duration to the next measurement.
208                    let next_step = delta_t.min(prop_inst.step_size).min(self.max_step);
209
210                    // Propagate reference state from the previous state to msr_epoch
211                    let this_state = prop_inst.for_duration(next_step).context(ODPropSnafu)?;
212                    epoch = this_state.epoch();
213
214                    // Grab the STM Phi(t_{i+1}, t_i) from the propagated state's STM.
215                    let step_stm = this_state.stm().expect("STM unavailable");
216                    // Compute the STM Phi(t_{i+1}, t_0) = Phi(t_{i+1}, t_i) * Phi(t_i, t_0)
217                    stm = step_stm * stm;
218
219                    if (epoch - msr_epoch).abs() < self.epoch_precision {
220                        // Get the correct tracking device
221                        let device = match devices.get_mut(&msr.tracker) {
222                            Some(d) => d,
223                            None => {
224                                error!(
225                                    "Tracker {} is not in the list of configured devices",
226                                    msr.tracker
227                                );
228                                continue;
229                            }
230                        };
231
232                        for msr_type in msr.data.keys().copied() {
233                            let mut msr_types = IndexSet::new();
234                            msr_types.insert(msr_type);
235
236                            let h_tilde = device
237                            .h_tilde::<U1>(msr, &msr_types, &this_state, self.almanac.clone())?;
238
239                            // Compute expected measurement H(X(t_i))
240                            let computed_meas_opt = device
241                                .measure_instantaneous(this_state, None, self.almanac.clone())?;
242
243                            let computed_meas = match computed_meas_opt {
244                                Some(cm) => cm,
245                                None => {
246                                    debug!("Device {} does not expect measurement at epoch {msr_epoch}, skipping", msr.tracker);
247                                    continue;
248                                }
249                            };
250
251                            // Get the computed observation value
252                            let computed_obs = computed_meas.observation::<U1>(&msr_types)[0];
253
254                            // Get real observation y_i
255                            let real_obs = msr.observation::<U1>(&msr_types)[0];
256
257                            // Sanity check measurement value
258                            ensure!(
259                                real_obs.is_finite(),
260                                InvalidMeasurementSnafu {
261                                    epoch: msr_epoch,
262                                    val: real_obs
263                                }
264                            );
265
266                            // Compute residual dy = y_i - H(X(t_i))
267                            let residual = real_obs - computed_obs;
268
269                            // Get measurement variance R (assuming 1x1 matrix) and weight W = 1/R
270                            let r_matrix = device
271                                .measurement_covar_matrix::<U1>(&msr_types, msr_epoch)?;
272                            let r_variance = r_matrix[(0, 0)];
273
274                            ensure!(r_variance > 0.0, SingularNoiseRkSnafu);
275                            let weight = 1.0 / r_variance;
276
277                            // Compute H_matrix = H_tilde * Phi(t_i, t_0) (sensitivity wrt initial state X_0)
278                            let h_matrix = h_tilde * stm;
279
280                            // Accumulate Information Matrix: info_matrix += H^T * W * H
281                            // Recall that the weight is a scalar, so we can move it to the end of the operation.
282                            info_matrix += h_matrix.transpose() * &h_matrix * weight;
283
284                            // Accumulate Normal Matrix: normal_matrix += H^T * W * y
285                            normal_matrix += h_matrix.transpose() * residual * weight;
286
287                            // Accumulate sum of squares of weighted residuals
288                            sum_sq_weighted_residuals += weight * residual * residual;
289                        }
290                    }
291                }
292            }
293
294            // --- Solve for State Correction dx ---
295            let state_correction: OVector<f64, <D::StateType as State>::Size>;
296            let iteration_cost_decreased; // For LM logic
297
298            // Use num_measurements for consistency
299            let current_iter_rms = (sum_sq_weighted_residuals / num_measurements as f64).sqrt();
300
301            match self.solver {
302                BLSSolver::NormalEquations => {
303                    // Solve Lambda * dx = N => dx = Lambda^-1 * N
304                    let info_matrix_chol = match info_matrix.cholesky() {
305                         Some(chol) => chol,
306                         None => return Err(ODError::SingularInformationMatrix)
307                    };
308                    state_correction = info_matrix_chol.solve(&normal_matrix);
309                    // Assume NE always decreases cost locally
310                    iteration_cost_decreased = true;
311                    current_rms = current_iter_rms;
312                }
313                BLSSolver::LevenbergMarquardt => {
314                     // Solve (Lambda + lambda * D^T D) * dx = N
315                     // D^T D is a diagonal scaling matrix.
316                     // Common choices: D^T D = I or D^T D = diag(Lambda)
317                    let mut d_sq = StateMatrix::<D>::identity();
318                    if self.lm_use_diag_scaling {
319                        // Use D^T D = diag(Lambda)
320                        for i in 0..6 {
321                            d_sq[(i, i)] = info_matrix.diagonal()[i];
322                        }
323                        // Ensure diagonal elements are positive for stability
324                        for i in 0..6 {
325                            if d_sq[(i, i)] <= 0.0 {
326                                d_sq[(i, i)] = 1e-6; // Set a small positive floor
327                                warn!("LM Scaling: Found non-positive diagonal element {} in H^TWH, using floor.", info_matrix[(i, i)]);
328                            }
329                        }
330                    } // else d_sq remains Identity
331
332                    // Inner LM loop to find suitable lambda
333                    let augmented_matrix = info_matrix + d_sq * lambda;
334
335                    if let Some(aug_chol) = augmented_matrix.cholesky() {
336                        state_correction = aug_chol.solve(&normal_matrix);
337
338                        // --- LM Lambda Update Logic ---
339                        // Simple strategy: Check if RMS decreased. More robust methods exist.
340                        // For a simple check, we compare current_iter_rms with the previous iteration's RMS.
341                        if current_iter_rms < current_rms || iter == 0 {
342                            // Cost (approximated by RMS) decreased or first iteration
343                            iteration_cost_decreased = true;
344                            // Decrease damping
345                            lambda /= self.lm_lambda_decrease;
346                            // Clamp to min
347                            lambda = lambda.max(self.lm_lambda_min);
348                            debug!("LM: Cost decreased (RMS {current_rms} -> {current_iter_rms}). Decreasing lambda to {lambda}");
349                            current_rms = current_iter_rms;
350                        } else {
351                             // Cost increased or stalled
352                             iteration_cost_decreased = false;
353                             // Increase damping
354                             lambda *= self.lm_lambda_increase;
355                             // Clamp to max
356                             lambda = lambda.min(self.lm_lambda_max);
357                             debug!("LM: Cost increased/stalled (RMS {current_rms} -> {current_iter_rms}). Increasing lambda to {lambda}");
358                             // Don't update current_rms baseline if cost increased
359                        }
360
361                    } else {
362                        // Augmented matrix is singular, increase lambda significantly and retry
363                        warn!("LM: Augmented matrix (H^TWH + lambda*D^2) singular with lambda={lambda}. Increasing lambda.");
364                        lambda *= self.lm_lambda_increase * 10.0; // Increase more aggressively
365                        lambda = lambda.min(self.lm_lambda_max);
366                        // Skip update in this iteration, force retry with larger lambda next time if possible
367                        // Skip the rest of the loop and go to next iteration
368                        continue;
369                    }
370                }
371            }
372
373            // --- Update State Estimate ---
374            // Only update if the step is considered successful (esp. for LM)
375            // Also hit if using normal equations because iteration_cost_decreased is forced to true
376            if iteration_cost_decreased {
377                current_estimate = current_estimate + state_correction;
378                corr_pos_km = state_correction.fixed_rows::<3>(0).norm();
379
380                let corr_vel_km_s = state_correction.fixed_rows::<3>(3).norm();
381                info!(
382                    "[{iter}/{}] RMS: {current_iter_rms:.3}; corrections: {:.3} m\t{:.3} m/s",
383                    self.max_iterations,
384                    corr_pos_km * 1e3,
385                    corr_vel_km_s * 1e3
386                );
387
388                // Update the covariance
389                current_covariance = match info_matrix.try_inverse() {
390                    Some(cov) => cov,
391                    None => {
392                        warn!("Information matrix H^TWH is singular.");
393                        StateMatrix::<D>::identity()
394                    }
395               };
396
397                // --- Check Convergence ---
398                if corr_pos_km < self.tolerance_pos_km {
399                    info!("Converged in {iter} iterations.");
400                    converged = true;
401                    break;
402                }
403            } else if self.solver == BLSSolver::LevenbergMarquardt {
404                 // LM step was rejected (cost increased)
405                 info!("[{iter}/{}] LM: Step rejected, increasing lambda.", self.max_iterations);
406                 // Reset correction norm as step was bad
407                 corr_pos_km = f64::MAX;
408                 // The loop will continue with the increased lambda
409            }
410        }
411
412        if !converged {
413            warn!("Not converged after {} iterations.", self.max_iterations);
414        }
415
416        info!("Batch Least Squares estimation completed.");
417        Ok(BLSSolution {
418            estimated_state: current_estimate,
419            covariance: current_covariance,
420            num_iterations: iter,
421            final_rms: current_rms,
422            final_corr_pos_km: corr_pos_km,
423            converged,
424        })
425    }
426}