nyx_space/md/opti/
raphson_hyperdual.rs

1/*
2    Nyx, blazing fast astrodynamics
3    Copyright (C) 2018-onwards Christopher Rabotin <christopher.rabotin@gmail.com>
4
5    This program is free software: you can redistribute it and/or modify
6    it under the terms of the GNU Affero General Public License as published
7    by the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU Affero General Public License for more details.
14
15    You should have received a copy of the GNU Affero General Public License
16    along with this program.  If not, see <https://www.gnu.org/licenses/>.
17*/
18
19use snafu::{ensure, ResultExt};
20
21use super::solution::TargeterSolution;
22use crate::cosmic::AstroAlmanacSnafu;
23use crate::errors::TargetingError;
24use crate::linalg::{DMatrix, SVector};
25use crate::md::{prelude::*, PropSnafu, UnderdeterminedProblemSnafu};
26use crate::md::{AstroSnafu, StateParameter};
27pub use crate::md::{Variable, Vary};
28use crate::pseudo_inverse;
29use crate::utils::are_eigenvalues_stable;
30#[cfg(not(target_arch = "wasm32"))]
31use std::time::Instant;
32
33impl<const V: usize, const O: usize> Targeter<'_, V, O> {
34    /// Differential correction using hyperdual numbers for the objectives
35    #[allow(clippy::comparison_chain)]
36    pub fn try_achieve_dual(
37        &self,
38        initial_state: Spacecraft,
39        correction_epoch: Epoch,
40        achievement_epoch: Epoch,
41        almanac: Arc<Almanac>,
42    ) -> Result<TargeterSolution<V, O>, TargetingError> {
43        ensure!(!self.objectives.is_empty(), UnderdeterminedProblemSnafu);
44
45        let mut is_bplane_tgt = false;
46        for obj in &self.objectives {
47            if obj.parameter.is_b_plane() {
48                is_bplane_tgt = true;
49                break;
50            }
51        }
52
53        // Now we know that the problem is correctly defined, so let's propagate as is to the epoch
54        // where the correction should be applied.
55        let xi_start = self
56            .prop
57            .with(initial_state, almanac.clone())
58            .until_epoch(correction_epoch)
59            .context(PropSnafu)?;
60
61        debug!("initial_state = {initial_state:?}");
62        debug!("xi_start = {xi_start:?}");
63
64        let mut xi = xi_start;
65
66        // Store the total correction in a static vector
67        let mut total_correction = SVector::<f64, V>::zeros();
68
69        // Apply the initial guess
70        for (i, var) in self.variables.iter().enumerate() {
71            match var.component {
72                Vary::PositionX => {
73                    xi.orbit.radius_km.x += var.init_guess;
74                }
75                Vary::PositionY => {
76                    xi.orbit.radius_km.y += var.init_guess;
77                }
78                Vary::PositionZ => {
79                    xi.orbit.radius_km.z += var.init_guess;
80                }
81                Vary::VelocityX => {
82                    xi.orbit.velocity_km_s.x += var.init_guess;
83                }
84                Vary::VelocityY => {
85                    xi.orbit.velocity_km_s.y += var.init_guess;
86                }
87                Vary::VelocityZ => {
88                    xi.orbit.velocity_km_s.z += var.init_guess;
89                }
90                _ => {
91                    return Err(TargetingError::UnsupportedVariable {
92                        var: var.to_string(),
93                    });
94                }
95            }
96            total_correction[i] += var.init_guess;
97        }
98
99        let mut prev_err_norm = f64::INFINITY;
100
101        // Determine padding in debugging info
102        // For the width, we find the largest desired values and multiply it by the order of magnitude of its tolerance
103        let max_obj_val = self
104            .objectives
105            .iter()
106            .map(|obj| {
107                obj.desired_value.abs().ceil() as i32
108                    * 10_i32.pow(obj.tolerance.abs().log10().ceil() as u32)
109            })
110            .max()
111            .unwrap();
112
113        let max_obj_tol = self
114            .objectives
115            .iter()
116            .map(|obj| obj.tolerance.log10().abs().ceil() as usize)
117            .max()
118            .unwrap();
119
120        let width = f64::from(max_obj_val).log10() as usize + 2 + max_obj_tol;
121
122        #[cfg(not(target_arch = "wasm32"))]
123        let start_instant = Instant::now();
124
125        for it in 0..=self.iterations {
126            // Now, enable the trajectory STM for this state so we can apply the correction
127            xi.enable_stm();
128
129            // Full propagation for a half period duration is slightly more precise than a step by step one with multiplications in between.
130            let xf = self
131                .prop
132                .with(xi, almanac.clone())
133                .until_epoch(achievement_epoch)
134                .context(PropSnafu)?;
135
136            // Check linearization
137            if !are_eigenvalues_stable(xf.stm().unwrap().complex_eigenvalues()) {
138                warn!(
139                    "STM linearization assumption is wrong for a time step of {}",
140                    achievement_epoch - correction_epoch
141                );
142            }
143
144            let xf_dual_obj_frame = match &self.objective_frame {
145                Some(frame) => {
146                    let orbit_obj_frame = almanac
147                        .transform_to(xf.orbit, *frame, None)
148                        .context(AstroAlmanacSnafu)
149                        .context(AstroSnafu)?;
150
151                    OrbitDual::from(orbit_obj_frame)
152                }
153                None => OrbitDual::from(xf.orbit),
154            };
155
156            // Build the error vector
157            let mut err_vector = SVector::<f64, O>::zeros();
158            let mut converged = true;
159
160            // Build the B-Plane once, if needed, and always in the objective frame
161            let b_plane = if is_bplane_tgt {
162                Some(BPlane::from_dual(xf_dual_obj_frame).context(AstroSnafu)?)
163            } else {
164                None
165            };
166
167            // Build debugging information
168            let mut objmsg = Vec::new();
169
170            // The Jacobian includes the sensitivity of each objective with respect to each variable for the whole trajectory.
171            // As such, it includes the STM of that variable for the whole propagation arc.
172            let mut jac = DMatrix::from_element(self.objectives.len(), self.variables.len(), 0.0);
173
174            for (i, obj) in self.objectives.iter().enumerate() {
175                let xf_partial = if obj.parameter.is_b_plane() {
176                    match obj.parameter {
177                        StateParameter::BdotR => b_plane.unwrap().b_r,
178                        StateParameter::BdotT => b_plane.unwrap().b_t,
179                        StateParameter::BLTOF => b_plane.unwrap().ltof_s,
180                        _ => unreachable!(),
181                    }
182                } else {
183                    xf_dual_obj_frame
184                        .partial_for(obj.parameter)
185                        .context(AstroSnafu)?
186                };
187
188                let achieved = xf_partial.real();
189
190                let (ok, param_err) = obj.assess_value(achieved);
191                if !ok {
192                    converged = false;
193                }
194                err_vector[i] = param_err;
195
196                objmsg.push(format!(
197                    "\t{:?}: achieved = {:>width$.prec$}\t desired = {:>width$.prec$}\t scaled error = {:>width$.prec$}",
198                    obj.parameter,
199                    achieved,
200                    obj.desired_value,
201                    param_err, width=width, prec=max_obj_tol
202                ));
203
204                // Build the Jacobian with the partials of the objectives with respect to all of the final state parameters
205                // We localize the problem in the STM.
206                // TODO: VNC (how?!)
207                let mut partial_vec = DMatrix::from_element(1, 6, 0.0);
208                for (i, val) in [
209                    xf_partial.wtr_x(),
210                    xf_partial.wtr_y(),
211                    xf_partial.wtr_z(),
212                    xf_partial.wtr_vx(),
213                    xf_partial.wtr_vy(),
214                    xf_partial.wtr_vz(),
215                ]
216                .iter()
217                .enumerate()
218                {
219                    partial_vec[(0, i)] = *val;
220                }
221
222                for (j, var) in self.variables.iter().enumerate() {
223                    // Grab the STM first.
224                    let sc_stm = xf.stm().unwrap();
225                    let stm = sc_stm.fixed_view::<6, 6>(0, 0);
226                    let idx = var.component.vec_index();
227                    // Compute the partial of the objective over all components wrt to all of the components in the STM of the control variable.
228                    let rslt = &partial_vec * stm.fixed_columns::<1>(idx);
229                    jac[(i, j)] = rslt[(0, 0)];
230                }
231            }
232
233            if converged {
234                #[cfg(not(target_arch = "wasm32"))]
235                let conv_dur = Instant::now() - start_instant;
236                #[cfg(target_arch = "wasm32")]
237                let conv_dur = Duration::ZERO.into();
238                let mut state = xi_start;
239                // Convert the total correction from VNC back to integration frame in case that's needed.
240                for (i, var) in self.variables.iter().enumerate() {
241                    match var.component {
242                        Vary::PositionX => state.orbit.radius_km.x += total_correction[i],
243                        Vary::PositionY => state.orbit.radius_km.y += total_correction[i],
244                        Vary::PositionZ => state.orbit.radius_km.z += total_correction[i],
245                        Vary::VelocityX => state.orbit.velocity_km_s.x += total_correction[i],
246                        Vary::VelocityY => state.orbit.velocity_km_s.y += total_correction[i],
247                        Vary::VelocityZ => state.orbit.velocity_km_s.z += total_correction[i],
248                        _ => {
249                            return Err(TargetingError::UnsupportedVariable {
250                                var: var.to_string(),
251                            })
252                        }
253                    }
254                }
255
256                let sol = TargeterSolution {
257                    corrected_state: state,
258                    achieved_state: xf,
259                    correction: total_correction,
260                    computation_dur: conv_dur,
261                    variables: self.variables,
262                    achieved_errors: err_vector,
263                    achieved_objectives: self.objectives,
264                    iterations: it,
265                };
266                info!("Targeter -- CONVERGED in {} iterations", it);
267                for obj in &objmsg {
268                    info!("{}", obj);
269                }
270                return Ok(sol);
271            }
272
273            // We haven't converged yet, so let's build the error vector
274            if (err_vector.norm() - prev_err_norm).abs() < 1e-10 {
275                return Err(TargetingError::CorrectionIneffective {
276                    cur_val: err_vector.norm(),
277                    prev_val: prev_err_norm,
278                    action: "No change in objective errors",
279                });
280            }
281            prev_err_norm = err_vector.norm();
282
283            debug!("Jacobian {}", jac);
284
285            // Perform the pseudo-inverse if needed, else just inverse
286            let jac_inv = pseudo_inverse!(&jac)?;
287
288            debug!("Inverse Jacobian {}", jac_inv);
289
290            let mut delta = jac_inv * err_vector;
291
292            debug!("Error vector: {}\nRaw correction: {}", err_vector, delta);
293
294            // And finally apply it to the xi
295            for (i, var) in self.variables.iter().enumerate() {
296                // Choose the minimum step between the provided max step and the correction.
297                if delta[i].abs() > var.max_step {
298                    delta[i] = var.max_step * delta[i].signum();
299                } else if delta[i] > var.max_value {
300                    delta[i] = var.max_value;
301                } else if delta[i] < var.min_value {
302                    delta[i] = var.min_value;
303                }
304
305                info!(
306                    "Correction {:?} (element {}): {}",
307                    var.component, i, delta[i]
308                );
309
310                match var.component {
311                    Vary::PositionX => {
312                        xi.orbit.radius_km.x += delta[i];
313                    }
314                    Vary::PositionY => {
315                        xi.orbit.radius_km.y += delta[i];
316                    }
317                    Vary::PositionZ => {
318                        xi.orbit.radius_km.z += delta[i];
319                    }
320                    Vary::VelocityX => {
321                        xi.orbit.velocity_km_s.x += delta[i];
322                    }
323                    Vary::VelocityY => {
324                        xi.orbit.velocity_km_s.y += delta[i];
325                    }
326                    Vary::VelocityZ => {
327                        xi.orbit.velocity_km_s.z += delta[i];
328                    }
329                    _ => {
330                        return Err(TargetingError::UnsupportedVariable {
331                            var: var.to_string(),
332                        });
333                    }
334                }
335            }
336            total_correction += delta;
337            debug!("Total correction: {:e}", total_correction);
338
339            // Log progress
340            info!("Targeter -- Iteration #{} -- {}", it, achievement_epoch);
341            for obj in &objmsg {
342                info!("{}", obj);
343            }
344        }
345
346        Err(TargetingError::TooManyIterations)
347    }
348}