1use super::details::{EventArc, EventDetails, EventEdge};
20use crate::errors::{EventError, EventTrajSnafu};
21use crate::linalg::allocator::Allocator;
22use crate::linalg::DefaultAllocator;
23use crate::md::prelude::{Interpolatable, Traj};
24use crate::md::EventEvaluator;
25use crate::time::{Duration, Epoch, TimeSeries, Unit};
26use anise::almanac::Almanac;
27use rayon::prelude::*;
28use snafu::ResultExt;
29use std::iter::Iterator;
30use std::sync::mpsc::channel;
31use std::sync::Arc;
32
33impl<S: Interpolatable> Traj<S>
34where
35 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
36{
37 #[allow(clippy::identity_op)]
39 pub fn find_bracketed<E>(
40 &self,
41 start: Epoch,
42 end: Epoch,
43 event: &E,
44 almanac: Arc<Almanac>,
45 ) -> Result<EventDetails<S>, EventError>
46 where
47 E: EventEvaluator<S>,
48 {
49 let max_iter = 50;
50
51 let has_converged =
53 |xa: f64, xb: f64| (xa - xb).abs() <= event.epoch_precision().to_seconds();
54 let arrange = |a: f64, ya: f64, b: f64, yb: f64| {
55 if ya.abs() > yb.abs() {
56 (a, ya, b, yb)
57 } else {
58 (b, yb, a, ya)
59 }
60 };
61
62 let xa_e = start;
63 let xb_e = end;
64
65 let mut xa = 0.0;
67 let mut xb = (xb_e - xa_e).to_seconds();
68 let ya_state = self.at(xa_e).context(EventTrajSnafu {})?;
70 let yb_state = self.at(xb_e).context(EventTrajSnafu {})?;
71 let mut ya = event.eval(&ya_state, almanac.clone())?;
72 let mut yb = event.eval(&yb_state, almanac.clone())?;
73
74 if ya.abs() <= event.value_precision().abs() {
76 debug!(
77 "{event} -- found with |{ya}| < {} @ {xa_e}",
78 event.value_precision().abs()
79 );
80 return EventDetails::new(ya_state, ya, event, self, almanac.clone());
81 } else if yb.abs() <= event.value_precision().abs() {
82 debug!(
83 "{event} -- found with |{yb}| < {} @ {xb_e}",
84 event.value_precision().abs()
85 );
86 return EventDetails::new(yb_state, yb, event, self, almanac.clone());
87 }
88
89 let (mut xc, mut yc, mut xd) = (xa, ya, xa);
93 let mut flag = true;
94
95 for _ in 0..max_iter {
96 if ya.abs() < event.value_precision().abs() {
97 let state = self.at(xa_e + xa * Unit::Second).unwrap();
99 debug!(
100 "{event} -- found with |{ya}| < {} @ {}",
101 event.value_precision().abs(),
102 state.epoch(),
103 );
104 return EventDetails::new(state, ya, event, self, almanac.clone());
105 }
106 if yb.abs() < event.value_precision().abs() {
107 let state = self.at(xa_e + xb * Unit::Second).unwrap();
109 debug!(
110 "{event} -- found with |{yb}| < {} @ {}",
111 event.value_precision().abs(),
112 state.epoch()
113 );
114 return EventDetails::new(state, yb, event, self, almanac.clone());
115 }
116 if has_converged(xa, xb) {
117 return Err(EventError::NotFound {
119 start,
120 end,
121 event: format!("{event}"),
122 });
123 }
124 let mut s = if (ya - yc).abs() > f64::EPSILON && (yb - yc).abs() > f64::EPSILON {
125 xa * yb * yc / ((ya - yb) * (ya - yc))
126 + xb * ya * yc / ((yb - ya) * (yb - yc))
127 + xc * ya * yb / ((yc - ya) * (yc - yb))
128 } else {
129 xb - yb * (xb - xa) / (yb - ya)
130 };
131 let cond1 = (s - xb) * (s - (3.0 * xa + xb) / 4.0) > 0.0;
132 let cond2 = flag && (s - xb).abs() >= (xb - xc).abs() / 2.0;
133 let cond3 = !flag && (s - xb).abs() >= (xc - xd).abs() / 2.0;
134 let cond4 = flag && has_converged(xb, xc);
135 let cond5 = !flag && has_converged(xc, xd);
136 if cond1 || cond2 || cond3 || cond4 || cond5 {
137 s = (xa + xb) / 2.0;
138 flag = true;
139 } else {
140 flag = false;
141 }
142 let next_try = self
143 .at(xa_e + s * Unit::Second)
144 .context(EventTrajSnafu {})?;
145 let ys = event.eval(&next_try, almanac.clone())?;
146 xd = xc;
147 xc = xb;
148 yc = yb;
149 if ya * ys < 0.0 {
150 let next_try = self
152 .at(xa_e + xa * Unit::Second)
153 .context(EventTrajSnafu {})?;
154 let ya_p = event.eval(&next_try, almanac.clone())?;
155 let (_a, _ya, _b, _yb) = arrange(xa, ya_p, s, ys);
156 {
157 xa = _a;
158 ya = _ya;
159 xb = _b;
160 yb = _yb;
161 }
162 } else {
163 let next_try = self
165 .at(xa_e + xb * Unit::Second)
166 .context(EventTrajSnafu {})?;
167 let yb_p = event.eval(&next_try, almanac.clone())?;
168 let (_a, _ya, _b, _yb) = arrange(s, ys, xb, yb_p);
169 {
170 xa = _a;
171 ya = _ya;
172 xb = _b;
173 yb = _yb;
174 }
175 }
176 }
177 error!("Brent solver failed after {max_iter} iterations");
178 Err(EventError::NotFound {
179 start,
180 end,
181 event: format!("{event}"),
182 })
183 }
184
185 #[allow(clippy::identity_op)]
199 pub fn find<E>(
200 &self,
201 event: &E,
202 almanac: Arc<Almanac>,
203 ) -> Result<Vec<EventDetails<S>>, EventError>
204 where
205 E: EventEvaluator<S>,
206 {
207 let start_epoch = self.first().epoch();
208 let end_epoch = self.last().epoch();
209 if start_epoch == end_epoch {
210 return Err(EventError::NotFound {
211 start: start_epoch,
212 end: end_epoch,
213 event: format!("{event}"),
214 });
215 }
216 let heuristic = (end_epoch - start_epoch) / 100;
217 info!("Searching for {event} with initial heuristic of {heuristic}");
218
219 let (sender, receiver) = channel();
220
221 let epochs: Vec<Epoch> = TimeSeries::inclusive(start_epoch, end_epoch, heuristic).collect();
222 epochs.into_par_iter().for_each_with(sender, |s, epoch| {
223 if let Ok(event_state) =
224 self.find_bracketed(epoch, epoch + heuristic, event, almanac.clone())
225 {
226 s.send(event_state).unwrap()
227 };
228 });
229
230 let mut states: Vec<_> = receiver.iter().collect();
231
232 if states.is_empty() {
233 warn!("Heuristic failed to find any {event} event, using slower approach");
234 match self.find_minmax(event, Unit::Second, almanac.clone()) {
237 Ok((min_event, max_event)) => {
238 let lower_min_epoch =
239 if min_event.epoch() - 1 * Unit::Millisecond < self.first().epoch() {
240 self.first().epoch()
241 } else {
242 min_event.epoch() - 1 * Unit::Millisecond
243 };
244
245 let lower_max_epoch =
246 if min_event.epoch() + 1 * Unit::Millisecond > self.last().epoch() {
247 self.last().epoch()
248 } else {
249 min_event.epoch() + 1 * Unit::Millisecond
250 };
251
252 let upper_min_epoch =
253 if max_event.epoch() - 1 * Unit::Millisecond < self.first().epoch() {
254 self.first().epoch()
255 } else {
256 max_event.epoch() - 1 * Unit::Millisecond
257 };
258
259 let upper_max_epoch =
260 if max_event.epoch() + 1 * Unit::Millisecond > self.last().epoch() {
261 self.last().epoch()
262 } else {
263 max_event.epoch() + 1 * Unit::Millisecond
264 };
265
266 if let Ok(event_state) = self.find_bracketed(
268 lower_min_epoch,
269 lower_max_epoch,
270 event,
271 almanac.clone(),
272 ) {
273 states.push(event_state);
274 };
275
276 if let Ok(event_state) = self.find_bracketed(
278 upper_min_epoch,
279 upper_max_epoch,
280 event,
281 almanac.clone(),
282 ) {
283 states.push(event_state);
284 };
285
286 if states.is_empty() {
288 return Err(EventError::NotFound {
289 start: start_epoch,
290 end: end_epoch,
291 event: format!("{event}"),
292 });
293 }
294 }
295 Err(_) => {
296 return Err(EventError::NotFound {
297 start: start_epoch,
298 end: end_epoch,
299 event: format!("{event}"),
300 });
301 }
302 };
303 }
304 states.sort_by(|s1, s2| s1.state.epoch().partial_cmp(&s2.state.epoch()).unwrap());
306 states.dedup();
307
308 match states.len() {
309 0 => info!("Event {event} not found"),
310 1 => info!("Event {event} found once on {}", states[0].state.epoch()),
311 _ => {
312 info!(
313 "Event {event} found {} times from {} until {}",
314 states.len(),
315 states.first().unwrap().state.epoch(),
316 states.last().unwrap().state.epoch()
317 )
318 }
319 };
320
321 Ok(states)
322 }
323
324 #[allow(clippy::identity_op)]
326 pub fn find_minmax<E>(
327 &self,
328 event: &E,
329 precision: Unit,
330 almanac: Arc<Almanac>,
331 ) -> Result<(S, S), EventError>
332 where
333 E: EventEvaluator<S>,
334 {
335 let step: Duration = 1 * precision;
336 let mut min_val = f64::INFINITY;
337 let mut max_val = f64::NEG_INFINITY;
338 let mut min_state = S::zeros();
339 let mut max_state = S::zeros();
340
341 let (sender, receiver) = channel();
342
343 let epochs: Vec<Epoch> =
344 TimeSeries::inclusive(self.first().epoch(), self.last().epoch(), step).collect();
345
346 epochs.into_par_iter().for_each_with(sender, |s, epoch| {
347 let state = self.at(epoch).unwrap();
349 if let Ok(this_eval) = event.eval(&state, almanac.clone()) {
350 s.send((this_eval, state)).unwrap();
351 }
352 });
353
354 let evald_states: Vec<_> = receiver.iter().collect();
355 for (this_eval, state) in evald_states {
356 if this_eval < min_val {
357 min_val = this_eval;
358 min_state = state;
359 }
360 if this_eval > max_val {
361 max_val = this_eval;
362 max_state = state;
363 }
364 }
365
366 Ok((min_state, max_state))
367 }
368
369 pub fn find_arcs<E>(
394 &self,
395 event: &E,
396 almanac: Arc<Almanac>,
397 ) -> Result<Vec<EventArc<S>>, EventError>
398 where
399 E: EventEvaluator<S>,
400 {
401 let mut events = match self.find(event, almanac.clone()) {
402 Ok(events) => events,
403 Err(_) => {
404 let first_eval = event.eval(self.first(), almanac.clone())?;
407 let last_eval = event.eval(self.last(), almanac.clone())?;
408 if first_eval > 0.0 && last_eval > 0.0 {
409 vec![
413 EventDetails::new(*self.first(), first_eval, event, self, almanac.clone())?,
414 EventDetails::new(*self.last(), last_eval, event, self, almanac.clone())?,
415 ]
416 } else {
417 return Err(EventError::NotFound {
418 start: self.first().epoch(),
419 end: self.last().epoch(),
420 event: format!("{event}"),
421 });
422 }
423 }
424 };
425 events.sort_by_key(|event| event.state.epoch());
426
427 let mut arcs = Vec::new();
429
430 if events.is_empty() {
431 return Ok(arcs);
432 }
433
434 let mut prev_rise = if events[0].edge != EventEdge::Rising {
436 let value = event.eval(self.first(), almanac.clone())?;
437 Some(EventDetails::new(
438 *self.first(),
439 value,
440 event,
441 self,
442 almanac.clone(),
443 )?)
444 } else {
445 Some(events[0].clone())
446 };
447
448 let mut prev_fall = if events[0].edge == EventEdge::Falling {
449 Some(events[0].clone())
450 } else {
451 None
452 };
453
454 for event in events {
455 if event.edge == EventEdge::Rising {
456 if prev_rise.is_none() && prev_fall.is_none() {
457 prev_rise = Some(event.clone());
459 } else if prev_fall.is_some() {
460 if prev_rise.is_some() {
462 let arc = EventArc {
463 rise: prev_rise.clone().unwrap(),
464 fall: prev_fall.clone().unwrap(),
465 };
466 arcs.push(arc);
467 } else {
468 let arc = EventArc {
469 rise: event.clone(),
470 fall: prev_fall.clone().unwrap(),
471 };
472 arcs.push(arc);
473 }
474 prev_fall = None;
475 prev_rise = Some(event.clone());
477 }
478 } else if event.edge == EventEdge::Falling {
479 prev_fall = Some(event.clone());
480 }
481 }
482
483 if prev_rise.is_some() {
485 if prev_fall.is_some() {
486 let arc = EventArc {
487 rise: prev_rise.clone().unwrap(),
488 fall: prev_fall.clone().unwrap(),
489 };
490 arcs.push(arc);
491 } else {
492 let value = event.eval(self.last(), almanac.clone())?;
494 let fall = EventDetails::new(*self.last(), value, event, self, almanac.clone())?;
495 let arc = EventArc {
496 rise: prev_rise.clone().unwrap(),
497 fall,
498 };
499 arcs.push(arc);
500 }
501 }
502
503 Ok(arcs)
504 }
505}