nyx_space/md/trajectory/
traj.rs

1/*
2    Nyx, blazing fast astrodynamics
3    Copyright (C) 2018-onwards Christopher Rabotin <christopher.rabotin@gmail.com>
4
5    This program is free software: you can redistribute it and/or modify
6    it under the terms of the GNU Affero General Public License as published
7    by the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU Affero General Public License for more details.
14
15    You should have received a copy of the GNU Affero General Public License
16    along with this program.  If not, see <https://www.gnu.org/licenses/>.
17*/
18
19use super::traj_it::TrajIterator;
20use super::{ExportCfg, InterpolationSnafu, INTERPOLATION_SAMPLES};
21use super::{Interpolatable, TrajError};
22use crate::errors::NyxError;
23use crate::io::watermark::pq_writer;
24use crate::io::InputOutputError;
25use crate::linalg::allocator::Allocator;
26use crate::linalg::DefaultAllocator;
27use crate::md::prelude::{GuidanceMode, StateParameter};
28use crate::md::trajectory::smooth_state_diff_in_place;
29use crate::md::EventEvaluator;
30use crate::time::{Duration, Epoch, TimeSeries, TimeUnits};
31use anise::almanac::Almanac;
32use arrow::array::{Array, Float64Builder, StringBuilder};
33use arrow::datatypes::{DataType, Field, Schema};
34use arrow::record_batch::RecordBatch;
35use hifitime::TimeScale;
36use parquet::arrow::ArrowWriter;
37use snafu::ResultExt;
38use std::collections::HashMap;
39use std::error::Error;
40use std::fmt;
41use std::fs::File;
42use std::iter::Iterator;
43use std::ops;
44use std::ops::Bound::{Excluded, Included, Unbounded};
45use std::path::{Path, PathBuf};
46use std::sync::Arc;
47
48/// Store a trajectory of any State.
49#[derive(Clone, PartialEq)]
50pub struct Traj<S: Interpolatable>
51where
52    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
53{
54    /// Optionally name this trajectory
55    pub name: Option<String>,
56    /// We use a vector because we know that the states are produced in a chronological manner (the direction does not matter).
57    pub states: Vec<S>,
58}
59
60impl<S: Interpolatable> Traj<S>
61where
62    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
63{
64    pub fn new() -> Self {
65        Self {
66            name: None,
67            states: Vec::new(),
68        }
69    }
70    /// Orders the states, can be used to store the states out of order
71    pub fn finalize(&mut self) {
72        // Remove duplicate epochs
73        self.states.dedup_by(|a, b| a.epoch().eq(&b.epoch()));
74        // And sort
75        self.states.sort_by_key(|a| a.epoch());
76    }
77
78    /// Evaluate the trajectory at this specific epoch.
79    pub fn at(&self, epoch: Epoch) -> Result<S, TrajError> {
80        if self.states.is_empty() || self.first().epoch() > epoch || self.last().epoch() < epoch {
81            return Err(TrajError::NoInterpolationData { epoch });
82        }
83        match self
84            .states
85            .binary_search_by(|state| state.epoch().cmp(&epoch))
86        {
87            Ok(idx) => {
88                // Oh wow, we actually had this exact state!
89                Ok(self.states[idx])
90            }
91            Err(idx) => {
92                if idx == 0 || idx >= self.states.len() {
93                    // The binary search returns where we should insert the data, so if it's at either end of the list, then we're out of bounds.
94                    // This condition should have been handled by the check at the start of this function.
95                    return Err(TrajError::NoInterpolationData { epoch });
96                }
97                // This is the closest index, so let's grab the items around it.
98                // NOTE: This is essentially the same code as in ANISE for the Hermite SPK type 13
99
100                // We didn't find it, so let's build an interpolation here.
101                let num_left = INTERPOLATION_SAMPLES / 2;
102
103                // Ensure that we aren't fetching out of the window
104                let mut first_idx = idx.saturating_sub(num_left);
105                let last_idx = self.states.len().min(first_idx + INTERPOLATION_SAMPLES);
106
107                // Check that we have enough samples
108                if last_idx == self.states.len() {
109                    first_idx = last_idx.saturating_sub(2 * num_left);
110                }
111
112                let mut states = Vec::with_capacity(last_idx - first_idx);
113                for idx in first_idx..last_idx {
114                    states.push(self.states[idx]);
115                }
116
117                self.states[idx]
118                    .interpolate(epoch, &states)
119                    .context(InterpolationSnafu)
120            }
121        }
122    }
123
124    /// Returns the first state in this ephemeris
125    pub fn first(&self) -> &S {
126        // This is done after we've ordered the states we received, so we can just return the first state.
127        self.states.first().unwrap()
128    }
129
130    /// Returns the last state in this ephemeris
131    pub fn last(&self) -> &S {
132        self.states.last().unwrap()
133    }
134
135    /// Creates an iterator through the trajectory by the provided step size
136    pub fn every(&self, step: Duration) -> TrajIterator<'_, S> {
137        self.every_between(step, self.first().epoch(), self.last().epoch())
138    }
139
140    /// Creates an iterator through the trajectory by the provided step size between the provided bounds
141    pub fn every_between(&self, step: Duration, start: Epoch, end: Epoch) -> TrajIterator<'_, S> {
142        TrajIterator {
143            time_series: TimeSeries::inclusive(
144                start.max(self.first().epoch()),
145                end.min(self.last().epoch()),
146                step,
147            ),
148            traj: self,
149        }
150    }
151
152    /// Returns a new trajectory that only contains states that fall within the given epoch range.
153    pub fn filter_by_epoch<R: ops::RangeBounds<Epoch>>(mut self, bound: R) -> Self {
154        self.states = self
155            .states
156            .iter()
157            .copied()
158            .filter(|s| bound.contains(&s.epoch()))
159            .collect::<Vec<_>>();
160        self
161    }
162
163    /// Returns a new trajectory that only contains states that fall within the given offset from the first epoch.
164    /// For example, a bound of 30.minutes()..90.minutes() will only return states from the start of the trajectory + 30 minutes until start + 90 minutes.
165    pub fn filter_by_offset<R: ops::RangeBounds<Duration>>(self, bound: R) -> Self {
166        if self.states.is_empty() {
167            return self;
168        }
169        // Rebuild an epoch bound.
170        let start = match bound.start_bound() {
171            Unbounded => self.states.first().unwrap().epoch(),
172            Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
173        };
174
175        let end = match bound.end_bound() {
176            Unbounded => self.states.last().unwrap().epoch(),
177            Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
178        };
179
180        self.filter_by_epoch(start..=end)
181    }
182    /// Store this trajectory arc to a parquet file with the default configuration (depends on the state type, search for `export_params` in the documentation for details).
183    pub fn to_parquet_simple<P: AsRef<Path>>(
184        &self,
185        path: P,
186        almanac: Arc<Almanac>,
187    ) -> Result<PathBuf, Box<dyn Error>> {
188        self.to_parquet(path, None, ExportCfg::default(), almanac)
189    }
190
191    /// Store this trajectory arc to a parquet file with the provided configuration
192    pub fn to_parquet_with_cfg<P: AsRef<Path>>(
193        &self,
194        path: P,
195        cfg: ExportCfg,
196        almanac: Arc<Almanac>,
197    ) -> Result<PathBuf, Box<dyn Error>> {
198        self.to_parquet(path, None, cfg, almanac)
199    }
200
201    /// A shortcut to `to_parquet_with_cfg`
202    pub fn to_parquet_with_step<P: AsRef<Path>>(
203        &self,
204        path: P,
205        step: Duration,
206        almanac: Arc<Almanac>,
207    ) -> Result<(), Box<dyn Error>> {
208        self.to_parquet_with_cfg(
209            path,
210            ExportCfg {
211                step: Some(step),
212                ..Default::default()
213            },
214            almanac,
215        )?;
216
217        Ok(())
218    }
219
220    /// Store this trajectory arc to a parquet file with the provided configuration and event evaluators
221    pub fn to_parquet<P: AsRef<Path>>(
222        &self,
223        path: P,
224        events: Option<Vec<&dyn EventEvaluator<S>>>,
225        cfg: ExportCfg,
226        almanac: Arc<Almanac>,
227    ) -> Result<PathBuf, Box<dyn Error>> {
228        let tick = Epoch::now().unwrap();
229        info!("Exporting trajectory to parquet file...");
230
231        // Grab the path here before we move stuff.
232        let path_buf = cfg.actual_path(path);
233
234        // Build the schema
235        let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
236
237        let frame = self.states[0].frame();
238        let more_meta = Some(vec![(
239            "Frame".to_string(),
240            serde_dhall::serialize(&frame)
241                .static_type_annotation()
242                .to_string()
243                .map_err(|e| {
244                    Box::new(InputOutputError::SerializeDhall {
245                        what: format!("frame `{frame}`"),
246                        err: e.to_string(),
247                    })
248                })?,
249        )]);
250
251        let mut fields = match cfg.fields {
252            Some(fields) => fields,
253            None => S::export_params(),
254        };
255
256        // Check that we can retrieve this information
257        fields.retain(|param| self.first().value(*param).is_ok());
258
259        for field in &fields {
260            hdrs.push(field.to_field(more_meta.clone()));
261        }
262
263        if let Some(events) = events.as_ref() {
264            for event in events {
265                let field = Field::new(format!("{event}"), DataType::Float64, false);
266                hdrs.push(field);
267            }
268        }
269
270        // Build the schema
271        let schema = Arc::new(Schema::new(hdrs));
272        let mut record: Vec<Arc<dyn Array>> = Vec::new();
273
274        // Build the states iterator -- this does require copying the current states but I can't either get a reference or a copy of all the states.
275        let states = if cfg.start_epoch.is_some() || cfg.end_epoch.is_some() || cfg.step.is_some() {
276            // Must interpolate the data!
277            let start = cfg.start_epoch.unwrap_or_else(|| self.first().epoch());
278            let end = cfg.end_epoch.unwrap_or_else(|| self.last().epoch());
279            let step = cfg.step.unwrap_or_else(|| 1.minutes());
280            self.every_between(step, start, end).collect::<Vec<S>>()
281        } else {
282            self.states.to_vec()
283        };
284
285        // Build all of the records
286
287        // Epochs
288        let mut utc_epoch = StringBuilder::new();
289        for s in &states {
290            utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
291        }
292        record.push(Arc::new(utc_epoch.finish()));
293
294        // Add all of the fields
295        for field in fields {
296            if field == StateParameter::GuidanceMode {
297                let mut guid_mode = StringBuilder::new();
298                for s in &states {
299                    guid_mode
300                        .append_value(format!("{:?}", GuidanceMode::from(s.value(field).unwrap())));
301                }
302                record.push(Arc::new(guid_mode.finish()));
303            } else {
304                let mut data = Float64Builder::new();
305                for s in &states {
306                    data.append_value(s.value(field).unwrap());
307                }
308                record.push(Arc::new(data.finish()));
309            }
310        }
311
312        info!(
313            "Serialized {} states from {} to {}",
314            states.len(),
315            states.first().unwrap().epoch(),
316            states.last().unwrap().epoch()
317        );
318
319        // Add all of the evaluated events
320        if let Some(events) = events {
321            info!("Evaluating {} event(s)", events.len());
322            for event in events {
323                let mut data = Float64Builder::new();
324                for s in &states {
325                    data.append_value(event.eval(s, almanac.clone()).map_err(Box::new)?);
326                }
327                record.push(Arc::new(data.finish()));
328            }
329        }
330
331        // Serialize all of the devices and add that to the parquet file too.
332        let mut metadata = HashMap::new();
333        metadata.insert("Purpose".to_string(), "Trajectory data".to_string());
334        if let Some(add_meta) = cfg.metadata {
335            for (k, v) in add_meta {
336                metadata.insert(k, v);
337            }
338        }
339
340        let props = pq_writer(Some(metadata));
341
342        let file = File::create(&path_buf)?;
343        let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
344
345        let batch = RecordBatch::try_new(schema, record)?;
346        writer.write(&batch)?;
347        writer.close()?;
348
349        // Return the path this was written to
350        let tock_time = Epoch::now().unwrap() - tick;
351        info!(
352            "Trajectory written to {} in {tock_time}",
353            path_buf.display()
354        );
355        Ok(path_buf)
356    }
357
358    /// Allows resampling this trajectory at a fixed interval instead of using the propagator step size.
359    /// This may lead to aliasing due to the Nyquist–Shannon sampling theorem.
360    pub fn resample(&self, step: Duration) -> Result<Self, NyxError> {
361        if self.states.is_empty() {
362            return Err(NyxError::Trajectory {
363                source: TrajError::CreationError {
364                    msg: "No trajectory to convert".to_string(),
365                },
366            });
367        }
368
369        let mut traj = Self::new();
370        for state in self.every(step) {
371            traj.states.push(state);
372        }
373
374        traj.finalize();
375
376        Ok(traj)
377    }
378
379    /// Rebuilds this trajectory with the provided epochs.
380    /// This may lead to aliasing due to the Nyquist–Shannon sampling theorem.
381    pub fn rebuild(&self, epochs: &[Epoch]) -> Result<Self, NyxError> {
382        if self.states.is_empty() {
383            return Err(NyxError::Trajectory {
384                source: TrajError::CreationError {
385                    msg: "No trajectory to convert".to_string(),
386                },
387            });
388        }
389
390        let mut traj = Self::new();
391        for epoch in epochs {
392            traj.states.push(self.at(*epoch)?);
393        }
394
395        traj.finalize();
396
397        Ok(traj)
398    }
399
400    /// Export the difference in RIC from of this trajectory compare to the "other" trajectory in parquet format.
401    ///
402    /// # Notes
403    /// + The RIC frame accounts for the transport theorem by performing a finite differencing of the RIC frame.
404    pub fn ric_diff_to_parquet<P: AsRef<Path>>(
405        &self,
406        other: &Self,
407        path: P,
408        cfg: ExportCfg,
409    ) -> Result<PathBuf, Box<dyn Error>> {
410        let tick = Epoch::now().unwrap();
411        info!("Exporting trajectory to parquet file...");
412
413        // Grab the path here before we move stuff.
414        let path_buf = cfg.actual_path(path);
415
416        // Build the schema
417        let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
418
419        // Add the RIC headers
420        for coord in ["X", "Y", "Z"] {
421            let mut meta = HashMap::new();
422            meta.insert("unit".to_string(), "km".to_string());
423
424            let field = Field::new(
425                format!("Delta {coord} (RIC) (km)"),
426                DataType::Float64,
427                false,
428            )
429            .with_metadata(meta);
430
431            hdrs.push(field);
432        }
433
434        for coord in ["x", "y", "z"] {
435            let mut meta = HashMap::new();
436            meta.insert("unit".to_string(), "km/s".to_string());
437
438            let field = Field::new(
439                format!("Delta V{coord} (RIC) (km/s)"),
440                DataType::Float64,
441                false,
442            )
443            .with_metadata(meta);
444
445            hdrs.push(field);
446        }
447
448        let frame = self.states[0].frame();
449        let more_meta = Some(vec![(
450            "Frame".to_string(),
451            serde_dhall::serialize(&frame)
452                .static_type_annotation()
453                .to_string()
454                .unwrap_or(frame.to_string()),
455        )]);
456
457        let mut cfg = cfg;
458
459        let mut fields = match cfg.fields {
460            Some(fields) => fields,
461            None => S::export_params(),
462        };
463
464        // Remove disallowed field and check that we can retrieve this information
465        fields.retain(|param| {
466            param != &StateParameter::GuidanceMode && self.first().value(*param).is_ok()
467        });
468
469        for field in &fields {
470            hdrs.push(field.to_field(more_meta.clone()));
471        }
472
473        // Build the schema
474        let schema = Arc::new(Schema::new(hdrs));
475        let mut record: Vec<Arc<dyn Array>> = Vec::new();
476
477        // Ensure the times match.
478        cfg.start_epoch = if self.first().epoch() > other.first().epoch() {
479            Some(self.first().epoch())
480        } else {
481            Some(other.first().epoch())
482        };
483
484        cfg.end_epoch = if self.last().epoch() > other.last().epoch() {
485            Some(other.last().epoch())
486        } else {
487            Some(self.last().epoch())
488        };
489
490        // Build the states iterator
491        let step = cfg.step.unwrap_or_else(|| 1.minutes());
492        let self_states = self
493            .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
494            .collect::<Vec<S>>();
495
496        let other_states = other
497            .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
498            .collect::<Vec<S>>();
499
500        // Build an array of all the RIC differences
501        let mut ric_diff = Vec::with_capacity(other_states.len());
502        for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
503            let self_orbit = self_state.orbit();
504            let other_orbit = other_state.orbit();
505
506            let this_ric_diff = self_orbit.ric_difference(&other_orbit).map_err(Box::new)?;
507
508            ric_diff.push(this_ric_diff);
509        }
510
511        smooth_state_diff_in_place(&mut ric_diff, if other_states.len() > 5 { 5 } else { 1 });
512
513        // Build all of the records
514
515        // Epochs (both match for self and others)
516        let mut utc_epoch = StringBuilder::new();
517        for s in &self_states {
518            utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
519        }
520        record.push(Arc::new(utc_epoch.finish()));
521
522        // Add the RIC data
523        for coord_no in 0..6 {
524            let mut data = Float64Builder::new();
525            for this_ric_dff in &ric_diff {
526                data.append_value(this_ric_dff.to_cartesian_pos_vel()[coord_no]);
527            }
528            record.push(Arc::new(data.finish()));
529        }
530
531        // Add all of the fields
532        for field in fields {
533            let mut data = Float64Builder::new();
534            for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
535                let self_val = self_state.value(field).map_err(Box::new)?;
536                let other_val = other_state.value(field).map_err(Box::new)?;
537
538                data.append_value(self_val - other_val);
539            }
540
541            record.push(Arc::new(data.finish()));
542        }
543
544        info!("Serialized {} states differences", self_states.len());
545
546        // Serialize all of the devices and add that to the parquet file too.
547        let mut metadata = HashMap::new();
548        metadata.insert(
549            "Purpose".to_string(),
550            "Trajectory difference data".to_string(),
551        );
552        if let Some(add_meta) = cfg.metadata {
553            for (k, v) in add_meta {
554                metadata.insert(k, v);
555            }
556        }
557
558        let props = pq_writer(Some(metadata));
559
560        let file = File::create(&path_buf)?;
561        let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
562
563        let batch = RecordBatch::try_new(schema, record)?;
564        writer.write(&batch)?;
565        writer.close()?;
566
567        // Return the path this was written to
568        let tock_time = Epoch::now().unwrap() - tick;
569        info!(
570            "Trajectory written to {} in {tock_time}",
571            path_buf.display()
572        );
573        Ok(path_buf)
574    }
575}
576
577impl<S: Interpolatable> ops::Add for Traj<S>
578where
579    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
580{
581    type Output = Result<Traj<S>, NyxError>;
582
583    /// Add one trajectory to another. If they do not overlap to within 10ms, a warning will be printed.
584    fn add(self, other: Traj<S>) -> Self::Output {
585        &self + &other
586    }
587}
588
589impl<S: Interpolatable> ops::Add<&Traj<S>> for &Traj<S>
590where
591    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
592{
593    type Output = Result<Traj<S>, NyxError>;
594
595    /// Add one trajectory to another, returns an error if the frames don't match
596    fn add(self, other: &Traj<S>) -> Self::Output {
597        if self.first().frame() != other.first().frame() {
598            Err(NyxError::Trajectory {
599                source: TrajError::CreationError {
600                    msg: format!(
601                        "Frame mismatch in add operation: {} != {}",
602                        self.first().frame(),
603                        other.first().frame()
604                    ),
605                },
606            })
607        } else {
608            if self.last().epoch() < other.first().epoch() {
609                let gap = other.first().epoch() - self.last().epoch();
610                warn!(
611                    "Resulting merged trajectory will have a time-gap of {} starting at {}",
612                    gap,
613                    self.last().epoch()
614                );
615            }
616
617            let mut me = self.clone();
618            // Now start adding the other segments while correcting the index
619            for state in &other
620                .states
621                .iter()
622                .copied()
623                .filter(|s| s.epoch() > self.last().epoch())
624                .collect::<Vec<S>>()
625            {
626                me.states.push(*state);
627            }
628            me.finalize();
629
630            Ok(me)
631        }
632    }
633}
634
635impl<S: Interpolatable> ops::AddAssign<&Traj<S>> for Traj<S>
636where
637    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
638{
639    /// Attempt to add two trajectories together and assign it to `self`
640    ///
641    /// # Warnings
642    /// 1. This will panic if the frames mismatch!
643    /// 2. This is inefficient because both `self` and `rhs` are cloned.
644    fn add_assign(&mut self, rhs: &Self) {
645        *self = (self.clone() + rhs.clone()).unwrap();
646    }
647}
648
649impl<S: Interpolatable> fmt::Display for Traj<S>
650where
651    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
652{
653    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
654        if self.states.is_empty() {
655            write!(f, "Empty Trajectory!")
656        } else {
657            let dur = self.last().epoch() - self.first().epoch();
658            write!(
659                f,
660                "Trajectory {}in {} from {} to {} ({}, or {:.3} s) [{} states]",
661                match &self.name {
662                    Some(name) => format!("of {name} "),
663                    None => String::new(),
664                },
665                self.first().frame(),
666                self.first().epoch(),
667                self.last().epoch(),
668                dur,
669                dur.to_seconds(),
670                self.states.len()
671            )
672        }
673    }
674}
675
676impl<S: Interpolatable> fmt::Debug for Traj<S>
677where
678    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
679{
680    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
681        write!(f, "{self}",)
682    }
683}
684
685impl<S: Interpolatable> Default for Traj<S>
686where
687    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
688{
689    fn default() -> Self {
690        Self::new()
691    }
692}