Skip to main content

nyx_space/md/opti/
raphson_hyperdual.rs

1/*
2    Nyx, blazing fast astrodynamics
3    Copyright (C) 2018-onwards Christopher Rabotin <christopher.rabotin@gmail.com>
4
5    This program is free software: you can redistribute it and/or modify
6    it under the terms of the GNU Affero General Public License as published
7    by the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU Affero General Public License for more details.
14
15    You should have received a copy of the GNU Affero General Public License
16    along with this program.  If not, see <https://www.gnu.org/licenses/>.
17*/
18
19use super::solution::TargeterSolution;
20use crate::cosmic::{AstroAlmanacSnafu, AstroPhysicsSnafu};
21use crate::errors::TargetingError;
22use crate::linalg::{DMatrix, SVector};
23use crate::md::{prelude::*, PropSnafu, UnderdeterminedProblemSnafu};
24use crate::md::{AstroSnafu, StateParameter};
25pub use crate::md::{Variable, Vary};
26use crate::pseudo_inverse;
27use crate::utils::are_eigenvalues_stable;
28use anise::astro::orbit_gradient::OrbitGrad;
29use log::{debug, info, warn};
30use snafu::{ensure, ResultExt};
31#[cfg(not(target_arch = "wasm32"))]
32use std::time::Instant;
33
34impl<const V: usize, const O: usize> Targeter<'_, V, O> {
35    /// Differential correction using hyperdual numbers for the objectives
36    #[allow(clippy::comparison_chain)]
37    pub fn try_achieve_dual(
38        &self,
39        initial_state: Spacecraft,
40        correction_epoch: Epoch,
41        achievement_epoch: Epoch,
42        almanac: Arc<Almanac>,
43    ) -> Result<TargeterSolution<V, O>, TargetingError> {
44        ensure!(!self.objectives.is_empty(), UnderdeterminedProblemSnafu);
45
46        let mut is_bplane_tgt = false;
47        for obj in &self.objectives {
48            if obj.parameter.is_b_plane() {
49                is_bplane_tgt = true;
50                break;
51            }
52        }
53
54        // Now we know that the problem is correctly defined, so let's propagate as is to the epoch
55        // where the correction should be applied.
56        let xi_start = self
57            .prop
58            .with(initial_state, almanac.clone())
59            .until_epoch(correction_epoch)
60            .context(PropSnafu)?;
61
62        debug!("initial_state = {initial_state:?}");
63        debug!("xi_start = {xi_start:?}");
64
65        let mut xi = xi_start;
66
67        // Store the total correction in a static vector
68        let mut total_correction = SVector::<f64, V>::zeros();
69
70        // Apply the initial guess
71        for (i, var) in self.variables.iter().enumerate() {
72            match var.component {
73                Vary::PositionX => {
74                    xi.orbit.radius_km.x += var.init_guess;
75                }
76                Vary::PositionY => {
77                    xi.orbit.radius_km.y += var.init_guess;
78                }
79                Vary::PositionZ => {
80                    xi.orbit.radius_km.z += var.init_guess;
81                }
82                Vary::VelocityX => {
83                    xi.orbit.velocity_km_s.x += var.init_guess;
84                }
85                Vary::VelocityY => {
86                    xi.orbit.velocity_km_s.y += var.init_guess;
87                }
88                Vary::VelocityZ => {
89                    xi.orbit.velocity_km_s.z += var.init_guess;
90                }
91                _ => {
92                    return Err(TargetingError::UnsupportedVariable {
93                        var: var.to_string(),
94                    });
95                }
96            }
97            total_correction[i] += var.init_guess;
98        }
99
100        let mut prev_err_norm = f64::INFINITY;
101
102        // Determine padding in debugging info
103        // For the width, we find the largest desired values and multiply it by the order of magnitude of its tolerance
104        let max_obj_val = self
105            .objectives
106            .iter()
107            .map(|obj| {
108                obj.desired_value.abs().ceil() as i32
109                    * 10_i32.pow(obj.tolerance.abs().log10().ceil() as u32)
110            })
111            .max()
112            .unwrap();
113
114        let max_obj_tol = self
115            .objectives
116            .iter()
117            .map(|obj| obj.tolerance.log10().abs().ceil() as usize)
118            .max()
119            .unwrap();
120
121        let width = f64::from(max_obj_val).log10() as usize + 2 + max_obj_tol;
122
123        #[cfg(not(target_arch = "wasm32"))]
124        let start_instant = Instant::now();
125
126        for it in 0..=self.iterations {
127            // Now, enable the trajectory STM for this state so we can apply the correction
128            xi.enable_stm();
129
130            // Full propagation for a half period duration is slightly more precise than a step by step one with multiplications in between.
131            let xf = self
132                .prop
133                .with(xi, almanac.clone())
134                .until_epoch(achievement_epoch)
135                .context(PropSnafu)?;
136
137            // Check linearization
138            if !are_eigenvalues_stable(xf.stm().unwrap().complex_eigenvalues()) {
139                warn!(
140                    "STM linearization assumption is wrong for a time step of {}",
141                    achievement_epoch - correction_epoch
142                );
143            }
144
145            let xf_dual_obj_frame = match &self.objective_frame {
146                Some(frame) => {
147                    let orbit_obj_frame = almanac
148                        .transform_to(xf.orbit, *frame, None)
149                        .context(AstroAlmanacSnafu)
150                        .context(AstroSnafu)?;
151
152                    OrbitGrad::from(orbit_obj_frame)
153                }
154                None => OrbitGrad::from(xf.orbit),
155            };
156
157            // Build the error vector
158            let mut err_vector = SVector::<f64, O>::zeros();
159            let mut converged = true;
160
161            // Build the B-Plane once, if needed, and always in the objective frame
162            let b_plane = if is_bplane_tgt {
163                Some(BPlane::from_dual(xf_dual_obj_frame).context(AstroSnafu)?)
164            } else {
165                None
166            };
167
168            // Build debugging information
169            let mut objmsg = Vec::new();
170
171            // The Jacobian includes the sensitivity of each objective with respect to each variable for the whole trajectory.
172            // As such, it includes the STM of that variable for the whole propagation arc.
173            let mut jac = DMatrix::from_element(self.objectives.len(), self.variables.len(), 0.0);
174
175            for (i, obj) in self.objectives.iter().enumerate() {
176                let xf_partial = if obj.parameter.is_b_plane() {
177                    match obj.parameter {
178                        StateParameter::BdotR => b_plane.unwrap().b_r_km,
179                        StateParameter::BdotT => b_plane.unwrap().b_t_km,
180                        StateParameter::BLTOF => b_plane.unwrap().ltof_s,
181                        _ => unreachable!(),
182                    }
183                } else if let StateParameter::Element(oe) = obj.parameter {
184                    xf_dual_obj_frame
185                        .partial_for(oe)
186                        .context(AstroPhysicsSnafu)
187                        .context(AstroSnafu)?
188                } else {
189                    unreachable!()
190                };
191
192                let achieved = xf_partial.real();
193
194                let (ok, param_err) = obj.assess_value(achieved);
195                if !ok {
196                    converged = false;
197                }
198                err_vector[i] = param_err;
199
200                objmsg.push(format!(
201                    "\t{:?}: achieved = {:>width$.prec$}\t desired = {:>width$.prec$}\t scaled error = {:>width$.prec$}",
202                    obj.parameter,
203                    achieved,
204                    obj.desired_value,
205                    param_err, width=width, prec=max_obj_tol
206                ));
207
208                // Build the Jacobian with the partials of the objectives with respect to all of the final state parameters
209                // We localize the problem in the STM.
210                // TODO: VNC (how?!)
211                let mut partial_vec = DMatrix::from_element(1, 6, 0.0);
212                for (i, val) in [
213                    xf_partial.wrt_x(),
214                    xf_partial.wrt_y(),
215                    xf_partial.wrt_z(),
216                    xf_partial.wrt_vx(),
217                    xf_partial.wrt_vy(),
218                    xf_partial.wrt_vz(),
219                ]
220                .iter()
221                .enumerate()
222                {
223                    partial_vec[(0, i)] = *val;
224                }
225
226                for (j, var) in self.variables.iter().enumerate() {
227                    // Grab the STM first.
228                    let sc_stm = xf.stm().unwrap();
229                    let stm = sc_stm.fixed_view::<6, 6>(0, 0);
230                    let idx = var.component.vec_index();
231                    // Compute the partial of the objective over all components wrt to all of the components in the STM of the control variable.
232                    let rslt = &partial_vec * stm.fixed_columns::<1>(idx);
233                    jac[(i, j)] = rslt[(0, 0)];
234                }
235            }
236
237            if converged {
238                #[cfg(not(target_arch = "wasm32"))]
239                let conv_dur = Instant::now() - start_instant;
240                #[cfg(target_arch = "wasm32")]
241                let conv_dur = Duration::ZERO.into();
242                let mut state = xi_start;
243                // Convert the total correction from VNC back to integration frame in case that's needed.
244                for (i, var) in self.variables.iter().enumerate() {
245                    match var.component {
246                        Vary::PositionX => state.orbit.radius_km.x += total_correction[i],
247                        Vary::PositionY => state.orbit.radius_km.y += total_correction[i],
248                        Vary::PositionZ => state.orbit.radius_km.z += total_correction[i],
249                        Vary::VelocityX => state.orbit.velocity_km_s.x += total_correction[i],
250                        Vary::VelocityY => state.orbit.velocity_km_s.y += total_correction[i],
251                        Vary::VelocityZ => state.orbit.velocity_km_s.z += total_correction[i],
252                        _ => {
253                            return Err(TargetingError::UnsupportedVariable {
254                                var: var.to_string(),
255                            })
256                        }
257                    }
258                }
259
260                let sol = TargeterSolution {
261                    corrected_state: state,
262                    achieved_state: xf,
263                    correction: total_correction,
264                    computation_dur: conv_dur,
265                    variables: self.variables,
266                    achieved_errors: err_vector,
267                    achieved_objectives: self.objectives,
268                    iterations: it,
269                };
270                info!("Targeter -- CONVERGED in {it} iterations");
271                for obj in &objmsg {
272                    info!("{obj}");
273                }
274                return Ok(sol);
275            }
276
277            // We haven't converged yet, so let's build the error vector
278            if (err_vector.norm() - prev_err_norm).abs() < 1e-10 {
279                return Err(TargetingError::CorrectionIneffective {
280                    cur_val: err_vector.norm(),
281                    prev_val: prev_err_norm,
282                    action: "No change in objective errors",
283                });
284            }
285            prev_err_norm = err_vector.norm();
286
287            debug!("Jacobian {jac}");
288
289            // Perform the pseudo-inverse if needed, else just inverse
290            let jac_inv = pseudo_inverse!(&jac)?;
291
292            debug!("Inverse Jacobian {jac_inv}");
293
294            let mut delta = jac_inv * err_vector;
295
296            debug!("Error vector: {err_vector}\nRaw correction: {delta}");
297
298            // And finally apply it to the xi
299            for (i, var) in self.variables.iter().enumerate() {
300                // Choose the minimum step between the provided max step and the correction.
301                if delta[i].abs() > var.max_step {
302                    delta[i] = var.max_step * delta[i].signum();
303                } else if delta[i] > var.max_value {
304                    delta[i] = var.max_value;
305                } else if delta[i] < var.min_value {
306                    delta[i] = var.min_value;
307                }
308
309                info!(
310                    "Correction {:?} (element {}): {}",
311                    var.component, i, delta[i]
312                );
313
314                match var.component {
315                    Vary::PositionX => {
316                        xi.orbit.radius_km.x += delta[i];
317                    }
318                    Vary::PositionY => {
319                        xi.orbit.radius_km.y += delta[i];
320                    }
321                    Vary::PositionZ => {
322                        xi.orbit.radius_km.z += delta[i];
323                    }
324                    Vary::VelocityX => {
325                        xi.orbit.velocity_km_s.x += delta[i];
326                    }
327                    Vary::VelocityY => {
328                        xi.orbit.velocity_km_s.y += delta[i];
329                    }
330                    Vary::VelocityZ => {
331                        xi.orbit.velocity_km_s.z += delta[i];
332                    }
333                    _ => {
334                        return Err(TargetingError::UnsupportedVariable {
335                            var: var.to_string(),
336                        });
337                    }
338                }
339            }
340            total_correction += delta;
341            debug!("Total correction: {total_correction:e}");
342
343            // Log progress
344            info!("Targeter -- Iteration #{it} -- {achievement_epoch}");
345            for obj in &objmsg {
346                info!("{obj}");
347            }
348        }
349
350        Err(TargetingError::TooManyIterations)
351    }
352}