nyx_space/md/opti/multipleshooting/
multishoot.rs1use log::info;
20use snafu::ResultExt;
21
22pub use super::CostFunction;
23use super::{MultipleShootingError, TargetingSnafu};
24use crate::linalg::{DMatrix, DVector, SVector};
25use crate::md::opti::solution::TargeterSolution;
26use crate::md::targeter::Targeter;
27use crate::md::{prelude::*, TargetingError};
28use crate::pseudo_inverse;
29use crate::{Orbit, Spacecraft};
30
31use std::fmt;
32
33pub trait MultishootNode<const O: usize>: Copy + Into<[Objective; O]> {
34 fn epoch(&self) -> Epoch;
35 fn update_component(&mut self, component: usize, add_val: f64);
36}
37
38pub struct MultipleShooting<'a, T: MultishootNode<OT>, const VT: usize, const OT: usize> {
43 pub prop: &'a Propagator<SpacecraftDynamics>,
45 pub targets: Vec<T>,
47 pub x0: Spacecraft,
49 pub xf: Orbit,
51 pub current_iteration: usize,
52 pub max_iterations: usize,
54 pub improvement_threshold: f64,
58 pub variables: [Variable; VT],
60 pub all_dvs: Vec<SVector<f64, VT>>,
61}
62
63impl<T: MultishootNode<OT>, const VT: usize, const OT: usize> MultipleShooting<'_, T, VT, OT> {
64 pub fn solve(
66 &mut self,
67 cost: CostFunction,
68 almanac: Arc<Almanac>,
69 ) -> Result<MultipleShootingSolution<T, OT>, MultipleShootingError> {
70 let mut prev_cost = 1e12; for it in 0..self.max_iterations {
72 let mut initial_states = Vec::with_capacity(self.targets.len());
73 initial_states.push(self.x0);
74 let mut outer_jacobian =
75 DMatrix::from_element(3 * self.targets.len(), OT * (self.targets.len() - 1), 0.0);
76 let mut cost_vec = DVector::from_element(3 * self.targets.len(), 0.0);
77
78 self.all_dvs = Vec::with_capacity(self.all_dvs.len());
80
81 for i in 0..self.targets.len() {
82 let tgt = Targeter {
86 prop: self.prop,
87 objectives: self.targets[i].into(),
88 variables: self.variables,
89 iterations: 100,
90 objective_frame: None,
91 correction_frame: None,
92 };
93 let sol = tgt
94 .try_achieve_dual(
95 initial_states[i],
96 initial_states[i].epoch(),
97 self.targets[i].epoch(),
98 almanac.clone(),
99 )
100 .context(TargetingSnafu { segment: i })?;
101
102 let nominal_delta_v = sol.correction;
103
104 self.all_dvs.push(nominal_delta_v);
105 initial_states.push(sol.achieved_state);
107 }
108 for i in 0..(self.targets.len() - 1) {
111 for axis in 0..OT {
119 let mut next_node = self.targets[i].into();
123 next_node[axis].desired_value += next_node[axis].tolerance;
124 let inner_tgt_a = Targeter::delta_v(self.prop, next_node);
130 let inner_sol_a = inner_tgt_a
131 .try_achieve_dual(
132 initial_states[i],
133 initial_states[i].epoch(),
134 self.targets[i].epoch(),
135 almanac.clone(),
136 )
137 .context(TargetingSnafu { segment: i })?;
138
139 outer_jacobian[(3 * i, OT * i + axis)] = (inner_sol_a.correction[0]
141 - self.all_dvs[i][0])
142 / next_node[axis].tolerance;
143 outer_jacobian[(3 * i + 1, OT * i + axis)] = (inner_sol_a.correction[1]
145 - self.all_dvs[i][1])
146 / next_node[axis].tolerance;
147 outer_jacobian[(3 * i + 2, OT * i + axis)] = (inner_sol_a.correction[2]
149 - self.all_dvs[i][2])
150 / next_node[axis].tolerance;
151
152 let inner_tgt_b = Targeter::delta_v(self.prop, self.targets[i + 1].into());
156 let inner_sol_b = inner_tgt_b
157 .try_achieve_dual(
158 inner_sol_a.achieved_state,
159 inner_sol_a.achieved_state.epoch(),
160 self.targets[i + 1].epoch(),
161 almanac.clone(),
162 )
163 .context(TargetingSnafu { segment: i })?;
164
165 outer_jacobian[(3 * (i + 1), OT * i + axis)] = (inner_sol_b.correction[0]
168 - self.all_dvs[i + 1][0])
169 / next_node[axis].tolerance;
170 outer_jacobian[(3 * (i + 1) + 1, OT * i + axis)] = (inner_sol_b.correction[1]
172 - self.all_dvs[i + 1][1])
173 / next_node[axis].tolerance;
174 outer_jacobian[(3 * (i + 1) + 2, OT * i + axis)] = (inner_sol_b.correction[2]
176 - self.all_dvs[i + 1][2])
177 / next_node[axis].tolerance;
178
179 if i < self.targets.len() - 3 {
183 let dv_ip1 = inner_sol_b.achieved_state.orbit.velocity_km_s
184 - initial_states[i + 2].orbit.velocity_km_s;
185 outer_jacobian[(3 * (i + 2), OT * i + axis)] =
187 dv_ip1[0] / next_node[axis].tolerance;
188 outer_jacobian[(3 * (i + 2) + 1, OT * i + axis)] =
190 dv_ip1[1] / next_node[axis].tolerance;
191 outer_jacobian[(3 * (i + 2) + 2, OT * i + axis)] =
193 dv_ip1[2] / next_node[axis].tolerance;
194 }
195 }
196 }
197
198 for i in 0..self.targets.len() {
200 for j in 0..3 {
201 cost_vec[3 * i + j] = self.all_dvs[i][j];
202 }
203 }
204
205 let new_cost = match cost {
207 CostFunction::MinimumEnergy => cost_vec.dot(&cost_vec),
208 CostFunction::MinimumFuel => cost_vec.dot(&cost_vec).sqrt(),
209 };
210
211 let cost_improvmt = (prev_cost - new_cost) / new_cost.abs();
213 match cost {
215 CostFunction::MinimumEnergy => info!(
216 "Multiple shooting iteration #{}\t\tCost = {:.3} km^2/s^2\timprovement = {:.2}%",
217 it,
218 new_cost,
219 100.0 * cost_improvmt
220 ),
221 CostFunction::MinimumFuel => info!(
222 "Multiple shooting iteration #{}\t\tCost = {:.3} km/s\timprovement = {:.2}%",
223 it,
224 new_cost,
225 100.0 * cost_improvmt
226 ),
227 };
228 if cost_improvmt.abs() < self.improvement_threshold {
229 info!("Improvement below desired threshold. Running targeter on computed nodes.");
230
231 let mut ms_sol = MultipleShootingSolution {
235 x0: self.x0,
236 xf: self.xf,
237 nodes: self.targets.clone(),
238 solutions: Vec::with_capacity(self.targets.len()),
239 };
240 let mut initial_states = Vec::with_capacity(self.targets.len());
241 initial_states.push(self.x0);
242
243 for (i, node) in self.targets.iter().enumerate() {
244 let tgt = Targeter::delta_v(self.prop, (*node).into());
246 let sol = tgt
247 .try_achieve_dual(
248 initial_states[i],
249 initial_states[i].epoch(),
250 node.epoch(),
251 almanac.clone(),
252 )
253 .context(TargetingSnafu { segment: i })?;
254 initial_states.push(sol.achieved_state);
255 ms_sol.solutions.push(sol);
256 }
257
258 return Ok(ms_sol);
259 }
260
261 prev_cost = new_cost;
262 let inv_jac =
264 pseudo_inverse!(&outer_jacobian).context(TargetingSnafu { segment: 0_usize })?;
265 let delta_r = inv_jac * cost_vec;
266 let node_vector = -delta_r;
268 for (i, val) in node_vector.iter().enumerate() {
269 let node_no = i / 3;
270 let component_no = i % OT;
271 self.targets[node_no].update_component(component_no, *val);
272 }
273 self.current_iteration += 1;
274 }
275 Err(MultipleShootingError::TargetingError {
276 segment: 0_usize,
277 source: Box::new(TargetingError::TooManyIterations),
278 })
279 }
280}
281
282impl<T: MultishootNode<OT>, const VT: usize, const OT: usize> fmt::Display
283 for MultipleShooting<'_, T, VT, OT>
284{
285 #[allow(clippy::or_fun_call, clippy::clone_on_copy)]
286 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
287 let mut nodemsg = String::from("");
288 nodemsg.push_str(&format!(
290 "[{:.3}, {:.3}, {:.3}, {}, {}, {}, {}, {}, {}],\n",
291 self.x0.orbit.radius_km.x,
292 self.x0.orbit.radius_km.y,
293 self.x0.orbit.radius_km.z,
294 self.current_iteration,
295 0.0,
296 0.0,
297 0.0,
298 0.0,
299 0
300 ));
301
302 for (i, node) in self.targets.iter().enumerate() {
303 let objectives: [Objective; OT] = (*node).into();
304 let mut this_nodemsg = String::from("");
305 for obj in &objectives {
306 this_nodemsg.push_str(&format!("{:.3}, ", obj.desired_value));
307 }
308 let mut this_costmsg = String::from("");
309 let dv = match self.all_dvs.get(i) {
310 Some(dv) => dv.clone(),
311 None => SVector::<f64, VT>::zeros(),
312 };
313 for val in &dv {
314 this_costmsg.push_str(&format!("{val}, "));
315 }
316 if VT == 3 {
317 this_costmsg.push_str(&format!("{}, ", dv.norm()));
319 }
320 nodemsg.push_str(&format!(
321 "[{}{}, {}{}],\n",
322 this_nodemsg,
323 self.current_iteration,
324 this_nodemsg,
325 i + 1
326 ));
327 }
328 write!(f, "{nodemsg}")
329 }
330}
331
332#[derive(Clone, Debug)]
333pub struct MultipleShootingSolution<T: MultishootNode<O>, const O: usize> {
334 pub x0: Spacecraft,
335 pub xf: Orbit,
336 pub nodes: Vec<T>,
337 pub solutions: Vec<TargeterSolution<3, O>>,
338}
339
340impl<T: MultishootNode<O>, const O: usize> fmt::Display for MultipleShootingSolution<T, O> {
341 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
342 for sol in &self.solutions {
343 write!(f, "{sol}")?;
344 }
345 Ok(())
346 }
347}
348
349impl<T: MultishootNode<O>, const O: usize> MultipleShootingSolution<T, O> {
350 pub fn build_trajectories(
353 &self,
354 prop: &Propagator<SpacecraftDynamics>,
355 almanac: Arc<Almanac>,
356 ) -> Result<Vec<Trajectory>, MultipleShootingError> {
357 let mut trajz = Vec::with_capacity(self.nodes.len());
358
359 for (i, node) in self.nodes.iter().copied().enumerate() {
360 let (_, traj) = Targeter::delta_v(prop, node.into())
361 .apply_with_traj(&self.solutions[i], almanac.clone())
362 .context(TargetingSnafu { segment: i })?;
363 trajz.push(traj);
364 }
365
366 Ok(trajz)
367 }
368}