nyx_space/md/opti/multipleshooting/
multishoot.rs

1/*
2    Nyx, blazing fast astrodynamics
3    Copyright (C) 2018-onwards Christopher Rabotin <christopher.rabotin@gmail.com>
4
5    This program is free software: you can redistribute it and/or modify
6    it under the terms of the GNU Affero General Public License as published
7    by the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU Affero General Public License for more details.
14
15    You should have received a copy of the GNU Affero General Public License
16    along with this program.  If not, see <https://www.gnu.org/licenses/>.
17*/
18
19use snafu::ResultExt;
20
21pub use super::CostFunction;
22use super::{MultipleShootingError, TargetingSnafu};
23use crate::linalg::{DMatrix, DVector, SVector};
24use crate::md::opti::solution::TargeterSolution;
25use crate::md::targeter::Targeter;
26use crate::md::{prelude::*, TargetingError};
27use crate::pseudo_inverse;
28use crate::{Orbit, Spacecraft};
29
30use std::fmt;
31
32pub trait MultishootNode<const O: usize>: Copy + Into<[Objective; O]> {
33    fn epoch(&self) -> Epoch;
34    fn update_component(&mut self, component: usize, add_val: f64);
35}
36
37/// Multiple shooting is an optimization method.
38/// Source of implementation: "Low Thrust Optimization in Cislunar and Translunar space", 2018 Nathan Re (Parrish)
39/// OT: size of the objectives for each node (e.g. 3 if the objectives are X, Y, Z).
40/// VT: size of the variables for targeter node (e.g. 4 if the objectives are thrust direction (x,y,z) and thrust level).
41pub struct MultipleShooting<'a, T: MultishootNode<OT>, const VT: usize, const OT: usize> {
42    /// The propagator setup (kind, stages, etc.)
43    pub prop: &'a Propagator<SpacecraftDynamics>,
44    /// List of nodes of the optimal trajectory
45    pub targets: Vec<T>,
46    /// Starting point, must be a spacecraft equipped with a thruster
47    pub x0: Spacecraft,
48    /// Destination (Should this be the final node?)
49    pub xf: Orbit,
50    pub current_iteration: usize,
51    /// The maximum number of iterations allowed
52    pub max_iterations: usize,
53    /// Threshold after which the outer loop is considered to have converged,
54    /// e.g. 0.01 means that a 1% of less improvement in case between two iterations
55    /// will stop the iterations.
56    pub improvement_threshold: f64,
57    /// The kind of correction to apply to achieve the objectives
58    pub variables: [Variable; VT],
59    pub all_dvs: Vec<SVector<f64, VT>>,
60}
61
62impl<T: MultishootNode<OT>, const VT: usize, const OT: usize> MultipleShooting<'_, T, VT, OT> {
63    /// Solve the multiple shooting problem by finding the arrangement of nodes to minimize the cost function.
64    pub fn solve(
65        &mut self,
66        cost: CostFunction,
67        almanac: Arc<Almanac>,
68    ) -> Result<MultipleShootingSolution<T, OT>, MultipleShootingError> {
69        let mut prev_cost = 1e12; // We don't use infinity because we compare a ratio of cost
70        for it in 0..self.max_iterations {
71            let mut initial_states = Vec::with_capacity(self.targets.len());
72            initial_states.push(self.x0);
73            let mut outer_jacobian =
74                DMatrix::from_element(3 * self.targets.len(), OT * (self.targets.len() - 1), 0.0);
75            let mut cost_vec = DVector::from_element(3 * self.targets.len(), 0.0);
76
77            // Reset the all_dvs
78            self.all_dvs = Vec::with_capacity(self.all_dvs.len());
79
80            for i in 0..self.targets.len() {
81                /* ***
82                 ** 1. Solve the delta-v differential corrector between each node
83                 ** *** */
84                let tgt = Targeter {
85                    prop: self.prop,
86                    objectives: self.targets[i].into(),
87                    variables: self.variables,
88                    iterations: 100,
89                    objective_frame: None,
90                    correction_frame: None,
91                };
92                let sol = tgt
93                    .try_achieve_dual(
94                        initial_states[i],
95                        initial_states[i].epoch(),
96                        self.targets[i].epoch(),
97                        almanac.clone(),
98                    )
99                    .context(TargetingSnafu { segment: i })?;
100
101                let nominal_delta_v = sol.correction;
102
103                self.all_dvs.push(nominal_delta_v);
104                // Store the Δv and the initial state for the next targeter.
105                initial_states.push(sol.achieved_state);
106            }
107            // NOTE: We have two separate loops because we need the initial state of node i+2 for the dv computation
108            // of the third entry to the outer jacobian.
109            for i in 0..(self.targets.len() - 1) {
110                /* ***
111                 ** 2. Perturb each node and compute the partial of the Δv for the (i-1), i, and (i+1) nodes
112                 ** where the partial on the i+1 -th node is just the difference between the velocity at the
113                 ** achieved state and the initial state at that node.
114                 ** We don't perturb the endpoint node
115                 ** *** */
116
117                for axis in 0..OT {
118                    /* ***
119                     ** 2.A. Perturb the i-th node
120                     ** *** */
121                    let mut next_node = self.targets[i].into();
122                    next_node[axis].desired_value += next_node[axis].tolerance;
123                    /* ***
124                     ** 2.b. Rerun the targeter from the previous node to this one
125                     ** Note that because the first initial_state is x0, the i-th "initial state"
126                     ** is the initial state to reach the i-th node.
127                     ** *** */
128                    let inner_tgt_a = Targeter::delta_v(self.prop, next_node);
129                    let inner_sol_a = inner_tgt_a
130                        .try_achieve_dual(
131                            initial_states[i],
132                            initial_states[i].epoch(),
133                            self.targets[i].epoch(),
134                            almanac.clone(),
135                        )
136                        .context(TargetingSnafu { segment: i })?;
137
138                    // ∂Δv_x / ∂r_x
139                    outer_jacobian[(3 * i, OT * i + axis)] = (inner_sol_a.correction[0]
140                        - self.all_dvs[i][0])
141                        / next_node[axis].tolerance;
142                    // ∂Δv_y / ∂r_x
143                    outer_jacobian[(3 * i + 1, OT * i + axis)] = (inner_sol_a.correction[1]
144                        - self.all_dvs[i][1])
145                        / next_node[axis].tolerance;
146                    // ∂Δv_z / ∂r_x
147                    outer_jacobian[(3 * i + 2, OT * i + axis)] = (inner_sol_a.correction[2]
148                        - self.all_dvs[i][2])
149                        / next_node[axis].tolerance;
150
151                    /* ***
152                     ** 2.C. Rerun the targeter from the new state at the perturbed node to the next unpertubed node
153                     ** *** */
154                    let inner_tgt_b = Targeter::delta_v(self.prop, self.targets[i + 1].into());
155                    let inner_sol_b = inner_tgt_b
156                        .try_achieve_dual(
157                            inner_sol_a.achieved_state,
158                            inner_sol_a.achieved_state.epoch(),
159                            self.targets[i + 1].epoch(),
160                            almanac.clone(),
161                        )
162                        .context(TargetingSnafu { segment: i })?;
163
164                    // Compute the partials wrt the next Δv
165                    // ∂Δv_x / ∂r_x
166                    outer_jacobian[(3 * (i + 1), OT * i + axis)] = (inner_sol_b.correction[0]
167                        - self.all_dvs[i + 1][0])
168                        / next_node[axis].tolerance;
169                    // ∂Δv_y / ∂r_x
170                    outer_jacobian[(3 * (i + 1) + 1, OT * i + axis)] = (inner_sol_b.correction[1]
171                        - self.all_dvs[i + 1][1])
172                        / next_node[axis].tolerance;
173                    // ∂Δv_z / ∂r_x
174                    outer_jacobian[(3 * (i + 1) + 2, OT * i + axis)] = (inner_sol_b.correction[2]
175                        - self.all_dvs[i + 1][2])
176                        / next_node[axis].tolerance;
177
178                    /* ***
179                     ** 2.D. Compute the difference between the arrival and departure velocities and node i+1
180                     ** *** */
181                    if i < self.targets.len() - 3 {
182                        let dv_ip1 = inner_sol_b.achieved_state.orbit.velocity_km_s
183                            - initial_states[i + 2].orbit.velocity_km_s;
184                        // ∂Δv_x / ∂r_x
185                        outer_jacobian[(3 * (i + 2), OT * i + axis)] =
186                            dv_ip1[0] / next_node[axis].tolerance;
187                        // ∂Δv_y / ∂r_x
188                        outer_jacobian[(3 * (i + 2) + 1, OT * i + axis)] =
189                            dv_ip1[1] / next_node[axis].tolerance;
190                        // ∂Δv_z / ∂r_x
191                        outer_jacobian[(3 * (i + 2) + 2, OT * i + axis)] =
192                            dv_ip1[2] / next_node[axis].tolerance;
193                    }
194                }
195            }
196
197            // Build the cost vector
198            for i in 0..self.targets.len() {
199                for j in 0..3 {
200                    cost_vec[3 * i + j] = self.all_dvs[i][j];
201                }
202            }
203
204            // Compute the cost -- used to stop the algorithm if it does not change much.
205            let new_cost = match cost {
206                CostFunction::MinimumEnergy => cost_vec.dot(&cost_vec),
207                CostFunction::MinimumFuel => cost_vec.dot(&cost_vec).sqrt(),
208            };
209
210            // If the new cost is greater than the previous one, then the cost improvement is negative.
211            let cost_improvmt = (prev_cost - new_cost) / new_cost.abs();
212            // If the cost does not improve by more than threshold stop iteration
213            match cost {
214                CostFunction::MinimumEnergy => info!(
215                    "Multiple shooting iteration #{}\t\tCost = {:.3} km^2/s^2\timprovement = {:.2}%",
216                    it,
217                    new_cost,
218                    100.0 * cost_improvmt
219                ),
220                CostFunction::MinimumFuel => info!(
221                    "Multiple shooting iteration #{}\t\tCost = {:.3} km/s\timprovement = {:.2}%",
222                    it,
223                    new_cost,
224                    100.0 * cost_improvmt
225                ),
226            };
227            if cost_improvmt.abs() < self.improvement_threshold {
228                info!("Improvement below desired threshold. Running targeter on computed nodes.");
229
230                /* ***
231                 ** FIN -- Check the impulsive burns work and return all targeter solutions
232                 ** *** */
233                let mut ms_sol = MultipleShootingSolution {
234                    x0: self.x0,
235                    xf: self.xf,
236                    nodes: self.targets.clone(),
237                    solutions: Vec::with_capacity(self.targets.len()),
238                };
239                let mut initial_states = Vec::with_capacity(self.targets.len());
240                initial_states.push(self.x0);
241
242                for (i, node) in self.targets.iter().enumerate() {
243                    // Run the unpertubed targeter
244                    let tgt = Targeter::delta_v(self.prop, (*node).into());
245                    let sol = tgt
246                        .try_achieve_dual(
247                            initial_states[i],
248                            initial_states[i].epoch(),
249                            node.epoch(),
250                            almanac.clone(),
251                        )
252                        .context(TargetingSnafu { segment: i })?;
253                    initial_states.push(sol.achieved_state);
254                    ms_sol.solutions.push(sol);
255                }
256
257                return Ok(ms_sol);
258            }
259
260            prev_cost = new_cost;
261            // 2. Solve for the next position of the nodes using a pseudo inverse.
262            let inv_jac =
263                pseudo_inverse!(&outer_jacobian).context(TargetingSnafu { segment: 0_usize })?;
264            let delta_r = inv_jac * cost_vec;
265            // 3. Apply the correction to the node positions and iterator
266            let node_vector = -delta_r;
267            for (i, val) in node_vector.iter().enumerate() {
268                let node_no = i / 3;
269                let component_no = i % OT;
270                self.targets[node_no].update_component(component_no, *val);
271            }
272            self.current_iteration += 1;
273        }
274        Err(MultipleShootingError::TargetingError {
275            segment: 0_usize,
276            source: TargetingError::TooManyIterations,
277        })
278    }
279}
280
281impl<T: MultishootNode<OT>, const VT: usize, const OT: usize> fmt::Display
282    for MultipleShooting<'_, T, VT, OT>
283{
284    #[allow(clippy::or_fun_call, clippy::clone_on_copy)]
285    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
286        let mut nodemsg = String::from("");
287        // Add the starting point too
288        nodemsg.push_str(&format!(
289            "[{:.3}, {:.3}, {:.3}, {}, {}, {}, {}, {}, {}],\n",
290            self.x0.orbit.radius_km.x,
291            self.x0.orbit.radius_km.y,
292            self.x0.orbit.radius_km.z,
293            self.current_iteration,
294            0.0,
295            0.0,
296            0.0,
297            0.0,
298            0
299        ));
300
301        for (i, node) in self.targets.iter().enumerate() {
302            let objectives: [Objective; OT] = (*node).into();
303            let mut this_nodemsg = String::from("");
304            for obj in &objectives {
305                this_nodemsg.push_str(&format!("{:.3}, ", obj.desired_value));
306            }
307            let mut this_costmsg = String::from("");
308            let dv = match self.all_dvs.get(i) {
309                Some(dv) => dv.clone(),
310                None => SVector::<f64, VT>::zeros(),
311            };
312            for val in &dv {
313                this_costmsg.push_str(&format!("{val}, "));
314            }
315            if VT == 3 {
316                // Add the norm of the control
317                this_costmsg.push_str(&format!("{}, ", dv.norm()));
318            }
319            nodemsg.push_str(&format!(
320                "[{}{}, {}{}],\n",
321                this_nodemsg,
322                self.current_iteration,
323                this_nodemsg,
324                i + 1
325            ));
326        }
327        write!(f, "{nodemsg}")
328    }
329}
330
331#[derive(Clone, Debug)]
332pub struct MultipleShootingSolution<T: MultishootNode<O>, const O: usize> {
333    pub x0: Spacecraft,
334    pub xf: Orbit,
335    pub nodes: Vec<T>,
336    pub solutions: Vec<TargeterSolution<3, O>>,
337}
338
339impl<T: MultishootNode<O>, const O: usize> fmt::Display for MultipleShootingSolution<T, O> {
340    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
341        for sol in &self.solutions {
342            write!(f, "{sol}")?;
343        }
344        Ok(())
345    }
346}
347
348impl<T: MultishootNode<O>, const O: usize> MultipleShootingSolution<T, O> {
349    /// Allows building the trajectories between different nodes
350    /// This will rebuild the targeters and apply the solutions sequentially
351    pub fn build_trajectories(
352        &self,
353        prop: &Propagator<SpacecraftDynamics>,
354        almanac: Arc<Almanac>,
355    ) -> Result<Vec<Trajectory>, MultipleShootingError> {
356        let mut trajz = Vec::with_capacity(self.nodes.len());
357
358        for (i, node) in self.nodes.iter().copied().enumerate() {
359            let (_, traj) = Targeter::delta_v(prop, node.into())
360                .apply_with_traj(&self.solutions[i], almanac.clone())
361                .context(TargetingSnafu { segment: i })?;
362            trajz.push(traj);
363        }
364
365        Ok(trajz)
366    }
367}