1use super::traj_it::TrajIterator;
20use super::{ExportCfg, InterpolationSnafu, INTERPOLATION_SAMPLES};
21use super::{Interpolatable, TrajError};
22use crate::errors::NyxError;
23use crate::io::watermark::pq_writer;
24use crate::io::InputOutputError;
25use crate::linalg::allocator::Allocator;
26use crate::linalg::DefaultAllocator;
27use crate::md::prelude::{GuidanceMode, StateParameter};
28use crate::md::trajectory::smooth_state_diff_in_place;
29use crate::md::EventEvaluator;
30use crate::time::{Duration, Epoch, TimeSeries, TimeUnits};
31use anise::almanac::Almanac;
32use arrow::array::{Array, Float64Builder, StringBuilder};
33use arrow::datatypes::{DataType, Field, Schema};
34use arrow::record_batch::RecordBatch;
35use hifitime::TimeScale;
36use parquet::arrow::ArrowWriter;
37use snafu::ResultExt;
38use std::collections::HashMap;
39use std::error::Error;
40use std::fmt;
41use std::fs::File;
42use std::iter::Iterator;
43use std::ops;
44use std::ops::Bound::{Excluded, Included, Unbounded};
45use std::path::{Path, PathBuf};
46use std::sync::Arc;
47
48#[derive(Clone, PartialEq)]
50pub struct Traj<S: Interpolatable>
51where
52 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
53{
54 pub name: Option<String>,
56 pub states: Vec<S>,
58}
59
60impl<S: Interpolatable> Traj<S>
61where
62 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
63{
64 pub fn new() -> Self {
65 Self {
66 name: None,
67 states: Vec::new(),
68 }
69 }
70 pub fn finalize(&mut self) {
72 self.states.dedup_by(|a, b| a.epoch().eq(&b.epoch()));
74 self.states.sort_by_key(|a| a.epoch());
76 }
77
78 pub fn at(&self, epoch: Epoch) -> Result<S, TrajError> {
80 if self.states.is_empty() || self.first().epoch() > epoch || self.last().epoch() < epoch {
81 return Err(TrajError::NoInterpolationData { epoch });
82 }
83 match self
84 .states
85 .binary_search_by(|state| state.epoch().cmp(&epoch))
86 {
87 Ok(idx) => {
88 Ok(self.states[idx])
90 }
91 Err(idx) => {
92 if idx == 0 || idx >= self.states.len() {
93 return Err(TrajError::NoInterpolationData { epoch });
96 }
97 let num_left = INTERPOLATION_SAMPLES / 2;
102
103 let mut first_idx = idx.saturating_sub(num_left);
105 let last_idx = self.states.len().min(first_idx + INTERPOLATION_SAMPLES);
106
107 if last_idx == self.states.len() {
109 first_idx = last_idx.saturating_sub(2 * num_left);
110 }
111
112 let mut states = Vec::with_capacity(last_idx - first_idx);
113 for idx in first_idx..last_idx {
114 states.push(self.states[idx]);
115 }
116
117 self.states[idx]
118 .interpolate(epoch, &states)
119 .context(InterpolationSnafu)
120 }
121 }
122 }
123
124 pub fn first(&self) -> &S {
126 self.states.first().unwrap()
128 }
129
130 pub fn last(&self) -> &S {
132 self.states.last().unwrap()
133 }
134
135 pub fn every(&self, step: Duration) -> TrajIterator<'_, S> {
137 self.every_between(step, self.first().epoch(), self.last().epoch())
138 }
139
140 pub fn every_between(&self, step: Duration, start: Epoch, end: Epoch) -> TrajIterator<'_, S> {
142 TrajIterator {
143 time_series: TimeSeries::inclusive(
144 start.max(self.first().epoch()),
145 end.min(self.last().epoch()),
146 step,
147 ),
148 traj: self,
149 }
150 }
151
152 pub fn filter_by_epoch<R: ops::RangeBounds<Epoch>>(mut self, bound: R) -> Self {
154 self.states = self
155 .states
156 .iter()
157 .copied()
158 .filter(|s| bound.contains(&s.epoch()))
159 .collect::<Vec<_>>();
160 self
161 }
162
163 pub fn filter_by_offset<R: ops::RangeBounds<Duration>>(self, bound: R) -> Self {
166 if self.states.is_empty() {
167 return self;
168 }
169 let start = match bound.start_bound() {
171 Unbounded => self.states.first().unwrap().epoch(),
172 Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
173 };
174
175 let end = match bound.end_bound() {
176 Unbounded => self.states.last().unwrap().epoch(),
177 Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
178 };
179
180 self.filter_by_epoch(start..=end)
181 }
182 pub fn to_parquet_simple<P: AsRef<Path>>(
184 &self,
185 path: P,
186 almanac: Arc<Almanac>,
187 ) -> Result<PathBuf, Box<dyn Error>> {
188 self.to_parquet(path, None, ExportCfg::default(), almanac)
189 }
190
191 pub fn to_parquet_with_cfg<P: AsRef<Path>>(
193 &self,
194 path: P,
195 cfg: ExportCfg,
196 almanac: Arc<Almanac>,
197 ) -> Result<PathBuf, Box<dyn Error>> {
198 self.to_parquet(path, None, cfg, almanac)
199 }
200
201 pub fn to_parquet_with_step<P: AsRef<Path>>(
203 &self,
204 path: P,
205 step: Duration,
206 almanac: Arc<Almanac>,
207 ) -> Result<(), Box<dyn Error>> {
208 self.to_parquet_with_cfg(
209 path,
210 ExportCfg {
211 step: Some(step),
212 ..Default::default()
213 },
214 almanac,
215 )?;
216
217 Ok(())
218 }
219
220 pub fn to_parquet<P: AsRef<Path>>(
222 &self,
223 path: P,
224 events: Option<Vec<&dyn EventEvaluator<S>>>,
225 cfg: ExportCfg,
226 almanac: Arc<Almanac>,
227 ) -> Result<PathBuf, Box<dyn Error>> {
228 let tick = Epoch::now().unwrap();
229 info!("Exporting trajectory to parquet file...");
230
231 let path_buf = cfg.actual_path(path);
233
234 let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
236
237 let frame = self.states[0].frame();
238 let more_meta = Some(vec![(
239 "Frame".to_string(),
240 serde_dhall::serialize(&frame)
241 .static_type_annotation()
242 .to_string()
243 .map_err(|e| {
244 Box::new(InputOutputError::SerializeDhall {
245 what: format!("frame `{frame}`"),
246 err: e.to_string(),
247 })
248 })?,
249 )]);
250
251 let mut fields = match cfg.fields {
252 Some(fields) => fields,
253 None => S::export_params(),
254 };
255
256 fields.retain(|param| self.first().value(*param).is_ok());
258
259 for field in &fields {
260 hdrs.push(field.to_field(more_meta.clone()));
261 }
262
263 if let Some(events) = events.as_ref() {
264 for event in events {
265 let field = Field::new(format!("{event}"), DataType::Float64, false);
266 hdrs.push(field);
267 }
268 }
269
270 let schema = Arc::new(Schema::new(hdrs));
272 let mut record: Vec<Arc<dyn Array>> = Vec::new();
273
274 let states = if cfg.start_epoch.is_some() || cfg.end_epoch.is_some() || cfg.step.is_some() {
276 let start = cfg.start_epoch.unwrap_or_else(|| self.first().epoch());
278 let end = cfg.end_epoch.unwrap_or_else(|| self.last().epoch());
279 let step = cfg.step.unwrap_or_else(|| 1.minutes());
280 self.every_between(step, start, end).collect::<Vec<S>>()
281 } else {
282 self.states.to_vec()
283 };
284
285 let mut utc_epoch = StringBuilder::new();
289 for s in &states {
290 utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
291 }
292 record.push(Arc::new(utc_epoch.finish()));
293
294 for field in fields {
296 if field == StateParameter::GuidanceMode {
297 let mut guid_mode = StringBuilder::new();
298 for s in &states {
299 guid_mode
300 .append_value(format!("{:?}", GuidanceMode::from(s.value(field).unwrap())));
301 }
302 record.push(Arc::new(guid_mode.finish()));
303 } else {
304 let mut data = Float64Builder::new();
305 for s in &states {
306 data.append_value(s.value(field).unwrap());
307 }
308 record.push(Arc::new(data.finish()));
309 }
310 }
311
312 info!(
313 "Serialized {} states from {} to {}",
314 states.len(),
315 states.first().unwrap().epoch(),
316 states.last().unwrap().epoch()
317 );
318
319 if let Some(events) = events {
321 info!("Evaluating {} event(s)", events.len());
322 for event in events {
323 let mut data = Float64Builder::new();
324 for s in &states {
325 data.append_value(event.eval(s, almanac.clone()).map_err(Box::new)?);
326 }
327 record.push(Arc::new(data.finish()));
328 }
329 }
330
331 let mut metadata = HashMap::new();
333 metadata.insert("Purpose".to_string(), "Trajectory data".to_string());
334 if let Some(add_meta) = cfg.metadata {
335 for (k, v) in add_meta {
336 metadata.insert(k, v);
337 }
338 }
339
340 let props = pq_writer(Some(metadata));
341
342 let file = File::create(&path_buf)?;
343 let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
344
345 let batch = RecordBatch::try_new(schema, record)?;
346 writer.write(&batch)?;
347 writer.close()?;
348
349 let tock_time = Epoch::now().unwrap() - tick;
351 info!(
352 "Trajectory written to {} in {tock_time}",
353 path_buf.display()
354 );
355 Ok(path_buf)
356 }
357
358 pub fn resample(&self, step: Duration) -> Result<Self, NyxError> {
361 if self.states.is_empty() {
362 return Err(NyxError::Trajectory {
363 source: TrajError::CreationError {
364 msg: "No trajectory to convert".to_string(),
365 },
366 });
367 }
368
369 let mut traj = Self::new();
370 for state in self.every(step) {
371 traj.states.push(state);
372 }
373
374 traj.finalize();
375
376 Ok(traj)
377 }
378
379 pub fn rebuild(&self, epochs: &[Epoch]) -> Result<Self, NyxError> {
382 if self.states.is_empty() {
383 return Err(NyxError::Trajectory {
384 source: TrajError::CreationError {
385 msg: "No trajectory to convert".to_string(),
386 },
387 });
388 }
389
390 let mut traj = Self::new();
391 for epoch in epochs {
392 traj.states.push(self.at(*epoch)?);
393 }
394
395 traj.finalize();
396
397 Ok(traj)
398 }
399
400 pub fn ric_diff_to_parquet<P: AsRef<Path>>(
405 &self,
406 other: &Self,
407 path: P,
408 cfg: ExportCfg,
409 ) -> Result<PathBuf, Box<dyn Error>> {
410 let tick = Epoch::now().unwrap();
411 info!("Exporting trajectory to parquet file...");
412
413 let path_buf = cfg.actual_path(path);
415
416 let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
418
419 for coord in ["X", "Y", "Z"] {
421 let mut meta = HashMap::new();
422 meta.insert("unit".to_string(), "km".to_string());
423
424 let field = Field::new(
425 format!("Delta {coord} (RIC) (km)"),
426 DataType::Float64,
427 false,
428 )
429 .with_metadata(meta);
430
431 hdrs.push(field);
432 }
433
434 for coord in ["x", "y", "z"] {
435 let mut meta = HashMap::new();
436 meta.insert("unit".to_string(), "km/s".to_string());
437
438 let field = Field::new(
439 format!("Delta V{coord} (RIC) (km/s)"),
440 DataType::Float64,
441 false,
442 )
443 .with_metadata(meta);
444
445 hdrs.push(field);
446 }
447
448 let frame = self.states[0].frame();
449 let more_meta = Some(vec![(
450 "Frame".to_string(),
451 serde_dhall::serialize(&frame)
452 .static_type_annotation()
453 .to_string()
454 .unwrap_or(frame.to_string()),
455 )]);
456
457 let mut cfg = cfg;
458
459 let mut fields = match cfg.fields {
460 Some(fields) => fields,
461 None => S::export_params(),
462 };
463
464 fields.retain(|param| {
466 param != &StateParameter::GuidanceMode && self.first().value(*param).is_ok()
467 });
468
469 for field in &fields {
470 hdrs.push(field.to_field(more_meta.clone()));
471 }
472
473 let schema = Arc::new(Schema::new(hdrs));
475 let mut record: Vec<Arc<dyn Array>> = Vec::new();
476
477 cfg.start_epoch = if self.first().epoch() > other.first().epoch() {
479 Some(self.first().epoch())
480 } else {
481 Some(other.first().epoch())
482 };
483
484 cfg.end_epoch = if self.last().epoch() > other.last().epoch() {
485 Some(other.last().epoch())
486 } else {
487 Some(self.last().epoch())
488 };
489
490 let step = cfg.step.unwrap_or_else(|| 1.minutes());
492 let self_states = self
493 .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
494 .collect::<Vec<S>>();
495
496 let other_states = other
497 .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
498 .collect::<Vec<S>>();
499
500 let mut ric_diff = Vec::with_capacity(other_states.len());
502 for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
503 let self_orbit = self_state.orbit();
504 let other_orbit = other_state.orbit();
505
506 let this_ric_diff = self_orbit.ric_difference(&other_orbit).map_err(Box::new)?;
507
508 ric_diff.push(this_ric_diff);
509 }
510
511 smooth_state_diff_in_place(&mut ric_diff, if other_states.len() > 5 { 5 } else { 1 });
512
513 let mut utc_epoch = StringBuilder::new();
517 for s in &self_states {
518 utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
519 }
520 record.push(Arc::new(utc_epoch.finish()));
521
522 for coord_no in 0..6 {
524 let mut data = Float64Builder::new();
525 for this_ric_dff in &ric_diff {
526 data.append_value(this_ric_dff.to_cartesian_pos_vel()[coord_no]);
527 }
528 record.push(Arc::new(data.finish()));
529 }
530
531 for field in fields {
533 let mut data = Float64Builder::new();
534 for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
535 let self_val = self_state.value(field).map_err(Box::new)?;
536 let other_val = other_state.value(field).map_err(Box::new)?;
537
538 data.append_value(self_val - other_val);
539 }
540
541 record.push(Arc::new(data.finish()));
542 }
543
544 info!("Serialized {} states differences", self_states.len());
545
546 let mut metadata = HashMap::new();
548 metadata.insert(
549 "Purpose".to_string(),
550 "Trajectory difference data".to_string(),
551 );
552 if let Some(add_meta) = cfg.metadata {
553 for (k, v) in add_meta {
554 metadata.insert(k, v);
555 }
556 }
557
558 let props = pq_writer(Some(metadata));
559
560 let file = File::create(&path_buf)?;
561 let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
562
563 let batch = RecordBatch::try_new(schema, record)?;
564 writer.write(&batch)?;
565 writer.close()?;
566
567 let tock_time = Epoch::now().unwrap() - tick;
569 info!(
570 "Trajectory written to {} in {tock_time}",
571 path_buf.display()
572 );
573 Ok(path_buf)
574 }
575}
576
577impl<S: Interpolatable> ops::Add for Traj<S>
578where
579 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
580{
581 type Output = Result<Traj<S>, NyxError>;
582
583 fn add(self, other: Traj<S>) -> Self::Output {
585 &self + &other
586 }
587}
588
589impl<S: Interpolatable> ops::Add<&Traj<S>> for &Traj<S>
590where
591 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
592{
593 type Output = Result<Traj<S>, NyxError>;
594
595 fn add(self, other: &Traj<S>) -> Self::Output {
597 if self.first().frame() != other.first().frame() {
598 Err(NyxError::Trajectory {
599 source: TrajError::CreationError {
600 msg: format!(
601 "Frame mismatch in add operation: {} != {}",
602 self.first().frame(),
603 other.first().frame()
604 ),
605 },
606 })
607 } else {
608 if self.last().epoch() < other.first().epoch() {
609 let gap = other.first().epoch() - self.last().epoch();
610 warn!(
611 "Resulting merged trajectory will have a time-gap of {} starting at {}",
612 gap,
613 self.last().epoch()
614 );
615 }
616
617 let mut me = self.clone();
618 for state in &other
620 .states
621 .iter()
622 .copied()
623 .filter(|s| s.epoch() > self.last().epoch())
624 .collect::<Vec<S>>()
625 {
626 me.states.push(*state);
627 }
628 me.finalize();
629
630 Ok(me)
631 }
632 }
633}
634
635impl<S: Interpolatable> ops::AddAssign<&Traj<S>> for Traj<S>
636where
637 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
638{
639 fn add_assign(&mut self, rhs: &Self) {
645 *self = (self.clone() + rhs.clone()).unwrap();
646 }
647}
648
649impl<S: Interpolatable> fmt::Display for Traj<S>
650where
651 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
652{
653 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
654 if self.states.is_empty() {
655 write!(f, "Empty Trajectory!")
656 } else {
657 let dur = self.last().epoch() - self.first().epoch();
658 write!(
659 f,
660 "Trajectory {}in {} from {} to {} ({}, or {:.3} s) [{} states]",
661 match &self.name {
662 Some(name) => format!("of {name} "),
663 None => String::new(),
664 },
665 self.first().frame(),
666 self.first().epoch(),
667 self.last().epoch(),
668 dur,
669 dur.to_seconds(),
670 self.states.len()
671 )
672 }
673 }
674}
675
676impl<S: Interpolatable> fmt::Debug for Traj<S>
677where
678 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
679{
680 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
681 write!(f, "{self}",)
682 }
683}
684
685impl<S: Interpolatable> Default for Traj<S>
686where
687 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
688{
689 fn default() -> Self {
690 Self::new()
691 }
692}