1use super::traj_it::TrajIterator;
20use super::{ExportCfg, InterpolationSnafu, INTERPOLATION_SAMPLES};
21use super::{Interpolatable, TrajError};
22use crate::errors::NyxError;
23use crate::io::watermark::pq_writer;
24use crate::io::InputOutputError;
25use crate::linalg::allocator::Allocator;
26use crate::linalg::DefaultAllocator;
27use crate::md::prelude::{GuidanceMode, StateParameter};
28use crate::md::trajectory::smooth_state_diff_in_place;
29use crate::md::EventEvaluator;
30use crate::time::{Duration, Epoch, TimeSeries, TimeUnits};
31use anise::almanac::Almanac;
32use arrow::array::{Array, Float64Builder, StringBuilder};
33use arrow::datatypes::{DataType, Field, Schema};
34use arrow::record_batch::RecordBatch;
35use hifitime::TimeScale;
36use parquet::arrow::ArrowWriter;
37use snafu::ResultExt;
38use std::collections::HashMap;
39use std::error::Error;
40use std::fmt;
41use std::fs::File;
42use std::iter::Iterator;
43use std::ops;
44use std::path::{Path, PathBuf};
45use std::sync::Arc;
46
47#[derive(Clone, PartialEq)]
49pub struct Traj<S: Interpolatable>
50where
51 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
52{
53 pub name: Option<String>,
55 pub states: Vec<S>,
57}
58
59impl<S: Interpolatable> Traj<S>
60where
61 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
62{
63 pub fn new() -> Self {
64 Self {
65 name: None,
66 states: Vec::new(),
67 }
68 }
69 pub fn finalize(&mut self) {
71 self.states.dedup_by(|a, b| a.epoch().eq(&b.epoch()));
73 self.states.sort_by_key(|a| a.epoch());
75 }
76
77 pub fn at(&self, epoch: Epoch) -> Result<S, TrajError> {
79 if self.states.is_empty() || self.first().epoch() > epoch || self.last().epoch() < epoch {
80 return Err(TrajError::NoInterpolationData { epoch });
81 }
82 match self
83 .states
84 .binary_search_by(|state| state.epoch().cmp(&epoch))
85 {
86 Ok(idx) => {
87 Ok(self.states[idx])
89 }
90 Err(idx) => {
91 if idx == 0 || idx >= self.states.len() {
92 return Err(TrajError::NoInterpolationData { epoch });
95 }
96 let num_left = INTERPOLATION_SAMPLES / 2;
101
102 let mut first_idx = idx.saturating_sub(num_left);
104 let last_idx = self.states.len().min(first_idx + INTERPOLATION_SAMPLES);
105
106 if last_idx == self.states.len() {
108 first_idx = last_idx.saturating_sub(2 * num_left);
109 }
110
111 let mut states = Vec::with_capacity(last_idx - first_idx);
112 for idx in first_idx..last_idx {
113 states.push(self.states[idx]);
114 }
115
116 self.states[idx]
117 .interpolate(epoch, &states)
118 .context(InterpolationSnafu)
119 }
120 }
121 }
122
123 pub fn first(&self) -> &S {
125 self.states.first().unwrap()
127 }
128
129 pub fn last(&self) -> &S {
131 self.states.last().unwrap()
132 }
133
134 pub fn every(&self, step: Duration) -> TrajIterator<'_, S> {
136 self.every_between(step, self.first().epoch(), self.last().epoch())
137 }
138
139 pub fn every_between(&self, step: Duration, start: Epoch, end: Epoch) -> TrajIterator<'_, S> {
141 TrajIterator {
142 time_series: TimeSeries::inclusive(
143 start.max(self.first().epoch()),
144 end.min(self.last().epoch()),
145 step,
146 ),
147 traj: self,
148 }
149 }
150
151 pub fn to_parquet_simple<P: AsRef<Path>>(
153 &self,
154 path: P,
155 almanac: Arc<Almanac>,
156 ) -> Result<PathBuf, Box<dyn Error>> {
157 self.to_parquet(path, None, ExportCfg::default(), almanac)
158 }
159
160 pub fn to_parquet_with_cfg<P: AsRef<Path>>(
162 &self,
163 path: P,
164 cfg: ExportCfg,
165 almanac: Arc<Almanac>,
166 ) -> Result<PathBuf, Box<dyn Error>> {
167 self.to_parquet(path, None, cfg, almanac)
168 }
169
170 pub fn to_parquet_with_step<P: AsRef<Path>>(
172 &self,
173 path: P,
174 step: Duration,
175 almanac: Arc<Almanac>,
176 ) -> Result<(), Box<dyn Error>> {
177 self.to_parquet_with_cfg(
178 path,
179 ExportCfg {
180 step: Some(step),
181 ..Default::default()
182 },
183 almanac,
184 )?;
185
186 Ok(())
187 }
188
189 pub fn to_parquet<P: AsRef<Path>>(
191 &self,
192 path: P,
193 events: Option<Vec<&dyn EventEvaluator<S>>>,
194 cfg: ExportCfg,
195 almanac: Arc<Almanac>,
196 ) -> Result<PathBuf, Box<dyn Error>> {
197 let tick = Epoch::now().unwrap();
198 info!("Exporting trajectory to parquet file...");
199
200 let path_buf = cfg.actual_path(path);
202
203 let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
205
206 let frame = self.states[0].frame();
207 let more_meta = Some(vec![(
208 "Frame".to_string(),
209 serde_dhall::serialize(&frame)
210 .static_type_annotation()
211 .to_string()
212 .map_err(|e| {
213 Box::new(InputOutputError::SerializeDhall {
214 what: format!("frame `{frame}`"),
215 err: e.to_string(),
216 })
217 })?,
218 )]);
219
220 let mut fields = match cfg.fields {
221 Some(fields) => fields,
222 None => S::export_params(),
223 };
224
225 fields.retain(|param| self.first().value(*param).is_ok());
227
228 for field in &fields {
229 hdrs.push(field.to_field(more_meta.clone()));
230 }
231
232 if let Some(events) = events.as_ref() {
233 for event in events {
234 let field = Field::new(format!("{event}"), DataType::Float64, false);
235 hdrs.push(field);
236 }
237 }
238
239 let schema = Arc::new(Schema::new(hdrs));
241 let mut record: Vec<Arc<dyn Array>> = Vec::new();
242
243 let states = if cfg.start_epoch.is_some() || cfg.end_epoch.is_some() || cfg.step.is_some() {
245 let start = cfg.start_epoch.unwrap_or_else(|| self.first().epoch());
247 let end = cfg.end_epoch.unwrap_or_else(|| self.last().epoch());
248 let step = cfg.step.unwrap_or_else(|| 1.minutes());
249 self.every_between(step, start, end).collect::<Vec<S>>()
250 } else {
251 self.states.to_vec()
252 };
253
254 let mut utc_epoch = StringBuilder::new();
258 for s in &states {
259 utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
260 }
261 record.push(Arc::new(utc_epoch.finish()));
262
263 for field in fields {
265 if field == StateParameter::GuidanceMode {
266 let mut guid_mode = StringBuilder::new();
267 for s in &states {
268 guid_mode
269 .append_value(format!("{:?}", GuidanceMode::from(s.value(field).unwrap())));
270 }
271 record.push(Arc::new(guid_mode.finish()));
272 } else {
273 let mut data = Float64Builder::new();
274 for s in &states {
275 data.append_value(s.value(field).unwrap());
276 }
277 record.push(Arc::new(data.finish()));
278 }
279 }
280
281 info!(
282 "Serialized {} states from {} to {}",
283 states.len(),
284 states.first().unwrap().epoch(),
285 states.last().unwrap().epoch()
286 );
287
288 if let Some(events) = events {
290 info!("Evaluating {} event(s)", events.len());
291 for event in events {
292 let mut data = Float64Builder::new();
293 for s in &states {
294 data.append_value(event.eval(s, almanac.clone()).map_err(Box::new)?);
295 }
296 record.push(Arc::new(data.finish()));
297 }
298 }
299
300 let mut metadata = HashMap::new();
302 metadata.insert("Purpose".to_string(), "Trajectory data".to_string());
303 if let Some(add_meta) = cfg.metadata {
304 for (k, v) in add_meta {
305 metadata.insert(k, v);
306 }
307 }
308
309 let props = pq_writer(Some(metadata));
310
311 let file = File::create(&path_buf)?;
312 let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
313
314 let batch = RecordBatch::try_new(schema, record)?;
315 writer.write(&batch)?;
316 writer.close()?;
317
318 let tock_time = Epoch::now().unwrap() - tick;
320 info!(
321 "Trajectory written to {} in {tock_time}",
322 path_buf.display()
323 );
324 Ok(path_buf)
325 }
326
327 pub fn resample(&self, step: Duration) -> Result<Self, NyxError> {
330 if self.states.is_empty() {
331 return Err(NyxError::Trajectory {
332 source: TrajError::CreationError {
333 msg: "No trajectory to convert".to_string(),
334 },
335 });
336 }
337
338 let mut traj = Self::new();
339 for state in self.every(step) {
340 traj.states.push(state);
341 }
342
343 traj.finalize();
344
345 Ok(traj)
346 }
347
348 pub fn rebuild(&self, epochs: &[Epoch]) -> Result<Self, NyxError> {
351 if self.states.is_empty() {
352 return Err(NyxError::Trajectory {
353 source: TrajError::CreationError {
354 msg: "No trajectory to convert".to_string(),
355 },
356 });
357 }
358
359 let mut traj = Self::new();
360 for epoch in epochs {
361 traj.states.push(self.at(*epoch)?);
362 }
363
364 traj.finalize();
365
366 Ok(traj)
367 }
368
369 pub fn ric_diff_to_parquet<P: AsRef<Path>>(
374 &self,
375 other: &Self,
376 path: P,
377 cfg: ExportCfg,
378 ) -> Result<PathBuf, Box<dyn Error>> {
379 let tick = Epoch::now().unwrap();
380 info!("Exporting trajectory to parquet file...");
381
382 let path_buf = cfg.actual_path(path);
384
385 let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
387
388 for coord in ["X", "Y", "Z"] {
390 let mut meta = HashMap::new();
391 meta.insert("unit".to_string(), "km".to_string());
392
393 let field = Field::new(
394 format!("Delta {coord} (RIC) (km)"),
395 DataType::Float64,
396 false,
397 )
398 .with_metadata(meta);
399
400 hdrs.push(field);
401 }
402
403 for coord in ["x", "y", "z"] {
404 let mut meta = HashMap::new();
405 meta.insert("unit".to_string(), "km/s".to_string());
406
407 let field = Field::new(
408 format!("Delta V{coord} (RIC) (km/s)"),
409 DataType::Float64,
410 false,
411 )
412 .with_metadata(meta);
413
414 hdrs.push(field);
415 }
416
417 let frame = self.states[0].frame();
418 let more_meta = Some(vec![(
419 "Frame".to_string(),
420 serde_dhall::serialize(&frame)
421 .static_type_annotation()
422 .to_string()
423 .unwrap_or(frame.to_string()),
424 )]);
425
426 let mut cfg = cfg;
427
428 let mut fields = match cfg.fields {
429 Some(fields) => fields,
430 None => S::export_params(),
431 };
432
433 fields.retain(|param| {
435 param != &StateParameter::GuidanceMode && self.first().value(*param).is_ok()
436 });
437
438 for field in &fields {
439 hdrs.push(field.to_field(more_meta.clone()));
440 }
441
442 let schema = Arc::new(Schema::new(hdrs));
444 let mut record: Vec<Arc<dyn Array>> = Vec::new();
445
446 cfg.start_epoch = if self.first().epoch() > other.first().epoch() {
448 Some(self.first().epoch())
449 } else {
450 Some(other.first().epoch())
451 };
452
453 cfg.end_epoch = if self.last().epoch() > other.last().epoch() {
454 Some(other.last().epoch())
455 } else {
456 Some(self.last().epoch())
457 };
458
459 let step = cfg.step.unwrap_or_else(|| 1.minutes());
461 let self_states = self
462 .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
463 .collect::<Vec<S>>();
464
465 let other_states = other
466 .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
467 .collect::<Vec<S>>();
468
469 let mut ric_diff = Vec::with_capacity(other_states.len());
471 for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
472 let self_orbit = self_state.orbit();
473 let other_orbit = other_state.orbit();
474
475 let this_ric_diff = self_orbit.ric_difference(&other_orbit).map_err(Box::new)?;
476
477 ric_diff.push(this_ric_diff);
478 }
479
480 smooth_state_diff_in_place(&mut ric_diff, if other_states.len() > 5 { 5 } else { 1 });
481
482 let mut utc_epoch = StringBuilder::new();
486 for s in &self_states {
487 utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
488 }
489 record.push(Arc::new(utc_epoch.finish()));
490
491 for coord_no in 0..6 {
493 let mut data = Float64Builder::new();
494 for this_ric_dff in &ric_diff {
495 data.append_value(this_ric_dff.to_cartesian_pos_vel()[coord_no]);
496 }
497 record.push(Arc::new(data.finish()));
498 }
499
500 for field in fields {
502 let mut data = Float64Builder::new();
503 for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
504 let self_val = self_state.value(field).map_err(Box::new)?;
505 let other_val = other_state.value(field).map_err(Box::new)?;
506
507 data.append_value(self_val - other_val);
508 }
509
510 record.push(Arc::new(data.finish()));
511 }
512
513 info!("Serialized {} states differences", self_states.len());
514
515 let mut metadata = HashMap::new();
517 metadata.insert(
518 "Purpose".to_string(),
519 "Trajectory difference data".to_string(),
520 );
521 if let Some(add_meta) = cfg.metadata {
522 for (k, v) in add_meta {
523 metadata.insert(k, v);
524 }
525 }
526
527 let props = pq_writer(Some(metadata));
528
529 let file = File::create(&path_buf)?;
530 let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
531
532 let batch = RecordBatch::try_new(schema, record)?;
533 writer.write(&batch)?;
534 writer.close()?;
535
536 let tock_time = Epoch::now().unwrap() - tick;
538 info!(
539 "Trajectory written to {} in {tock_time}",
540 path_buf.display()
541 );
542 Ok(path_buf)
543 }
544}
545
546impl<S: Interpolatable> ops::Add for Traj<S>
547where
548 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
549{
550 type Output = Result<Traj<S>, NyxError>;
551
552 fn add(self, other: Traj<S>) -> Self::Output {
554 &self + &other
555 }
556}
557
558impl<S: Interpolatable> ops::Add<&Traj<S>> for &Traj<S>
559where
560 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
561{
562 type Output = Result<Traj<S>, NyxError>;
563
564 fn add(self, other: &Traj<S>) -> Self::Output {
566 if self.first().frame() != other.first().frame() {
567 Err(NyxError::Trajectory {
568 source: TrajError::CreationError {
569 msg: format!(
570 "Frame mismatch in add operation: {} != {}",
571 self.first().frame(),
572 other.first().frame()
573 ),
574 },
575 })
576 } else {
577 if self.last().epoch() < other.first().epoch() {
578 let gap = other.first().epoch() - self.last().epoch();
579 warn!(
580 "Resulting merged trajectory will have a time-gap of {} starting at {}",
581 gap,
582 self.last().epoch()
583 );
584 }
585
586 let mut me = self.clone();
587 for state in &other
589 .states
590 .iter()
591 .copied()
592 .filter(|s| s.epoch() > self.last().epoch())
593 .collect::<Vec<S>>()
594 {
595 me.states.push(*state);
596 }
597 me.finalize();
598
599 Ok(me)
600 }
601 }
602}
603
604impl<S: Interpolatable> ops::AddAssign<&Traj<S>> for Traj<S>
605where
606 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
607{
608 fn add_assign(&mut self, rhs: &Self) {
614 *self = (self.clone() + rhs.clone()).unwrap();
615 }
616}
617
618impl<S: Interpolatable> fmt::Display for Traj<S>
619where
620 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
621{
622 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
623 if self.states.is_empty() {
624 write!(f, "Empty Trajectory!")
625 } else {
626 let dur = self.last().epoch() - self.first().epoch();
627 write!(
628 f,
629 "Trajectory {}in {} from {} to {} ({}, or {:.3} s) [{} states]",
630 match &self.name {
631 Some(name) => format!("of {name} "),
632 None => String::new(),
633 },
634 self.first().frame(),
635 self.first().epoch(),
636 self.last().epoch(),
637 dur,
638 dur.to_seconds(),
639 self.states.len()
640 )
641 }
642 }
643}
644
645impl<S: Interpolatable> fmt::Debug for Traj<S>
646where
647 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
648{
649 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
650 write!(f, "{self}",)
651 }
652}
653
654impl<S: Interpolatable> Default for Traj<S>
655where
656 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
657{
658 fn default() -> Self {
659 Self::new()
660 }
661}