1use super::traj_it::TrajIterator;
20use super::{ExportCfg, InterpolationSnafu, INTERPOLATION_SAMPLES};
21use super::{Interpolatable, TrajError};
22use crate::errors::NyxError;
23use crate::io::watermark::pq_writer;
24use crate::io::InputOutputError;
25use crate::linalg::allocator::Allocator;
26use crate::linalg::DefaultAllocator;
27use crate::md::prelude::{GuidanceMode, StateParameter};
28use crate::md::trajectory::smooth_state_diff_in_place;
29use crate::time::{Duration, Epoch, TimeSeries, TimeUnits};
30use anise::analysis::specs::StateSpecTrait;
31use anise::analysis::AnalysisError;
32use anise::astro::orbit::Orbit;
33use anise::prelude::{Aberration, Almanac};
34use arrow::array::{Array, Float64Builder, StringBuilder};
35use arrow::datatypes::{DataType, Field, Schema};
36use arrow::record_batch::RecordBatch;
37use hifitime::TimeScale;
38use log::{info, warn};
39use parquet::arrow::ArrowWriter;
40use snafu::ResultExt;
41use std::collections::HashMap;
42use std::error::Error;
43use std::fmt;
44use std::fs::File;
45use std::iter::Iterator;
46use std::ops;
47use std::ops::Bound::{Excluded, Included, Unbounded};
48use std::path::{Path, PathBuf};
49use std::sync::Arc;
50
51#[derive(Clone, PartialEq)]
53pub struct Traj<S: Interpolatable>
54where
55 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
56{
57 pub name: Option<String>,
59 pub states: Vec<S>,
61}
62
63impl<S: Interpolatable> Traj<S>
64where
65 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
66{
67 pub fn new() -> Self {
68 Self {
69 name: None,
70 states: Vec::new(),
71 }
72 }
73 pub fn finalize(&mut self) {
75 self.states.dedup_by(|a, b| a.epoch().eq(&b.epoch()));
77 self.states.sort_by_key(|a| a.epoch());
79 }
80
81 pub fn at(&self, epoch: Epoch) -> Result<S, TrajError> {
83 if self.states.is_empty() || self.first().epoch() > epoch || self.last().epoch() < epoch {
84 return Err(TrajError::NoInterpolationData { epoch });
85 }
86 match self
87 .states
88 .binary_search_by(|state| state.epoch().cmp(&epoch))
89 {
90 Ok(idx) => {
91 Ok(self.states[idx])
93 }
94 Err(idx) => {
95 if idx == 0 || idx >= self.states.len() {
96 return Err(TrajError::NoInterpolationData { epoch });
99 }
100 let num_left = INTERPOLATION_SAMPLES / 2;
105
106 let mut first_idx = idx.saturating_sub(num_left);
108 let last_idx = self.states.len().min(first_idx + INTERPOLATION_SAMPLES);
109
110 if last_idx == self.states.len() {
112 first_idx = last_idx.saturating_sub(2 * num_left);
113 }
114
115 let mut states = Vec::with_capacity(last_idx - first_idx);
116 for idx in first_idx..last_idx {
117 states.push(self.states[idx]);
118 }
119
120 self.states[idx]
121 .interpolate(epoch, &states)
122 .context(InterpolationSnafu)
123 }
124 }
125 }
126
127 pub fn first(&self) -> &S {
129 self.states.first().unwrap()
131 }
132
133 pub fn last(&self) -> &S {
135 self.states.last().unwrap()
136 }
137
138 pub fn start_epoch(&self) -> Epoch {
139 self.first().epoch()
140 }
141
142 pub fn end_epoch(&self) -> Epoch {
143 self.last().epoch()
144 }
145
146 pub fn every(&self, step: Duration) -> TrajIterator<'_, S> {
148 self.every_between(step, self.first().epoch(), self.last().epoch())
149 }
150
151 pub fn every_between(&self, step: Duration, start: Epoch, end: Epoch) -> TrajIterator<'_, S> {
153 TrajIterator {
154 time_series: TimeSeries::inclusive(
155 start.max(self.first().epoch()),
156 end.min(self.last().epoch()),
157 step,
158 ),
159 traj: self,
160 }
161 }
162
163 pub fn filter_by_epoch<R: ops::RangeBounds<Epoch>>(mut self, bound: R) -> Self {
165 self.states = self
166 .states
167 .iter()
168 .copied()
169 .filter(|s| bound.contains(&s.epoch()))
170 .collect::<Vec<_>>();
171 self
172 }
173
174 pub fn filter_by_offset<R: ops::RangeBounds<Duration>>(self, bound: R) -> Self {
177 if self.states.is_empty() {
178 return self;
179 }
180 let start = match bound.start_bound() {
182 Unbounded => self.states.first().unwrap().epoch(),
183 Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
184 };
185
186 let end = match bound.end_bound() {
187 Unbounded => self.states.last().unwrap().epoch(),
188 Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
189 };
190
191 self.filter_by_epoch(start..=end)
192 }
193 pub fn to_parquet_simple<P: AsRef<Path>>(&self, path: P) -> Result<PathBuf, Box<dyn Error>> {
195 self.to_parquet(path, ExportCfg::default())
196 }
197
198 pub fn to_parquet_with_cfg<P: AsRef<Path>>(
200 &self,
201 path: P,
202 cfg: ExportCfg,
203 ) -> Result<PathBuf, Box<dyn Error>> {
204 self.to_parquet(path, cfg)
205 }
206
207 pub fn to_parquet_with_step<P: AsRef<Path>>(
209 &self,
210 path: P,
211 step: Duration,
212 ) -> Result<(), Box<dyn Error>> {
213 self.to_parquet_with_cfg(
214 path,
215 ExportCfg {
216 step: Some(step),
217 ..Default::default()
218 },
219 )?;
220
221 Ok(())
222 }
223
224 pub fn to_parquet<P: AsRef<Path>>(
226 &self,
227 path: P,
228 cfg: ExportCfg,
229 ) -> Result<PathBuf, Box<dyn Error>> {
230 let tick = Epoch::now().unwrap();
231 info!("Exporting trajectory to parquet file...");
232
233 let path_buf = cfg.actual_path(path);
235
236 let states = if cfg.start_epoch.is_some() || cfg.end_epoch.is_some() || cfg.step.is_some() {
238 let start = cfg.start_epoch.unwrap_or_else(|| self.first().epoch());
240 let end = cfg.end_epoch.unwrap_or_else(|| self.last().epoch());
241 let step = cfg.step.unwrap_or_else(|| 1.minutes());
242 self.every_between(step, start, end).collect::<Vec<S>>()
243 } else {
244 self.states.to_vec()
245 };
246
247 let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
249
250 let frame = self.states[0].frame();
251 let more_meta = Some(vec![(
252 "Frame".to_string(),
253 serde_dhall::serialize(&frame)
254 .static_type_annotation()
255 .to_string()
256 .map_err(|e| {
257 Box::new(InputOutputError::SerializeDhall {
258 what: format!("frame `{frame}`"),
259 err: e.to_string(),
260 })
261 })?,
262 )]);
263
264 let requested_fields = match cfg.fields {
265 Some(fields) => fields,
266 None => S::export_params(),
267 };
268
269 let mut fields = Vec::new();
270 let mut field_nullable = Vec::new();
271 for field in requested_fields {
272 let mut any_ok = false;
273 let mut any_err = false;
274 for state in &states {
275 if state.value(field).is_ok() {
276 any_ok = true;
277 } else {
278 any_err = true;
279 }
280 }
281
282 if any_ok {
283 fields.push(field);
284 field_nullable.push(any_err);
285 }
286 }
287
288 for (field, nullable) in fields.iter().zip(field_nullable.iter().copied()) {
289 hdrs.push(field.to_field(more_meta.clone()).with_nullable(nullable));
290 }
291
292 let schema = Arc::new(Schema::new(hdrs));
294 let mut record: Vec<Arc<dyn Array>> = Vec::new();
295
296 let mut utc_epoch = StringBuilder::new();
300 for s in &states {
301 utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
302 }
303 record.push(Arc::new(utc_epoch.finish()));
304
305 for field in fields {
307 if field == StateParameter::GuidanceMode() {
308 let mut guid_mode = StringBuilder::new();
309 for s in &states {
310 match s.value(field) {
311 Ok(value) => {
312 guid_mode.append_value(format!("{:?}", GuidanceMode::from(value)));
313 }
314 Err(_) => guid_mode.append_null(),
315 }
316 }
317 record.push(Arc::new(guid_mode.finish()));
318 } else {
319 let mut data = Float64Builder::new();
320 for s in &states {
321 match s.value(field) {
322 Ok(value) => data.append_value(value),
323 Err(_) => data.append_null(),
324 };
325 }
326 record.push(Arc::new(data.finish()));
327 }
328 }
329
330 info!(
331 "Serialized {} states from {} to {}",
332 states.len(),
333 states.first().unwrap().epoch(),
334 states.last().unwrap().epoch()
335 );
336
337 let mut metadata = HashMap::new();
339 metadata.insert("Purpose".to_string(), "Trajectory data".to_string());
340 if let Some(add_meta) = cfg.metadata {
341 for (k, v) in add_meta {
342 metadata.insert(k, v);
343 }
344 }
345
346 let props = pq_writer(Some(metadata));
347
348 let file = File::create(&path_buf)?;
349 let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
350
351 let batch = RecordBatch::try_new(schema, record)?;
352 writer.write(&batch)?;
353 writer.close()?;
354
355 let tock_time = Epoch::now().unwrap() - tick;
357 info!(
358 "Trajectory written to {} in {tock_time}",
359 path_buf.display()
360 );
361 Ok(path_buf)
362 }
363
364 pub fn resample(&self, step: Duration) -> Result<Self, NyxError> {
367 if self.states.is_empty() {
368 return Err(NyxError::Trajectory {
369 source: TrajError::CreationError {
370 msg: "No trajectory to convert".to_string(),
371 },
372 });
373 }
374
375 let mut traj = Self::new();
376 for state in self.every(step) {
377 traj.states.push(state);
378 }
379
380 traj.finalize();
381
382 Ok(traj)
383 }
384
385 pub fn rebuild(&self, epochs: &[Epoch]) -> Result<Self, NyxError> {
388 if self.states.is_empty() {
389 return Err(NyxError::Trajectory {
390 source: TrajError::CreationError {
391 msg: "No trajectory to convert".to_string(),
392 },
393 });
394 }
395
396 let mut traj = Self::new();
397 for epoch in epochs {
398 traj.states.push(self.at(*epoch)?);
399 }
400
401 traj.finalize();
402
403 Ok(traj)
404 }
405
406 pub fn ric_diff_to_parquet<P: AsRef<Path>>(
411 &self,
412 other: &Self,
413 path: P,
414 cfg: ExportCfg,
415 ) -> Result<PathBuf, Box<dyn Error>> {
416 let tick = Epoch::now().unwrap();
417 info!("Exporting trajectory to parquet file...");
418
419 let path_buf = cfg.actual_path(path);
421
422 let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
424
425 for coord in ["X", "Y", "Z"] {
427 let mut meta = HashMap::new();
428 meta.insert("unit".to_string(), "km".to_string());
429
430 let field = Field::new(
431 format!("Delta {coord} (RIC) (km)"),
432 DataType::Float64,
433 false,
434 )
435 .with_metadata(meta);
436
437 hdrs.push(field);
438 }
439
440 for coord in ["x", "y", "z"] {
441 let mut meta = HashMap::new();
442 meta.insert("unit".to_string(), "km/s".to_string());
443
444 let field = Field::new(
445 format!("Delta V{coord} (RIC) (km/s)"),
446 DataType::Float64,
447 false,
448 )
449 .with_metadata(meta);
450
451 hdrs.push(field);
452 }
453
454 let frame = self.states[0].frame();
455 let more_meta = Some(vec![(
456 "Frame".to_string(),
457 serde_dhall::serialize(&frame)
458 .static_type_annotation()
459 .to_string()
460 .unwrap_or(frame.to_string()),
461 )]);
462
463 let mut cfg = cfg;
464
465 let mut fields = match cfg.fields {
466 Some(fields) => fields,
467 None => S::export_params(),
468 };
469
470 fields.retain(|param| {
472 param != &StateParameter::GuidanceMode() && self.first().value(*param).is_ok()
473 });
474
475 for field in &fields {
476 hdrs.push(field.to_field(more_meta.clone()));
477 }
478
479 let schema = Arc::new(Schema::new(hdrs));
481 let mut record: Vec<Arc<dyn Array>> = Vec::new();
482
483 cfg.start_epoch = if self.first().epoch() > other.first().epoch() {
485 Some(self.first().epoch())
486 } else {
487 Some(other.first().epoch())
488 };
489
490 cfg.end_epoch = if self.last().epoch() > other.last().epoch() {
491 Some(other.last().epoch())
492 } else {
493 Some(self.last().epoch())
494 };
495
496 let step = cfg.step.unwrap_or_else(|| 1.minutes());
498 let self_states = self
499 .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
500 .collect::<Vec<S>>();
501
502 let other_states = other
503 .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
504 .collect::<Vec<S>>();
505
506 let mut ric_diff = Vec::with_capacity(other_states.len());
508 for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
509 let self_orbit = self_state.orbit();
510 let other_orbit = other_state.orbit();
511
512 let this_ric_diff = self_orbit.ric_difference(&other_orbit).map_err(Box::new)?;
513
514 ric_diff.push(this_ric_diff);
515 }
516
517 smooth_state_diff_in_place(&mut ric_diff, if other_states.len() > 5 { 5 } else { 1 });
518
519 let mut utc_epoch = StringBuilder::new();
523 for s in &self_states {
524 utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
525 }
526 record.push(Arc::new(utc_epoch.finish()));
527
528 for coord_no in 0..6 {
530 let mut data = Float64Builder::new();
531 for this_ric_dff in &ric_diff {
532 data.append_value(this_ric_dff.to_cartesian_pos_vel()[coord_no]);
533 }
534 record.push(Arc::new(data.finish()));
535 }
536
537 for field in fields {
539 let mut data = Float64Builder::new();
540 for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
541 let self_val = self_state.value(field).map_err(Box::new)?;
542 let other_val = other_state.value(field).map_err(Box::new)?;
543
544 data.append_value(self_val - other_val);
545 }
546
547 record.push(Arc::new(data.finish()));
548 }
549
550 info!("Serialized {} states differences", self_states.len());
551
552 let mut metadata = HashMap::new();
554 metadata.insert(
555 "Purpose".to_string(),
556 "Trajectory difference data".to_string(),
557 );
558 if let Some(add_meta) = cfg.metadata {
559 for (k, v) in add_meta {
560 metadata.insert(k, v);
561 }
562 }
563
564 let props = pq_writer(Some(metadata));
565
566 let file = File::create(&path_buf)?;
567 let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
568
569 let batch = RecordBatch::try_new(schema, record)?;
570 writer.write(&batch)?;
571 writer.close()?;
572
573 let tock_time = Epoch::now().unwrap() - tick;
575 info!(
576 "Trajectory written to {} in {tock_time}",
577 path_buf.display()
578 );
579 Ok(path_buf)
580 }
581}
582
583impl<S: Interpolatable> ops::Add for Traj<S>
584where
585 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
586{
587 type Output = Result<Traj<S>, NyxError>;
588
589 fn add(self, other: Traj<S>) -> Self::Output {
591 &self + &other
592 }
593}
594
595impl<S: Interpolatable> ops::Add<&Traj<S>> for &Traj<S>
596where
597 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
598{
599 type Output = Result<Traj<S>, NyxError>;
600
601 fn add(self, other: &Traj<S>) -> Self::Output {
603 if self.first().frame() != other.first().frame() {
604 Err(NyxError::Trajectory {
605 source: TrajError::CreationError {
606 msg: format!(
607 "Frame mismatch in add operation: {} != {}",
608 self.first().frame(),
609 other.first().frame()
610 ),
611 },
612 })
613 } else {
614 if self.last().epoch() < other.first().epoch() {
615 let gap = other.first().epoch() - self.last().epoch();
616 warn!(
617 "Resulting merged trajectory will have a time-gap of {} starting at {}",
618 gap,
619 self.last().epoch()
620 );
621 }
622
623 let mut me = self.clone();
624 for state in &other
626 .states
627 .iter()
628 .copied()
629 .filter(|s| s.epoch() > self.last().epoch())
630 .collect::<Vec<S>>()
631 {
632 me.states.push(*state);
633 }
634 me.finalize();
635
636 Ok(me)
637 }
638 }
639}
640
641impl<S: Interpolatable> ops::AddAssign<&Traj<S>> for Traj<S>
642where
643 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
644{
645 fn add_assign(&mut self, rhs: &Self) {
651 *self = (self.clone() + rhs.clone()).unwrap();
652 }
653}
654
655impl<S: Interpolatable> fmt::Display for Traj<S>
656where
657 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
658{
659 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
660 if self.states.is_empty() {
661 write!(f, "Empty Trajectory!")
662 } else {
663 let dur = self.last().epoch() - self.first().epoch();
664 write!(
665 f,
666 "Trajectory {}in {} from {} to {} ({}, or {:.3} s) [{} states]",
667 match &self.name {
668 Some(name) => format!("of {name} "),
669 None => String::new(),
670 },
671 self.first().frame(),
672 self.first().epoch(),
673 self.last().epoch(),
674 dur,
675 dur.to_seconds(),
676 self.states.len()
677 )
678 }
679 }
680}
681
682impl<S: Interpolatable> fmt::Debug for Traj<S>
683where
684 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
685{
686 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
687 write!(f, "{self}",)
688 }
689}
690
691impl<S: Interpolatable> Default for Traj<S>
692where
693 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
694{
695 fn default() -> Self {
696 Self::new()
697 }
698}
699
700impl<S: Interpolatable> StateSpecTrait for Traj<S>
701where
702 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
703{
704 fn ab_corr(&self) -> Option<Aberration> {
705 None
706 }
707
708 fn evaluate(&self, epoch: Epoch, _almanac: &Almanac) -> Result<Orbit, AnalysisError> {
709 self.at(epoch)
710 .map(|state| state.orbit())
711 .map_err(|e| AnalysisError::GenericAnalysisError {
712 err: format!("{e}"),
713 })
714 }
715}