nyx_space/md/opti/multipleshooting/
multishoot.rs

1/*
2    Nyx, blazing fast astrodynamics
3    Copyright (C) 2018-onwards Christopher Rabotin <christopher.rabotin@gmail.com>
4
5    This program is free software: you can redistribute it and/or modify
6    it under the terms of the GNU Affero General Public License as published
7    by the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU Affero General Public License for more details.
14
15    You should have received a copy of the GNU Affero General Public License
16    along with this program.  If not, see <https://www.gnu.org/licenses/>.
17*/
18
19use log::info;
20use snafu::ResultExt;
21
22pub use super::CostFunction;
23use super::{MultipleShootingError, TargetingSnafu};
24use crate::linalg::{DMatrix, DVector, SVector};
25use crate::md::opti::solution::TargeterSolution;
26use crate::md::targeter::Targeter;
27use crate::md::{prelude::*, TargetingError};
28use crate::pseudo_inverse;
29use crate::{Orbit, Spacecraft};
30
31use std::fmt;
32
33pub trait MultishootNode<const O: usize>: Copy + Into<[Objective; O]> {
34    fn epoch(&self) -> Epoch;
35    fn update_component(&mut self, component: usize, add_val: f64);
36}
37
38/// Multiple shooting is an optimization method.
39/// Source of implementation: "Low Thrust Optimization in Cislunar and Translunar space", 2018 Nathan Re (Parrish)
40/// OT: size of the objectives for each node (e.g. 3 if the objectives are X, Y, Z).
41/// VT: size of the variables for targeter node (e.g. 4 if the objectives are thrust direction (x,y,z) and thrust level).
42pub struct MultipleShooting<'a, T: MultishootNode<OT>, const VT: usize, const OT: usize> {
43    /// The propagator setup (kind, stages, etc.)
44    pub prop: &'a Propagator<SpacecraftDynamics>,
45    /// List of nodes of the optimal trajectory
46    pub targets: Vec<T>,
47    /// Starting point, must be a spacecraft equipped with a thruster
48    pub x0: Spacecraft,
49    /// Destination (Should this be the final node?)
50    pub xf: Orbit,
51    pub current_iteration: usize,
52    /// The maximum number of iterations allowed
53    pub max_iterations: usize,
54    /// Threshold after which the outer loop is considered to have converged,
55    /// e.g. 0.01 means that a 1% of less improvement in case between two iterations
56    /// will stop the iterations.
57    pub improvement_threshold: f64,
58    /// The kind of correction to apply to achieve the objectives
59    pub variables: [Variable; VT],
60    pub all_dvs: Vec<SVector<f64, VT>>,
61}
62
63impl<T: MultishootNode<OT>, const VT: usize, const OT: usize> MultipleShooting<'_, T, VT, OT> {
64    /// Solve the multiple shooting problem by finding the arrangement of nodes to minimize the cost function.
65    pub fn solve(
66        &mut self,
67        cost: CostFunction,
68        almanac: Arc<Almanac>,
69    ) -> Result<MultipleShootingSolution<T, OT>, MultipleShootingError> {
70        let mut prev_cost = 1e12; // We don't use infinity because we compare a ratio of cost
71        for it in 0..self.max_iterations {
72            let mut initial_states = Vec::with_capacity(self.targets.len());
73            initial_states.push(self.x0);
74            let mut outer_jacobian =
75                DMatrix::from_element(3 * self.targets.len(), OT * (self.targets.len() - 1), 0.0);
76            let mut cost_vec = DVector::from_element(3 * self.targets.len(), 0.0);
77
78            // Reset the all_dvs
79            self.all_dvs = Vec::with_capacity(self.all_dvs.len());
80
81            for i in 0..self.targets.len() {
82                /* ***
83                 ** 1. Solve the delta-v differential corrector between each node
84                 ** *** */
85                let tgt = Targeter {
86                    prop: self.prop,
87                    objectives: self.targets[i].into(),
88                    variables: self.variables,
89                    iterations: 100,
90                    objective_frame: None,
91                    correction_frame: None,
92                };
93                let sol = tgt
94                    .try_achieve_dual(
95                        initial_states[i],
96                        initial_states[i].epoch(),
97                        self.targets[i].epoch(),
98                        almanac.clone(),
99                    )
100                    .context(TargetingSnafu { segment: i })?;
101
102                let nominal_delta_v = sol.correction;
103
104                self.all_dvs.push(nominal_delta_v);
105                // Store the Δv and the initial state for the next targeter.
106                initial_states.push(sol.achieved_state);
107            }
108            // NOTE: We have two separate loops because we need the initial state of node i+2 for the dv computation
109            // of the third entry to the outer jacobian.
110            for i in 0..(self.targets.len() - 1) {
111                /* ***
112                 ** 2. Perturb each node and compute the partial of the Δv for the (i-1), i, and (i+1) nodes
113                 ** where the partial on the i+1 -th node is just the difference between the velocity at the
114                 ** achieved state and the initial state at that node.
115                 ** We don't perturb the endpoint node
116                 ** *** */
117
118                for axis in 0..OT {
119                    /* ***
120                     ** 2.A. Perturb the i-th node
121                     ** *** */
122                    let mut next_node = self.targets[i].into();
123                    next_node[axis].desired_value += next_node[axis].tolerance;
124                    /* ***
125                     ** 2.b. Rerun the targeter from the previous node to this one
126                     ** Note that because the first initial_state is x0, the i-th "initial state"
127                     ** is the initial state to reach the i-th node.
128                     ** *** */
129                    let inner_tgt_a = Targeter::delta_v(self.prop, next_node);
130                    let inner_sol_a = inner_tgt_a
131                        .try_achieve_dual(
132                            initial_states[i],
133                            initial_states[i].epoch(),
134                            self.targets[i].epoch(),
135                            almanac.clone(),
136                        )
137                        .context(TargetingSnafu { segment: i })?;
138
139                    // ∂Δv_x / ∂r_x
140                    outer_jacobian[(3 * i, OT * i + axis)] = (inner_sol_a.correction[0]
141                        - self.all_dvs[i][0])
142                        / next_node[axis].tolerance;
143                    // ∂Δv_y / ∂r_x
144                    outer_jacobian[(3 * i + 1, OT * i + axis)] = (inner_sol_a.correction[1]
145                        - self.all_dvs[i][1])
146                        / next_node[axis].tolerance;
147                    // ∂Δv_z / ∂r_x
148                    outer_jacobian[(3 * i + 2, OT * i + axis)] = (inner_sol_a.correction[2]
149                        - self.all_dvs[i][2])
150                        / next_node[axis].tolerance;
151
152                    /* ***
153                     ** 2.C. Rerun the targeter from the new state at the perturbed node to the next unpertubed node
154                     ** *** */
155                    let inner_tgt_b = Targeter::delta_v(self.prop, self.targets[i + 1].into());
156                    let inner_sol_b = inner_tgt_b
157                        .try_achieve_dual(
158                            inner_sol_a.achieved_state,
159                            inner_sol_a.achieved_state.epoch(),
160                            self.targets[i + 1].epoch(),
161                            almanac.clone(),
162                        )
163                        .context(TargetingSnafu { segment: i })?;
164
165                    // Compute the partials wrt the next Δv
166                    // ∂Δv_x / ∂r_x
167                    outer_jacobian[(3 * (i + 1), OT * i + axis)] = (inner_sol_b.correction[0]
168                        - self.all_dvs[i + 1][0])
169                        / next_node[axis].tolerance;
170                    // ∂Δv_y / ∂r_x
171                    outer_jacobian[(3 * (i + 1) + 1, OT * i + axis)] = (inner_sol_b.correction[1]
172                        - self.all_dvs[i + 1][1])
173                        / next_node[axis].tolerance;
174                    // ∂Δv_z / ∂r_x
175                    outer_jacobian[(3 * (i + 1) + 2, OT * i + axis)] = (inner_sol_b.correction[2]
176                        - self.all_dvs[i + 1][2])
177                        / next_node[axis].tolerance;
178
179                    /* ***
180                     ** 2.D. Compute the difference between the arrival and departure velocities and node i+1
181                     ** *** */
182                    if i < self.targets.len() - 3 {
183                        let dv_ip1 = inner_sol_b.achieved_state.orbit.velocity_km_s
184                            - initial_states[i + 2].orbit.velocity_km_s;
185                        // ∂Δv_x / ∂r_x
186                        outer_jacobian[(3 * (i + 2), OT * i + axis)] =
187                            dv_ip1[0] / next_node[axis].tolerance;
188                        // ∂Δv_y / ∂r_x
189                        outer_jacobian[(3 * (i + 2) + 1, OT * i + axis)] =
190                            dv_ip1[1] / next_node[axis].tolerance;
191                        // ∂Δv_z / ∂r_x
192                        outer_jacobian[(3 * (i + 2) + 2, OT * i + axis)] =
193                            dv_ip1[2] / next_node[axis].tolerance;
194                    }
195                }
196            }
197
198            // Build the cost vector
199            for i in 0..self.targets.len() {
200                for j in 0..3 {
201                    cost_vec[3 * i + j] = self.all_dvs[i][j];
202                }
203            }
204
205            // Compute the cost -- used to stop the algorithm if it does not change much.
206            let new_cost = match cost {
207                CostFunction::MinimumEnergy => cost_vec.dot(&cost_vec),
208                CostFunction::MinimumFuel => cost_vec.dot(&cost_vec).sqrt(),
209            };
210
211            // If the new cost is greater than the previous one, then the cost improvement is negative.
212            let cost_improvmt = (prev_cost - new_cost) / new_cost.abs();
213            // If the cost does not improve by more than threshold stop iteration
214            match cost {
215                CostFunction::MinimumEnergy => info!(
216                    "Multiple shooting iteration #{}\t\tCost = {:.3} km^2/s^2\timprovement = {:.2}%",
217                    it,
218                    new_cost,
219                    100.0 * cost_improvmt
220                ),
221                CostFunction::MinimumFuel => info!(
222                    "Multiple shooting iteration #{}\t\tCost = {:.3} km/s\timprovement = {:.2}%",
223                    it,
224                    new_cost,
225                    100.0 * cost_improvmt
226                ),
227            };
228            if cost_improvmt.abs() < self.improvement_threshold {
229                info!("Improvement below desired threshold. Running targeter on computed nodes.");
230
231                /* ***
232                 ** FIN -- Check the impulsive burns work and return all targeter solutions
233                 ** *** */
234                let mut ms_sol = MultipleShootingSolution {
235                    x0: self.x0,
236                    xf: self.xf,
237                    nodes: self.targets.clone(),
238                    solutions: Vec::with_capacity(self.targets.len()),
239                };
240                let mut initial_states = Vec::with_capacity(self.targets.len());
241                initial_states.push(self.x0);
242
243                for (i, node) in self.targets.iter().enumerate() {
244                    // Run the unpertubed targeter
245                    let tgt = Targeter::delta_v(self.prop, (*node).into());
246                    let sol = tgt
247                        .try_achieve_dual(
248                            initial_states[i],
249                            initial_states[i].epoch(),
250                            node.epoch(),
251                            almanac.clone(),
252                        )
253                        .context(TargetingSnafu { segment: i })?;
254                    initial_states.push(sol.achieved_state);
255                    ms_sol.solutions.push(sol);
256                }
257
258                return Ok(ms_sol);
259            }
260
261            prev_cost = new_cost;
262            // 2. Solve for the next position of the nodes using a pseudo inverse.
263            let inv_jac =
264                pseudo_inverse!(&outer_jacobian).context(TargetingSnafu { segment: 0_usize })?;
265            let delta_r = inv_jac * cost_vec;
266            // 3. Apply the correction to the node positions and iterator
267            let node_vector = -delta_r;
268            for (i, val) in node_vector.iter().enumerate() {
269                let node_no = i / 3;
270                let component_no = i % OT;
271                self.targets[node_no].update_component(component_no, *val);
272            }
273            self.current_iteration += 1;
274        }
275        Err(MultipleShootingError::TargetingError {
276            segment: 0_usize,
277            source: Box::new(TargetingError::TooManyIterations),
278        })
279    }
280}
281
282impl<T: MultishootNode<OT>, const VT: usize, const OT: usize> fmt::Display
283    for MultipleShooting<'_, T, VT, OT>
284{
285    #[allow(clippy::or_fun_call, clippy::clone_on_copy)]
286    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
287        let mut nodemsg = String::from("");
288        // Add the starting point too
289        nodemsg.push_str(&format!(
290            "[{:.3}, {:.3}, {:.3}, {}, {}, {}, {}, {}, {}],\n",
291            self.x0.orbit.radius_km.x,
292            self.x0.orbit.radius_km.y,
293            self.x0.orbit.radius_km.z,
294            self.current_iteration,
295            0.0,
296            0.0,
297            0.0,
298            0.0,
299            0
300        ));
301
302        for (i, node) in self.targets.iter().enumerate() {
303            let objectives: [Objective; OT] = (*node).into();
304            let mut this_nodemsg = String::from("");
305            for obj in &objectives {
306                this_nodemsg.push_str(&format!("{:.3}, ", obj.desired_value));
307            }
308            let mut this_costmsg = String::from("");
309            let dv = match self.all_dvs.get(i) {
310                Some(dv) => dv.clone(),
311                None => SVector::<f64, VT>::zeros(),
312            };
313            for val in &dv {
314                this_costmsg.push_str(&format!("{val}, "));
315            }
316            if VT == 3 {
317                // Add the norm of the control
318                this_costmsg.push_str(&format!("{}, ", dv.norm()));
319            }
320            nodemsg.push_str(&format!(
321                "[{}{}, {}{}],\n",
322                this_nodemsg,
323                self.current_iteration,
324                this_nodemsg,
325                i + 1
326            ));
327        }
328        write!(f, "{nodemsg}")
329    }
330}
331
332#[derive(Clone, Debug)]
333pub struct MultipleShootingSolution<T: MultishootNode<O>, const O: usize> {
334    pub x0: Spacecraft,
335    pub xf: Orbit,
336    pub nodes: Vec<T>,
337    pub solutions: Vec<TargeterSolution<3, O>>,
338}
339
340impl<T: MultishootNode<O>, const O: usize> fmt::Display for MultipleShootingSolution<T, O> {
341    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
342        for sol in &self.solutions {
343            write!(f, "{sol}")?;
344        }
345        Ok(())
346    }
347}
348
349impl<T: MultishootNode<O>, const O: usize> MultipleShootingSolution<T, O> {
350    /// Allows building the trajectories between different nodes
351    /// This will rebuild the targeters and apply the solutions sequentially
352    pub fn build_trajectories(
353        &self,
354        prop: &Propagator<SpacecraftDynamics>,
355        almanac: Arc<Almanac>,
356    ) -> Result<Vec<Trajectory>, MultipleShootingError> {
357        let mut trajz = Vec::with_capacity(self.nodes.len());
358
359        for (i, node) in self.nodes.iter().copied().enumerate() {
360            let (_, traj) = Targeter::delta_v(prop, node.into())
361                .apply_with_traj(&self.solutions[i], almanac.clone())
362                .context(TargetingSnafu { segment: i })?;
363            trajz.push(traj);
364        }
365
366        Ok(trajz)
367    }
368}