1use super::traj_it::TrajIterator;
20use super::{ExportCfg, InterpolationSnafu, INTERPOLATION_SAMPLES};
21use super::{Interpolatable, TrajError};
22use crate::errors::NyxError;
23use crate::io::watermark::pq_writer;
24use crate::io::InputOutputError;
25use crate::linalg::allocator::Allocator;
26use crate::linalg::DefaultAllocator;
27use crate::md::prelude::{GuidanceMode, StateParameter};
28use crate::md::trajectory::smooth_state_diff_in_place;
29use crate::md::EventEvaluator;
30use crate::time::{Duration, Epoch, TimeSeries, TimeUnits};
31use anise::almanac::Almanac;
32use arrow::array::{Array, Float64Builder, StringBuilder};
33use arrow::datatypes::{DataType, Field, Schema};
34use arrow::record_batch::RecordBatch;
35use hifitime::TimeScale;
36use log::{info, warn};
37use parquet::arrow::ArrowWriter;
38use snafu::ResultExt;
39use std::collections::HashMap;
40use std::error::Error;
41use std::fmt;
42use std::fs::File;
43use std::iter::Iterator;
44use std::ops;
45use std::ops::Bound::{Excluded, Included, Unbounded};
46use std::path::{Path, PathBuf};
47use std::sync::Arc;
48
49#[derive(Clone, PartialEq)]
51pub struct Traj<S: Interpolatable>
52where
53 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
54{
55 pub name: Option<String>,
57 pub states: Vec<S>,
59}
60
61impl<S: Interpolatable> Traj<S>
62where
63 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
64{
65 pub fn new() -> Self {
66 Self {
67 name: None,
68 states: Vec::new(),
69 }
70 }
71 pub fn finalize(&mut self) {
73 self.states.dedup_by(|a, b| a.epoch().eq(&b.epoch()));
75 self.states.sort_by_key(|a| a.epoch());
77 }
78
79 pub fn at(&self, epoch: Epoch) -> Result<S, TrajError> {
81 if self.states.is_empty() || self.first().epoch() > epoch || self.last().epoch() < epoch {
82 return Err(TrajError::NoInterpolationData { epoch });
83 }
84 match self
85 .states
86 .binary_search_by(|state| state.epoch().cmp(&epoch))
87 {
88 Ok(idx) => {
89 Ok(self.states[idx])
91 }
92 Err(idx) => {
93 if idx == 0 || idx >= self.states.len() {
94 return Err(TrajError::NoInterpolationData { epoch });
97 }
98 let num_left = INTERPOLATION_SAMPLES / 2;
103
104 let mut first_idx = idx.saturating_sub(num_left);
106 let last_idx = self.states.len().min(first_idx + INTERPOLATION_SAMPLES);
107
108 if last_idx == self.states.len() {
110 first_idx = last_idx.saturating_sub(2 * num_left);
111 }
112
113 let mut states = Vec::with_capacity(last_idx - first_idx);
114 for idx in first_idx..last_idx {
115 states.push(self.states[idx]);
116 }
117
118 self.states[idx]
119 .interpolate(epoch, &states)
120 .context(InterpolationSnafu)
121 }
122 }
123 }
124
125 pub fn first(&self) -> &S {
127 self.states.first().unwrap()
129 }
130
131 pub fn last(&self) -> &S {
133 self.states.last().unwrap()
134 }
135
136 pub fn every(&self, step: Duration) -> TrajIterator<'_, S> {
138 self.every_between(step, self.first().epoch(), self.last().epoch())
139 }
140
141 pub fn every_between(&self, step: Duration, start: Epoch, end: Epoch) -> TrajIterator<'_, S> {
143 TrajIterator {
144 time_series: TimeSeries::inclusive(
145 start.max(self.first().epoch()),
146 end.min(self.last().epoch()),
147 step,
148 ),
149 traj: self,
150 }
151 }
152
153 pub fn filter_by_epoch<R: ops::RangeBounds<Epoch>>(mut self, bound: R) -> Self {
155 self.states = self
156 .states
157 .iter()
158 .copied()
159 .filter(|s| bound.contains(&s.epoch()))
160 .collect::<Vec<_>>();
161 self
162 }
163
164 pub fn filter_by_offset<R: ops::RangeBounds<Duration>>(self, bound: R) -> Self {
167 if self.states.is_empty() {
168 return self;
169 }
170 let start = match bound.start_bound() {
172 Unbounded => self.states.first().unwrap().epoch(),
173 Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
174 };
175
176 let end = match bound.end_bound() {
177 Unbounded => self.states.last().unwrap().epoch(),
178 Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
179 };
180
181 self.filter_by_epoch(start..=end)
182 }
183 pub fn to_parquet_simple<P: AsRef<Path>>(
185 &self,
186 path: P,
187 almanac: Arc<Almanac>,
188 ) -> Result<PathBuf, Box<dyn Error>> {
189 self.to_parquet(path, None, ExportCfg::default(), almanac)
190 }
191
192 pub fn to_parquet_with_cfg<P: AsRef<Path>>(
194 &self,
195 path: P,
196 cfg: ExportCfg,
197 almanac: Arc<Almanac>,
198 ) -> Result<PathBuf, Box<dyn Error>> {
199 self.to_parquet(path, None, cfg, almanac)
200 }
201
202 pub fn to_parquet_with_step<P: AsRef<Path>>(
204 &self,
205 path: P,
206 step: Duration,
207 almanac: Arc<Almanac>,
208 ) -> Result<(), Box<dyn Error>> {
209 self.to_parquet_with_cfg(
210 path,
211 ExportCfg {
212 step: Some(step),
213 ..Default::default()
214 },
215 almanac,
216 )?;
217
218 Ok(())
219 }
220
221 pub fn to_parquet<P: AsRef<Path>>(
223 &self,
224 path: P,
225 events: Option<Vec<&dyn EventEvaluator<S>>>,
226 cfg: ExportCfg,
227 almanac: Arc<Almanac>,
228 ) -> Result<PathBuf, Box<dyn Error>> {
229 let tick = Epoch::now().unwrap();
230 info!("Exporting trajectory to parquet file...");
231
232 let path_buf = cfg.actual_path(path);
234
235 let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
237
238 let frame = self.states[0].frame();
239 let more_meta = Some(vec![(
240 "Frame".to_string(),
241 serde_dhall::serialize(&frame)
242 .static_type_annotation()
243 .to_string()
244 .map_err(|e| {
245 Box::new(InputOutputError::SerializeDhall {
246 what: format!("frame `{frame}`"),
247 err: e.to_string(),
248 })
249 })?,
250 )]);
251
252 let mut fields = match cfg.fields {
253 Some(fields) => fields,
254 None => S::export_params(),
255 };
256
257 fields.retain(|param| self.first().value(*param).is_ok());
259
260 for field in &fields {
261 hdrs.push(field.to_field(more_meta.clone()));
262 }
263
264 if let Some(events) = events.as_ref() {
265 for event in events {
266 let field = Field::new(format!("{event}"), DataType::Float64, false);
267 hdrs.push(field);
268 }
269 }
270
271 let schema = Arc::new(Schema::new(hdrs));
273 let mut record: Vec<Arc<dyn Array>> = Vec::new();
274
275 let states = if cfg.start_epoch.is_some() || cfg.end_epoch.is_some() || cfg.step.is_some() {
277 let start = cfg.start_epoch.unwrap_or_else(|| self.first().epoch());
279 let end = cfg.end_epoch.unwrap_or_else(|| self.last().epoch());
280 let step = cfg.step.unwrap_or_else(|| 1.minutes());
281 self.every_between(step, start, end).collect::<Vec<S>>()
282 } else {
283 self.states.to_vec()
284 };
285
286 let mut utc_epoch = StringBuilder::new();
290 for s in &states {
291 utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
292 }
293 record.push(Arc::new(utc_epoch.finish()));
294
295 for field in fields {
297 if field == StateParameter::GuidanceMode {
298 let mut guid_mode = StringBuilder::new();
299 for s in &states {
300 guid_mode
301 .append_value(format!("{:?}", GuidanceMode::from(s.value(field).unwrap())));
302 }
303 record.push(Arc::new(guid_mode.finish()));
304 } else {
305 let mut data = Float64Builder::new();
306 for s in &states {
307 data.append_value(s.value(field).unwrap());
308 }
309 record.push(Arc::new(data.finish()));
310 }
311 }
312
313 info!(
314 "Serialized {} states from {} to {}",
315 states.len(),
316 states.first().unwrap().epoch(),
317 states.last().unwrap().epoch()
318 );
319
320 if let Some(events) = events {
322 info!("Evaluating {} event(s)", events.len());
323 for event in events {
324 let mut data = Float64Builder::new();
325 for s in &states {
326 data.append_value(event.eval(s, almanac.clone()).map_err(Box::new)?);
327 }
328 record.push(Arc::new(data.finish()));
329 }
330 }
331
332 let mut metadata = HashMap::new();
334 metadata.insert("Purpose".to_string(), "Trajectory data".to_string());
335 if let Some(add_meta) = cfg.metadata {
336 for (k, v) in add_meta {
337 metadata.insert(k, v);
338 }
339 }
340
341 let props = pq_writer(Some(metadata));
342
343 let file = File::create(&path_buf)?;
344 let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
345
346 let batch = RecordBatch::try_new(schema, record)?;
347 writer.write(&batch)?;
348 writer.close()?;
349
350 let tock_time = Epoch::now().unwrap() - tick;
352 info!(
353 "Trajectory written to {} in {tock_time}",
354 path_buf.display()
355 );
356 Ok(path_buf)
357 }
358
359 pub fn resample(&self, step: Duration) -> Result<Self, NyxError> {
362 if self.states.is_empty() {
363 return Err(NyxError::Trajectory {
364 source: TrajError::CreationError {
365 msg: "No trajectory to convert".to_string(),
366 },
367 });
368 }
369
370 let mut traj = Self::new();
371 for state in self.every(step) {
372 traj.states.push(state);
373 }
374
375 traj.finalize();
376
377 Ok(traj)
378 }
379
380 pub fn rebuild(&self, epochs: &[Epoch]) -> Result<Self, NyxError> {
383 if self.states.is_empty() {
384 return Err(NyxError::Trajectory {
385 source: TrajError::CreationError {
386 msg: "No trajectory to convert".to_string(),
387 },
388 });
389 }
390
391 let mut traj = Self::new();
392 for epoch in epochs {
393 traj.states.push(self.at(*epoch)?);
394 }
395
396 traj.finalize();
397
398 Ok(traj)
399 }
400
401 pub fn ric_diff_to_parquet<P: AsRef<Path>>(
406 &self,
407 other: &Self,
408 path: P,
409 cfg: ExportCfg,
410 ) -> Result<PathBuf, Box<dyn Error>> {
411 let tick = Epoch::now().unwrap();
412 info!("Exporting trajectory to parquet file...");
413
414 let path_buf = cfg.actual_path(path);
416
417 let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
419
420 for coord in ["X", "Y", "Z"] {
422 let mut meta = HashMap::new();
423 meta.insert("unit".to_string(), "km".to_string());
424
425 let field = Field::new(
426 format!("Delta {coord} (RIC) (km)"),
427 DataType::Float64,
428 false,
429 )
430 .with_metadata(meta);
431
432 hdrs.push(field);
433 }
434
435 for coord in ["x", "y", "z"] {
436 let mut meta = HashMap::new();
437 meta.insert("unit".to_string(), "km/s".to_string());
438
439 let field = Field::new(
440 format!("Delta V{coord} (RIC) (km/s)"),
441 DataType::Float64,
442 false,
443 )
444 .with_metadata(meta);
445
446 hdrs.push(field);
447 }
448
449 let frame = self.states[0].frame();
450 let more_meta = Some(vec![(
451 "Frame".to_string(),
452 serde_dhall::serialize(&frame)
453 .static_type_annotation()
454 .to_string()
455 .unwrap_or(frame.to_string()),
456 )]);
457
458 let mut cfg = cfg;
459
460 let mut fields = match cfg.fields {
461 Some(fields) => fields,
462 None => S::export_params(),
463 };
464
465 fields.retain(|param| {
467 param != &StateParameter::GuidanceMode && self.first().value(*param).is_ok()
468 });
469
470 for field in &fields {
471 hdrs.push(field.to_field(more_meta.clone()));
472 }
473
474 let schema = Arc::new(Schema::new(hdrs));
476 let mut record: Vec<Arc<dyn Array>> = Vec::new();
477
478 cfg.start_epoch = if self.first().epoch() > other.first().epoch() {
480 Some(self.first().epoch())
481 } else {
482 Some(other.first().epoch())
483 };
484
485 cfg.end_epoch = if self.last().epoch() > other.last().epoch() {
486 Some(other.last().epoch())
487 } else {
488 Some(self.last().epoch())
489 };
490
491 let step = cfg.step.unwrap_or_else(|| 1.minutes());
493 let self_states = self
494 .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
495 .collect::<Vec<S>>();
496
497 let other_states = other
498 .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
499 .collect::<Vec<S>>();
500
501 let mut ric_diff = Vec::with_capacity(other_states.len());
503 for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
504 let self_orbit = self_state.orbit();
505 let other_orbit = other_state.orbit();
506
507 let this_ric_diff = self_orbit.ric_difference(&other_orbit).map_err(Box::new)?;
508
509 ric_diff.push(this_ric_diff);
510 }
511
512 smooth_state_diff_in_place(&mut ric_diff, if other_states.len() > 5 { 5 } else { 1 });
513
514 let mut utc_epoch = StringBuilder::new();
518 for s in &self_states {
519 utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
520 }
521 record.push(Arc::new(utc_epoch.finish()));
522
523 for coord_no in 0..6 {
525 let mut data = Float64Builder::new();
526 for this_ric_dff in &ric_diff {
527 data.append_value(this_ric_dff.to_cartesian_pos_vel()[coord_no]);
528 }
529 record.push(Arc::new(data.finish()));
530 }
531
532 for field in fields {
534 let mut data = Float64Builder::new();
535 for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
536 let self_val = self_state.value(field).map_err(Box::new)?;
537 let other_val = other_state.value(field).map_err(Box::new)?;
538
539 data.append_value(self_val - other_val);
540 }
541
542 record.push(Arc::new(data.finish()));
543 }
544
545 info!("Serialized {} states differences", self_states.len());
546
547 let mut metadata = HashMap::new();
549 metadata.insert(
550 "Purpose".to_string(),
551 "Trajectory difference data".to_string(),
552 );
553 if let Some(add_meta) = cfg.metadata {
554 for (k, v) in add_meta {
555 metadata.insert(k, v);
556 }
557 }
558
559 let props = pq_writer(Some(metadata));
560
561 let file = File::create(&path_buf)?;
562 let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
563
564 let batch = RecordBatch::try_new(schema, record)?;
565 writer.write(&batch)?;
566 writer.close()?;
567
568 let tock_time = Epoch::now().unwrap() - tick;
570 info!(
571 "Trajectory written to {} in {tock_time}",
572 path_buf.display()
573 );
574 Ok(path_buf)
575 }
576}
577
578impl<S: Interpolatable> ops::Add for Traj<S>
579where
580 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
581{
582 type Output = Result<Traj<S>, NyxError>;
583
584 fn add(self, other: Traj<S>) -> Self::Output {
586 &self + &other
587 }
588}
589
590impl<S: Interpolatable> ops::Add<&Traj<S>> for &Traj<S>
591where
592 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
593{
594 type Output = Result<Traj<S>, NyxError>;
595
596 fn add(self, other: &Traj<S>) -> Self::Output {
598 if self.first().frame() != other.first().frame() {
599 Err(NyxError::Trajectory {
600 source: TrajError::CreationError {
601 msg: format!(
602 "Frame mismatch in add operation: {} != {}",
603 self.first().frame(),
604 other.first().frame()
605 ),
606 },
607 })
608 } else {
609 if self.last().epoch() < other.first().epoch() {
610 let gap = other.first().epoch() - self.last().epoch();
611 warn!(
612 "Resulting merged trajectory will have a time-gap of {} starting at {}",
613 gap,
614 self.last().epoch()
615 );
616 }
617
618 let mut me = self.clone();
619 for state in &other
621 .states
622 .iter()
623 .copied()
624 .filter(|s| s.epoch() > self.last().epoch())
625 .collect::<Vec<S>>()
626 {
627 me.states.push(*state);
628 }
629 me.finalize();
630
631 Ok(me)
632 }
633 }
634}
635
636impl<S: Interpolatable> ops::AddAssign<&Traj<S>> for Traj<S>
637where
638 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
639{
640 fn add_assign(&mut self, rhs: &Self) {
646 *self = (self.clone() + rhs.clone()).unwrap();
647 }
648}
649
650impl<S: Interpolatable> fmt::Display for Traj<S>
651where
652 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
653{
654 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
655 if self.states.is_empty() {
656 write!(f, "Empty Trajectory!")
657 } else {
658 let dur = self.last().epoch() - self.first().epoch();
659 write!(
660 f,
661 "Trajectory {}in {} from {} to {} ({}, or {:.3} s) [{} states]",
662 match &self.name {
663 Some(name) => format!("of {name} "),
664 None => String::new(),
665 },
666 self.first().frame(),
667 self.first().epoch(),
668 self.last().epoch(),
669 dur,
670 dur.to_seconds(),
671 self.states.len()
672 )
673 }
674 }
675}
676
677impl<S: Interpolatable> fmt::Debug for Traj<S>
678where
679 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
680{
681 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
682 write!(f, "{self}",)
683 }
684}
685
686impl<S: Interpolatable> Default for Traj<S>
687where
688 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
689{
690 fn default() -> Self {
691 Self::new()
692 }
693}