1use super::details::{EventArc, EventDetails, EventEdge};
20use crate::errors::{EventError, EventTrajSnafu};
21use crate::linalg::allocator::Allocator;
22use crate::linalg::DefaultAllocator;
23use crate::md::prelude::{Interpolatable, Traj};
24use crate::md::EventEvaluator;
25use crate::time::{Duration, Epoch, TimeSeries, Unit};
26use anise::almanac::Almanac;
27use rayon::prelude::*;
28use snafu::ResultExt;
29use std::iter::Iterator;
30use std::sync::mpsc::channel;
31use std::sync::Arc;
32
33impl<S: Interpolatable> Traj<S>
34where
35 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
36{
37 #[allow(clippy::identity_op)]
39 pub fn find_bracketed<E>(
40 &self,
41 start: Epoch,
42 end: Epoch,
43 event: &E,
44 almanac: Arc<Almanac>,
45 ) -> Result<EventDetails<S>, EventError>
46 where
47 E: EventEvaluator<S>,
48 {
49 let max_iter = 50;
50
51 let has_converged =
53 |xa: f64, xb: f64| (xa - xb).abs() <= event.epoch_precision().to_seconds();
54 let arrange = |a: f64, ya: f64, b: f64, yb: f64| {
55 if ya.abs() > yb.abs() {
56 (a, ya, b, yb)
57 } else {
58 (b, yb, a, ya)
59 }
60 };
61
62 let xa_e = start;
63 let xb_e = end;
64
65 let mut xa = 0.0;
67 let mut xb = (xb_e - xa_e).to_seconds();
68 let ya_state = self.at(xa_e).context(EventTrajSnafu {})?;
70 let yb_state = self.at(xb_e).context(EventTrajSnafu {})?;
71 let mut ya = event.eval(&ya_state, almanac.clone())?;
72 let mut yb = event.eval(&yb_state, almanac.clone())?;
73
74 if ya.abs() <= event.value_precision().abs() {
76 debug!(
77 "{event} -- found with |{ya}| < {} @ {xa_e}",
78 event.value_precision().abs()
79 );
80 return EventDetails::new(ya_state, ya, event, self, almanac.clone());
81 } else if yb.abs() <= event.value_precision().abs() {
82 debug!(
83 "{event} -- found with |{yb}| < {} @ {xb_e}",
84 event.value_precision().abs()
85 );
86 return EventDetails::new(yb_state, yb, event, self, almanac.clone());
87 }
88
89 let (mut xc, mut yc, mut xd) = (xa, ya, xa);
93 let mut flag = true;
94
95 for _ in 0..max_iter {
96 if ya.abs() < event.value_precision().abs() {
97 let state = self.at(xa_e + xa * Unit::Second).unwrap();
99 debug!(
100 "{event} -- found with |{ya}| < {} @ {}",
101 event.value_precision().abs(),
102 state.epoch(),
103 );
104 return EventDetails::new(state, ya, event, self, almanac.clone());
105 }
106 if yb.abs() < event.value_precision().abs() {
107 let state = self.at(xa_e + xb * Unit::Second).unwrap();
109 debug!(
110 "{event} -- found with |{yb}| < {} @ {}",
111 event.value_precision().abs(),
112 state.epoch()
113 );
114 return EventDetails::new(state, yb, event, self, almanac.clone());
115 }
116 if has_converged(xa, xb) {
117 return Err(EventError::NotFound {
119 start,
120 end,
121 event: format!("{event}"),
122 });
123 }
124 let mut s = if (ya - yc).abs() > f64::EPSILON && (yb - yc).abs() > f64::EPSILON {
125 xa * yb * yc / ((ya - yb) * (ya - yc))
126 + xb * ya * yc / ((yb - ya) * (yb - yc))
127 + xc * ya * yb / ((yc - ya) * (yc - yb))
128 } else {
129 xb - yb * (xb - xa) / (yb - ya)
130 };
131 let cond1 = (s - xb) * (s - (3.0 * xa + xb) / 4.0) > 0.0;
132 let cond2 = flag && (s - xb).abs() >= (xb - xc).abs() / 2.0;
133 let cond3 = !flag && (s - xb).abs() >= (xc - xd).abs() / 2.0;
134 let cond4 = flag && has_converged(xb, xc);
135 let cond5 = !flag && has_converged(xc, xd);
136 if cond1 || cond2 || cond3 || cond4 || cond5 {
137 s = (xa + xb) / 2.0;
138 flag = true;
139 } else {
140 flag = false;
141 }
142 let next_try = self
143 .at(xa_e + s * Unit::Second)
144 .context(EventTrajSnafu {})?;
145 let ys = event.eval(&next_try, almanac.clone())?;
146 xd = xc;
147 xc = xb;
148 yc = yb;
149 if ya * ys < 0.0 {
150 let next_try = self
152 .at(xa_e + xa * Unit::Second)
153 .context(EventTrajSnafu {})?;
154 let ya_p = event.eval(&next_try, almanac.clone())?;
155 let (_a, _ya, _b, _yb) = arrange(xa, ya_p, s, ys);
156 {
157 xa = _a;
158 ya = _ya;
159 xb = _b;
160 yb = _yb;
161 }
162 } else {
163 let next_try = self
165 .at(xa_e + xb * Unit::Second)
166 .context(EventTrajSnafu {})?;
167 let yb_p = event.eval(&next_try, almanac.clone())?;
168 let (_a, _ya, _b, _yb) = arrange(s, ys, xb, yb_p);
169 {
170 xa = _a;
171 ya = _ya;
172 xb = _b;
173 yb = _yb;
174 }
175 }
176 }
177 error!("Brent solver failed after {max_iter} iterations");
178 Err(EventError::NotFound {
179 start,
180 end,
181 event: format!("{event}"),
182 })
183 }
184
185 #[allow(clippy::identity_op)]
199 pub fn find<E>(
200 &self,
201 event: &E,
202 heuristic: Option<Duration>,
203 almanac: Arc<Almanac>,
204 ) -> Result<Vec<EventDetails<S>>, EventError>
205 where
206 E: EventEvaluator<S>,
207 {
208 let start_epoch = self.first().epoch();
209 let end_epoch = self.last().epoch();
210 if start_epoch == end_epoch {
211 return Err(EventError::NotFound {
212 start: start_epoch,
213 end: end_epoch,
214 event: format!("{event}"),
215 });
216 }
217 let heuristic = heuristic.unwrap_or((end_epoch - start_epoch) / 100);
218 info!("Searching for {event} with initial heuristic of {heuristic}");
219
220 let (sender, receiver) = channel();
221
222 let epochs: Vec<Epoch> = TimeSeries::inclusive(start_epoch, end_epoch, heuristic).collect();
223 epochs.into_par_iter().for_each_with(sender, |s, epoch| {
224 if let Ok(event_state) =
225 self.find_bracketed(epoch, epoch + heuristic, event, almanac.clone())
226 {
227 s.send(event_state).unwrap()
228 };
229 });
230
231 let mut states: Vec<_> = receiver.iter().collect();
232
233 if states.is_empty() {
234 warn!("Heuristic failed to find any {event} event, using slower approach");
235 match self.find_minmax(event, Unit::Second, almanac.clone()) {
238 Ok((min_event, max_event)) => {
239 let lower_min_epoch =
240 if min_event.epoch() - 1 * Unit::Millisecond < self.first().epoch() {
241 self.first().epoch()
242 } else {
243 min_event.epoch() - 1 * Unit::Millisecond
244 };
245
246 let lower_max_epoch =
247 if min_event.epoch() + 1 * Unit::Millisecond > self.last().epoch() {
248 self.last().epoch()
249 } else {
250 min_event.epoch() + 1 * Unit::Millisecond
251 };
252
253 let upper_min_epoch =
254 if max_event.epoch() - 1 * Unit::Millisecond < self.first().epoch() {
255 self.first().epoch()
256 } else {
257 max_event.epoch() - 1 * Unit::Millisecond
258 };
259
260 let upper_max_epoch =
261 if max_event.epoch() + 1 * Unit::Millisecond > self.last().epoch() {
262 self.last().epoch()
263 } else {
264 max_event.epoch() + 1 * Unit::Millisecond
265 };
266
267 if let Ok(event_state) = self.find_bracketed(
269 lower_min_epoch,
270 lower_max_epoch,
271 event,
272 almanac.clone(),
273 ) {
274 states.push(event_state);
275 };
276
277 if let Ok(event_state) = self.find_bracketed(
279 upper_min_epoch,
280 upper_max_epoch,
281 event,
282 almanac.clone(),
283 ) {
284 states.push(event_state);
285 };
286
287 if states.is_empty() {
289 return Err(EventError::NotFound {
290 start: start_epoch,
291 end: end_epoch,
292 event: format!("{event}"),
293 });
294 }
295 }
296 Err(_) => {
297 return Err(EventError::NotFound {
298 start: start_epoch,
299 end: end_epoch,
300 event: format!("{event}"),
301 });
302 }
303 };
304 }
305 states.sort_by(|s1, s2| s1.state.epoch().partial_cmp(&s2.state.epoch()).unwrap());
307 states.dedup();
308
309 match states.len() {
310 0 => info!("Event {event} not found"),
311 1 => info!("Event {event} found once on {}", states[0].state.epoch()),
312 _ => {
313 info!(
314 "Event {event} found {} times from {} until {}",
315 states.len(),
316 states.first().unwrap().state.epoch(),
317 states.last().unwrap().state.epoch()
318 )
319 }
320 };
321
322 Ok(states)
323 }
324
325 #[allow(clippy::identity_op)]
327 pub fn find_minmax<E>(
328 &self,
329 event: &E,
330 precision: Unit,
331 almanac: Arc<Almanac>,
332 ) -> Result<(S, S), EventError>
333 where
334 E: EventEvaluator<S>,
335 {
336 let step: Duration = 1 * precision;
337 let mut min_val = f64::INFINITY;
338 let mut max_val = f64::NEG_INFINITY;
339 let mut min_state = S::zeros();
340 let mut max_state = S::zeros();
341
342 let (sender, receiver) = channel();
343
344 let epochs: Vec<Epoch> =
345 TimeSeries::inclusive(self.first().epoch(), self.last().epoch(), step).collect();
346
347 epochs.into_par_iter().for_each_with(sender, |s, epoch| {
348 let state = self.at(epoch).unwrap();
350 if let Ok(this_eval) = event.eval(&state, almanac.clone()) {
351 s.send((this_eval, state)).unwrap();
352 }
353 });
354
355 let evald_states: Vec<_> = receiver.iter().collect();
356 for (this_eval, state) in evald_states {
357 if this_eval < min_val {
358 min_val = this_eval;
359 min_state = state;
360 }
361 if this_eval > max_val {
362 max_val = this_eval;
363 max_state = state;
364 }
365 }
366
367 Ok((min_state, max_state))
368 }
369
370 pub fn find_arcs<E>(
395 &self,
396 event: &E,
397 heuristic: Option<Duration>,
398 almanac: Arc<Almanac>,
399 ) -> Result<Vec<EventArc<S>>, EventError>
400 where
401 E: EventEvaluator<S>,
402 {
403 let mut events = match self.find(event, heuristic, almanac.clone()) {
404 Ok(events) => events,
405 Err(_) => {
406 let first_eval = event.eval(self.first(), almanac.clone())?;
409 let last_eval = event.eval(self.last(), almanac.clone())?;
410 if first_eval > 0.0 && last_eval > 0.0 {
411 vec![
415 EventDetails::new(*self.first(), first_eval, event, self, almanac.clone())?,
416 EventDetails::new(*self.last(), last_eval, event, self, almanac.clone())?,
417 ]
418 } else {
419 return Err(EventError::NotFound {
420 start: self.first().epoch(),
421 end: self.last().epoch(),
422 event: format!("{event}"),
423 });
424 }
425 }
426 };
427 events.sort_by_key(|event| event.state.epoch());
428
429 let mut arcs = Vec::new();
431
432 if events.is_empty() {
433 return Ok(arcs);
434 }
435
436 let mut prev_rise = if events[0].edge != EventEdge::Rising {
438 let value = event.eval(self.first(), almanac.clone())?;
439 Some(EventDetails::new(
440 *self.first(),
441 value,
442 event,
443 self,
444 almanac.clone(),
445 )?)
446 } else {
447 Some(events[0].clone())
448 };
449
450 let mut prev_fall = if events[0].edge == EventEdge::Falling {
451 Some(events[0].clone())
452 } else {
453 None
454 };
455
456 for event in events {
457 if event.edge == EventEdge::Rising {
458 if prev_rise.is_none() && prev_fall.is_none() {
459 prev_rise = Some(event.clone());
461 } else if prev_fall.is_some() {
462 if prev_rise.is_some() {
464 let arc = EventArc {
465 rise: prev_rise.clone().unwrap(),
466 fall: prev_fall.clone().unwrap(),
467 };
468 arcs.push(arc);
469 } else {
470 let arc = EventArc {
471 rise: event.clone(),
472 fall: prev_fall.clone().unwrap(),
473 };
474 arcs.push(arc);
475 }
476 prev_fall = None;
477 prev_rise = Some(event.clone());
479 }
480 } else if event.edge == EventEdge::Falling {
481 prev_fall = Some(event.clone());
482 }
483 }
484
485 if prev_rise.is_some() {
487 if prev_fall.is_some() {
488 let arc = EventArc {
489 rise: prev_rise.clone().unwrap(),
490 fall: prev_fall.clone().unwrap(),
491 };
492 arcs.push(arc);
493 } else {
494 let value = event.eval(self.last(), almanac.clone())?;
496 let fall = EventDetails::new(*self.last(), value, event, self, almanac.clone())?;
497 let arc = EventArc {
498 rise: prev_rise.clone().unwrap(),
499 fall,
500 };
501 arcs.push(arc);
502 }
503 }
504
505 Ok(arcs)
506 }
507}