nyx_space/md/opti/multipleshooting/
multishoot.rs1use snafu::ResultExt;
20
21pub use super::CostFunction;
22use super::{MultipleShootingError, TargetingSnafu};
23use crate::linalg::{DMatrix, DVector, SVector};
24use crate::md::opti::solution::TargeterSolution;
25use crate::md::targeter::Targeter;
26use crate::md::{prelude::*, TargetingError};
27use crate::pseudo_inverse;
28use crate::{Orbit, Spacecraft};
29
30use std::fmt;
31
32pub trait MultishootNode<const O: usize>: Copy + Into<[Objective; O]> {
33 fn epoch(&self) -> Epoch;
34 fn update_component(&mut self, component: usize, add_val: f64);
35}
36
37pub struct MultipleShooting<'a, T: MultishootNode<OT>, const VT: usize, const OT: usize> {
42 pub prop: &'a Propagator<SpacecraftDynamics>,
44 pub targets: Vec<T>,
46 pub x0: Spacecraft,
48 pub xf: Orbit,
50 pub current_iteration: usize,
51 pub max_iterations: usize,
53 pub improvement_threshold: f64,
57 pub variables: [Variable; VT],
59 pub all_dvs: Vec<SVector<f64, VT>>,
60}
61
62impl<T: MultishootNode<OT>, const VT: usize, const OT: usize> MultipleShooting<'_, T, VT, OT> {
63 pub fn solve(
65 &mut self,
66 cost: CostFunction,
67 almanac: Arc<Almanac>,
68 ) -> Result<MultipleShootingSolution<T, OT>, MultipleShootingError> {
69 let mut prev_cost = 1e12; for it in 0..self.max_iterations {
71 let mut initial_states = Vec::with_capacity(self.targets.len());
72 initial_states.push(self.x0);
73 let mut outer_jacobian =
74 DMatrix::from_element(3 * self.targets.len(), OT * (self.targets.len() - 1), 0.0);
75 let mut cost_vec = DVector::from_element(3 * self.targets.len(), 0.0);
76
77 self.all_dvs = Vec::with_capacity(self.all_dvs.len());
79
80 for i in 0..self.targets.len() {
81 let tgt = Targeter {
85 prop: self.prop,
86 objectives: self.targets[i].into(),
87 variables: self.variables,
88 iterations: 100,
89 objective_frame: None,
90 correction_frame: None,
91 };
92 let sol = tgt
93 .try_achieve_dual(
94 initial_states[i],
95 initial_states[i].epoch(),
96 self.targets[i].epoch(),
97 almanac.clone(),
98 )
99 .context(TargetingSnafu { segment: i })?;
100
101 let nominal_delta_v = sol.correction;
102
103 self.all_dvs.push(nominal_delta_v);
104 initial_states.push(sol.achieved_state);
106 }
107 for i in 0..(self.targets.len() - 1) {
110 for axis in 0..OT {
118 let mut next_node = self.targets[i].into();
122 next_node[axis].desired_value += next_node[axis].tolerance;
123 let inner_tgt_a = Targeter::delta_v(self.prop, next_node);
129 let inner_sol_a = inner_tgt_a
130 .try_achieve_dual(
131 initial_states[i],
132 initial_states[i].epoch(),
133 self.targets[i].epoch(),
134 almanac.clone(),
135 )
136 .context(TargetingSnafu { segment: i })?;
137
138 outer_jacobian[(3 * i, OT * i + axis)] = (inner_sol_a.correction[0]
140 - self.all_dvs[i][0])
141 / next_node[axis].tolerance;
142 outer_jacobian[(3 * i + 1, OT * i + axis)] = (inner_sol_a.correction[1]
144 - self.all_dvs[i][1])
145 / next_node[axis].tolerance;
146 outer_jacobian[(3 * i + 2, OT * i + axis)] = (inner_sol_a.correction[2]
148 - self.all_dvs[i][2])
149 / next_node[axis].tolerance;
150
151 let inner_tgt_b = Targeter::delta_v(self.prop, self.targets[i + 1].into());
155 let inner_sol_b = inner_tgt_b
156 .try_achieve_dual(
157 inner_sol_a.achieved_state,
158 inner_sol_a.achieved_state.epoch(),
159 self.targets[i + 1].epoch(),
160 almanac.clone(),
161 )
162 .context(TargetingSnafu { segment: i })?;
163
164 outer_jacobian[(3 * (i + 1), OT * i + axis)] = (inner_sol_b.correction[0]
167 - self.all_dvs[i + 1][0])
168 / next_node[axis].tolerance;
169 outer_jacobian[(3 * (i + 1) + 1, OT * i + axis)] = (inner_sol_b.correction[1]
171 - self.all_dvs[i + 1][1])
172 / next_node[axis].tolerance;
173 outer_jacobian[(3 * (i + 1) + 2, OT * i + axis)] = (inner_sol_b.correction[2]
175 - self.all_dvs[i + 1][2])
176 / next_node[axis].tolerance;
177
178 if i < self.targets.len() - 3 {
182 let dv_ip1 = inner_sol_b.achieved_state.orbit.velocity_km_s
183 - initial_states[i + 2].orbit.velocity_km_s;
184 outer_jacobian[(3 * (i + 2), OT * i + axis)] =
186 dv_ip1[0] / next_node[axis].tolerance;
187 outer_jacobian[(3 * (i + 2) + 1, OT * i + axis)] =
189 dv_ip1[1] / next_node[axis].tolerance;
190 outer_jacobian[(3 * (i + 2) + 2, OT * i + axis)] =
192 dv_ip1[2] / next_node[axis].tolerance;
193 }
194 }
195 }
196
197 for i in 0..self.targets.len() {
199 for j in 0..3 {
200 cost_vec[3 * i + j] = self.all_dvs[i][j];
201 }
202 }
203
204 let new_cost = match cost {
206 CostFunction::MinimumEnergy => cost_vec.dot(&cost_vec),
207 CostFunction::MinimumFuel => cost_vec.dot(&cost_vec).sqrt(),
208 };
209
210 let cost_improvmt = (prev_cost - new_cost) / new_cost.abs();
212 match cost {
214 CostFunction::MinimumEnergy => info!(
215 "Multiple shooting iteration #{}\t\tCost = {:.3} km^2/s^2\timprovement = {:.2}%",
216 it,
217 new_cost,
218 100.0 * cost_improvmt
219 ),
220 CostFunction::MinimumFuel => info!(
221 "Multiple shooting iteration #{}\t\tCost = {:.3} km/s\timprovement = {:.2}%",
222 it,
223 new_cost,
224 100.0 * cost_improvmt
225 ),
226 };
227 if cost_improvmt.abs() < self.improvement_threshold {
228 info!("Improvement below desired threshold. Running targeter on computed nodes.");
229
230 let mut ms_sol = MultipleShootingSolution {
234 x0: self.x0,
235 xf: self.xf,
236 nodes: self.targets.clone(),
237 solutions: Vec::with_capacity(self.targets.len()),
238 };
239 let mut initial_states = Vec::with_capacity(self.targets.len());
240 initial_states.push(self.x0);
241
242 for (i, node) in self.targets.iter().enumerate() {
243 let tgt = Targeter::delta_v(self.prop, (*node).into());
245 let sol = tgt
246 .try_achieve_dual(
247 initial_states[i],
248 initial_states[i].epoch(),
249 node.epoch(),
250 almanac.clone(),
251 )
252 .context(TargetingSnafu { segment: i })?;
253 initial_states.push(sol.achieved_state);
254 ms_sol.solutions.push(sol);
255 }
256
257 return Ok(ms_sol);
258 }
259
260 prev_cost = new_cost;
261 let inv_jac =
263 pseudo_inverse!(&outer_jacobian).context(TargetingSnafu { segment: 0_usize })?;
264 let delta_r = inv_jac * cost_vec;
265 let node_vector = -delta_r;
267 for (i, val) in node_vector.iter().enumerate() {
268 let node_no = i / 3;
269 let component_no = i % OT;
270 self.targets[node_no].update_component(component_no, *val);
271 }
272 self.current_iteration += 1;
273 }
274 Err(MultipleShootingError::TargetingError {
275 segment: 0_usize,
276 source: TargetingError::TooManyIterations,
277 })
278 }
279}
280
281impl<T: MultishootNode<OT>, const VT: usize, const OT: usize> fmt::Display
282 for MultipleShooting<'_, T, VT, OT>
283{
284 #[allow(clippy::or_fun_call, clippy::clone_on_copy)]
285 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
286 let mut nodemsg = String::from("");
287 nodemsg.push_str(&format!(
289 "[{:.3}, {:.3}, {:.3}, {}, {}, {}, {}, {}, {}],\n",
290 self.x0.orbit.radius_km.x,
291 self.x0.orbit.radius_km.y,
292 self.x0.orbit.radius_km.z,
293 self.current_iteration,
294 0.0,
295 0.0,
296 0.0,
297 0.0,
298 0
299 ));
300
301 for (i, node) in self.targets.iter().enumerate() {
302 let objectives: [Objective; OT] = (*node).into();
303 let mut this_nodemsg = String::from("");
304 for obj in &objectives {
305 this_nodemsg.push_str(&format!("{:.3}, ", obj.desired_value));
306 }
307 let mut this_costmsg = String::from("");
308 let dv = match self.all_dvs.get(i) {
309 Some(dv) => dv.clone(),
310 None => SVector::<f64, VT>::zeros(),
311 };
312 for val in &dv {
313 this_costmsg.push_str(&format!("{val}, "));
314 }
315 if VT == 3 {
316 this_costmsg.push_str(&format!("{}, ", dv.norm()));
318 }
319 nodemsg.push_str(&format!(
320 "[{}{}, {}{}],\n",
321 this_nodemsg,
322 self.current_iteration,
323 this_nodemsg,
324 i + 1
325 ));
326 }
327 write!(f, "{nodemsg}")
328 }
329}
330
331#[derive(Clone, Debug)]
332pub struct MultipleShootingSolution<T: MultishootNode<O>, const O: usize> {
333 pub x0: Spacecraft,
334 pub xf: Orbit,
335 pub nodes: Vec<T>,
336 pub solutions: Vec<TargeterSolution<3, O>>,
337}
338
339impl<T: MultishootNode<O>, const O: usize> fmt::Display for MultipleShootingSolution<T, O> {
340 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
341 for sol in &self.solutions {
342 write!(f, "{sol}")?;
343 }
344 Ok(())
345 }
346}
347
348impl<T: MultishootNode<O>, const O: usize> MultipleShootingSolution<T, O> {
349 pub fn build_trajectories(
352 &self,
353 prop: &Propagator<SpacecraftDynamics>,
354 almanac: Arc<Almanac>,
355 ) -> Result<Vec<Trajectory>, MultipleShootingError> {
356 let mut trajz = Vec::with_capacity(self.nodes.len());
357
358 for (i, node) in self.nodes.iter().copied().enumerate() {
359 let (_, traj) = Targeter::delta_v(prop, node.into())
360 .apply_with_traj(&self.solutions[i], almanac.clone())
361 .context(TargetingSnafu { segment: i })?;
362 trajz.push(traj);
363 }
364
365 Ok(trajz)
366 }
367}