nyx_space/md/trajectory/
traj.rs

1/*
2    Nyx, blazing fast astrodynamics
3    Copyright (C) 2018-onwards Christopher Rabotin <christopher.rabotin@gmail.com>
4
5    This program is free software: you can redistribute it and/or modify
6    it under the terms of the GNU Affero General Public License as published
7    by the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU Affero General Public License for more details.
14
15    You should have received a copy of the GNU Affero General Public License
16    along with this program.  If not, see <https://www.gnu.org/licenses/>.
17*/
18
19use super::traj_it::TrajIterator;
20use super::{ExportCfg, InterpolationSnafu, INTERPOLATION_SAMPLES};
21use super::{Interpolatable, TrajError};
22use crate::errors::NyxError;
23use crate::io::watermark::pq_writer;
24use crate::io::InputOutputError;
25use crate::linalg::allocator::Allocator;
26use crate::linalg::DefaultAllocator;
27use crate::md::prelude::{GuidanceMode, StateParameter};
28use crate::md::trajectory::smooth_state_diff_in_place;
29use crate::md::EventEvaluator;
30use crate::time::{Duration, Epoch, TimeSeries, TimeUnits};
31use anise::almanac::Almanac;
32use arrow::array::{Array, Float64Builder, StringBuilder};
33use arrow::datatypes::{DataType, Field, Schema};
34use arrow::record_batch::RecordBatch;
35use hifitime::TimeScale;
36use log::{info, warn};
37use parquet::arrow::ArrowWriter;
38use snafu::ResultExt;
39use std::collections::HashMap;
40use std::error::Error;
41use std::fmt;
42use std::fs::File;
43use std::iter::Iterator;
44use std::ops;
45use std::ops::Bound::{Excluded, Included, Unbounded};
46use std::path::{Path, PathBuf};
47use std::sync::Arc;
48
49/// Store a trajectory of any State.
50#[derive(Clone, PartialEq)]
51pub struct Traj<S: Interpolatable>
52where
53    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
54{
55    /// Optionally name this trajectory
56    pub name: Option<String>,
57    /// We use a vector because we know that the states are produced in a chronological manner (the direction does not matter).
58    pub states: Vec<S>,
59}
60
61impl<S: Interpolatable> Traj<S>
62where
63    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
64{
65    pub fn new() -> Self {
66        Self {
67            name: None,
68            states: Vec::new(),
69        }
70    }
71    /// Orders the states, can be used to store the states out of order
72    pub fn finalize(&mut self) {
73        // Remove duplicate epochs
74        self.states.dedup_by(|a, b| a.epoch().eq(&b.epoch()));
75        // And sort
76        self.states.sort_by_key(|a| a.epoch());
77    }
78
79    /// Evaluate the trajectory at this specific epoch.
80    pub fn at(&self, epoch: Epoch) -> Result<S, TrajError> {
81        if self.states.is_empty() || self.first().epoch() > epoch || self.last().epoch() < epoch {
82            return Err(TrajError::NoInterpolationData { epoch });
83        }
84        match self
85            .states
86            .binary_search_by(|state| state.epoch().cmp(&epoch))
87        {
88            Ok(idx) => {
89                // Oh wow, we actually had this exact state!
90                Ok(self.states[idx])
91            }
92            Err(idx) => {
93                if idx == 0 || idx >= self.states.len() {
94                    // The binary search returns where we should insert the data, so if it's at either end of the list, then we're out of bounds.
95                    // This condition should have been handled by the check at the start of this function.
96                    return Err(TrajError::NoInterpolationData { epoch });
97                }
98                // This is the closest index, so let's grab the items around it.
99                // NOTE: This is essentially the same code as in ANISE for the Hermite SPK type 13
100
101                // We didn't find it, so let's build an interpolation here.
102                let num_left = INTERPOLATION_SAMPLES / 2;
103
104                // Ensure that we aren't fetching out of the window
105                let mut first_idx = idx.saturating_sub(num_left);
106                let last_idx = self.states.len().min(first_idx + INTERPOLATION_SAMPLES);
107
108                // Check that we have enough samples
109                if last_idx == self.states.len() {
110                    first_idx = last_idx.saturating_sub(2 * num_left);
111                }
112
113                let mut states = Vec::with_capacity(last_idx - first_idx);
114                for idx in first_idx..last_idx {
115                    states.push(self.states[idx]);
116                }
117
118                self.states[idx]
119                    .interpolate(epoch, &states)
120                    .context(InterpolationSnafu)
121            }
122        }
123    }
124
125    /// Returns the first state in this ephemeris
126    pub fn first(&self) -> &S {
127        // This is done after we've ordered the states we received, so we can just return the first state.
128        self.states.first().unwrap()
129    }
130
131    /// Returns the last state in this ephemeris
132    pub fn last(&self) -> &S {
133        self.states.last().unwrap()
134    }
135
136    /// Creates an iterator through the trajectory by the provided step size
137    pub fn every(&self, step: Duration) -> TrajIterator<'_, S> {
138        self.every_between(step, self.first().epoch(), self.last().epoch())
139    }
140
141    /// Creates an iterator through the trajectory by the provided step size between the provided bounds
142    pub fn every_between(&self, step: Duration, start: Epoch, end: Epoch) -> TrajIterator<'_, S> {
143        TrajIterator {
144            time_series: TimeSeries::inclusive(
145                start.max(self.first().epoch()),
146                end.min(self.last().epoch()),
147                step,
148            ),
149            traj: self,
150        }
151    }
152
153    /// Returns a new trajectory that only contains states that fall within the given epoch range.
154    pub fn filter_by_epoch<R: ops::RangeBounds<Epoch>>(mut self, bound: R) -> Self {
155        self.states = self
156            .states
157            .iter()
158            .copied()
159            .filter(|s| bound.contains(&s.epoch()))
160            .collect::<Vec<_>>();
161        self
162    }
163
164    /// Returns a new trajectory that only contains states that fall within the given offset from the first epoch.
165    /// For example, a bound of 30.minutes()..90.minutes() will only return states from the start of the trajectory + 30 minutes until start + 90 minutes.
166    pub fn filter_by_offset<R: ops::RangeBounds<Duration>>(self, bound: R) -> Self {
167        if self.states.is_empty() {
168            return self;
169        }
170        // Rebuild an epoch bound.
171        let start = match bound.start_bound() {
172            Unbounded => self.states.first().unwrap().epoch(),
173            Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
174        };
175
176        let end = match bound.end_bound() {
177            Unbounded => self.states.last().unwrap().epoch(),
178            Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
179        };
180
181        self.filter_by_epoch(start..=end)
182    }
183    /// Store this trajectory arc to a parquet file with the default configuration (depends on the state type, search for `export_params` in the documentation for details).
184    pub fn to_parquet_simple<P: AsRef<Path>>(
185        &self,
186        path: P,
187        almanac: Arc<Almanac>,
188    ) -> Result<PathBuf, Box<dyn Error>> {
189        self.to_parquet(path, None, ExportCfg::default(), almanac)
190    }
191
192    /// Store this trajectory arc to a parquet file with the provided configuration
193    pub fn to_parquet_with_cfg<P: AsRef<Path>>(
194        &self,
195        path: P,
196        cfg: ExportCfg,
197        almanac: Arc<Almanac>,
198    ) -> Result<PathBuf, Box<dyn Error>> {
199        self.to_parquet(path, None, cfg, almanac)
200    }
201
202    /// A shortcut to `to_parquet_with_cfg`
203    pub fn to_parquet_with_step<P: AsRef<Path>>(
204        &self,
205        path: P,
206        step: Duration,
207        almanac: Arc<Almanac>,
208    ) -> Result<(), Box<dyn Error>> {
209        self.to_parquet_with_cfg(
210            path,
211            ExportCfg {
212                step: Some(step),
213                ..Default::default()
214            },
215            almanac,
216        )?;
217
218        Ok(())
219    }
220
221    /// Store this trajectory arc to a parquet file with the provided configuration and event evaluators
222    pub fn to_parquet<P: AsRef<Path>>(
223        &self,
224        path: P,
225        events: Option<Vec<&dyn EventEvaluator<S>>>,
226        cfg: ExportCfg,
227        almanac: Arc<Almanac>,
228    ) -> Result<PathBuf, Box<dyn Error>> {
229        let tick = Epoch::now().unwrap();
230        info!("Exporting trajectory to parquet file...");
231
232        // Grab the path here before we move stuff.
233        let path_buf = cfg.actual_path(path);
234
235        // Build the schema
236        let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
237
238        let frame = self.states[0].frame();
239        let more_meta = Some(vec![(
240            "Frame".to_string(),
241            serde_dhall::serialize(&frame)
242                .static_type_annotation()
243                .to_string()
244                .map_err(|e| {
245                    Box::new(InputOutputError::SerializeDhall {
246                        what: format!("frame `{frame}`"),
247                        err: e.to_string(),
248                    })
249                })?,
250        )]);
251
252        let mut fields = match cfg.fields {
253            Some(fields) => fields,
254            None => S::export_params(),
255        };
256
257        // Check that we can retrieve this information
258        fields.retain(|param| self.first().value(*param).is_ok());
259
260        for field in &fields {
261            hdrs.push(field.to_field(more_meta.clone()));
262        }
263
264        if let Some(events) = events.as_ref() {
265            for event in events {
266                let field = Field::new(format!("{event}"), DataType::Float64, false);
267                hdrs.push(field);
268            }
269        }
270
271        // Build the schema
272        let schema = Arc::new(Schema::new(hdrs));
273        let mut record: Vec<Arc<dyn Array>> = Vec::new();
274
275        // Build the states iterator -- this does require copying the current states but I can't either get a reference or a copy of all the states.
276        let states = if cfg.start_epoch.is_some() || cfg.end_epoch.is_some() || cfg.step.is_some() {
277            // Must interpolate the data!
278            let start = cfg.start_epoch.unwrap_or_else(|| self.first().epoch());
279            let end = cfg.end_epoch.unwrap_or_else(|| self.last().epoch());
280            let step = cfg.step.unwrap_or_else(|| 1.minutes());
281            self.every_between(step, start, end).collect::<Vec<S>>()
282        } else {
283            self.states.to_vec()
284        };
285
286        // Build all of the records
287
288        // Epochs
289        let mut utc_epoch = StringBuilder::new();
290        for s in &states {
291            utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
292        }
293        record.push(Arc::new(utc_epoch.finish()));
294
295        // Add all of the fields
296        for field in fields {
297            if field == StateParameter::GuidanceMode {
298                let mut guid_mode = StringBuilder::new();
299                for s in &states {
300                    guid_mode
301                        .append_value(format!("{:?}", GuidanceMode::from(s.value(field).unwrap())));
302                }
303                record.push(Arc::new(guid_mode.finish()));
304            } else {
305                let mut data = Float64Builder::new();
306                for s in &states {
307                    data.append_value(s.value(field).unwrap());
308                }
309                record.push(Arc::new(data.finish()));
310            }
311        }
312
313        info!(
314            "Serialized {} states from {} to {}",
315            states.len(),
316            states.first().unwrap().epoch(),
317            states.last().unwrap().epoch()
318        );
319
320        // Add all of the evaluated events
321        if let Some(events) = events {
322            info!("Evaluating {} event(s)", events.len());
323            for event in events {
324                let mut data = Float64Builder::new();
325                for s in &states {
326                    data.append_value(event.eval(s, almanac.clone()).map_err(Box::new)?);
327                }
328                record.push(Arc::new(data.finish()));
329            }
330        }
331
332        // Serialize all of the devices and add that to the parquet file too.
333        let mut metadata = HashMap::new();
334        metadata.insert("Purpose".to_string(), "Trajectory data".to_string());
335        if let Some(add_meta) = cfg.metadata {
336            for (k, v) in add_meta {
337                metadata.insert(k, v);
338            }
339        }
340
341        let props = pq_writer(Some(metadata));
342
343        let file = File::create(&path_buf)?;
344        let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
345
346        let batch = RecordBatch::try_new(schema, record)?;
347        writer.write(&batch)?;
348        writer.close()?;
349
350        // Return the path this was written to
351        let tock_time = Epoch::now().unwrap() - tick;
352        info!(
353            "Trajectory written to {} in {tock_time}",
354            path_buf.display()
355        );
356        Ok(path_buf)
357    }
358
359    /// Allows resampling this trajectory at a fixed interval instead of using the propagator step size.
360    /// This may lead to aliasing due to the Nyquist–Shannon sampling theorem.
361    pub fn resample(&self, step: Duration) -> Result<Self, NyxError> {
362        if self.states.is_empty() {
363            return Err(NyxError::Trajectory {
364                source: TrajError::CreationError {
365                    msg: "No trajectory to convert".to_string(),
366                },
367            });
368        }
369
370        let mut traj = Self::new();
371        for state in self.every(step) {
372            traj.states.push(state);
373        }
374
375        traj.finalize();
376
377        Ok(traj)
378    }
379
380    /// Rebuilds this trajectory with the provided epochs.
381    /// This may lead to aliasing due to the Nyquist–Shannon sampling theorem.
382    pub fn rebuild(&self, epochs: &[Epoch]) -> Result<Self, NyxError> {
383        if self.states.is_empty() {
384            return Err(NyxError::Trajectory {
385                source: TrajError::CreationError {
386                    msg: "No trajectory to convert".to_string(),
387                },
388            });
389        }
390
391        let mut traj = Self::new();
392        for epoch in epochs {
393            traj.states.push(self.at(*epoch)?);
394        }
395
396        traj.finalize();
397
398        Ok(traj)
399    }
400
401    /// Export the difference in RIC from of this trajectory compare to the "other" trajectory in parquet format.
402    ///
403    /// # Notes
404    /// + The RIC frame accounts for the transport theorem by performing a finite differencing of the RIC frame.
405    pub fn ric_diff_to_parquet<P: AsRef<Path>>(
406        &self,
407        other: &Self,
408        path: P,
409        cfg: ExportCfg,
410    ) -> Result<PathBuf, Box<dyn Error>> {
411        let tick = Epoch::now().unwrap();
412        info!("Exporting trajectory to parquet file...");
413
414        // Grab the path here before we move stuff.
415        let path_buf = cfg.actual_path(path);
416
417        // Build the schema
418        let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
419
420        // Add the RIC headers
421        for coord in ["X", "Y", "Z"] {
422            let mut meta = HashMap::new();
423            meta.insert("unit".to_string(), "km".to_string());
424
425            let field = Field::new(
426                format!("Delta {coord} (RIC) (km)"),
427                DataType::Float64,
428                false,
429            )
430            .with_metadata(meta);
431
432            hdrs.push(field);
433        }
434
435        for coord in ["x", "y", "z"] {
436            let mut meta = HashMap::new();
437            meta.insert("unit".to_string(), "km/s".to_string());
438
439            let field = Field::new(
440                format!("Delta V{coord} (RIC) (km/s)"),
441                DataType::Float64,
442                false,
443            )
444            .with_metadata(meta);
445
446            hdrs.push(field);
447        }
448
449        let frame = self.states[0].frame();
450        let more_meta = Some(vec![(
451            "Frame".to_string(),
452            serde_dhall::serialize(&frame)
453                .static_type_annotation()
454                .to_string()
455                .unwrap_or(frame.to_string()),
456        )]);
457
458        let mut cfg = cfg;
459
460        let mut fields = match cfg.fields {
461            Some(fields) => fields,
462            None => S::export_params(),
463        };
464
465        // Remove disallowed field and check that we can retrieve this information
466        fields.retain(|param| {
467            param != &StateParameter::GuidanceMode && self.first().value(*param).is_ok()
468        });
469
470        for field in &fields {
471            hdrs.push(field.to_field(more_meta.clone()));
472        }
473
474        // Build the schema
475        let schema = Arc::new(Schema::new(hdrs));
476        let mut record: Vec<Arc<dyn Array>> = Vec::new();
477
478        // Ensure the times match.
479        cfg.start_epoch = if self.first().epoch() > other.first().epoch() {
480            Some(self.first().epoch())
481        } else {
482            Some(other.first().epoch())
483        };
484
485        cfg.end_epoch = if self.last().epoch() > other.last().epoch() {
486            Some(other.last().epoch())
487        } else {
488            Some(self.last().epoch())
489        };
490
491        // Build the states iterator
492        let step = cfg.step.unwrap_or_else(|| 1.minutes());
493        let self_states = self
494            .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
495            .collect::<Vec<S>>();
496
497        let other_states = other
498            .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
499            .collect::<Vec<S>>();
500
501        // Build an array of all the RIC differences
502        let mut ric_diff = Vec::with_capacity(other_states.len());
503        for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
504            let self_orbit = self_state.orbit();
505            let other_orbit = other_state.orbit();
506
507            let this_ric_diff = self_orbit.ric_difference(&other_orbit).map_err(Box::new)?;
508
509            ric_diff.push(this_ric_diff);
510        }
511
512        smooth_state_diff_in_place(&mut ric_diff, if other_states.len() > 5 { 5 } else { 1 });
513
514        // Build all of the records
515
516        // Epochs (both match for self and others)
517        let mut utc_epoch = StringBuilder::new();
518        for s in &self_states {
519            utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
520        }
521        record.push(Arc::new(utc_epoch.finish()));
522
523        // Add the RIC data
524        for coord_no in 0..6 {
525            let mut data = Float64Builder::new();
526            for this_ric_dff in &ric_diff {
527                data.append_value(this_ric_dff.to_cartesian_pos_vel()[coord_no]);
528            }
529            record.push(Arc::new(data.finish()));
530        }
531
532        // Add all of the fields
533        for field in fields {
534            let mut data = Float64Builder::new();
535            for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
536                let self_val = self_state.value(field).map_err(Box::new)?;
537                let other_val = other_state.value(field).map_err(Box::new)?;
538
539                data.append_value(self_val - other_val);
540            }
541
542            record.push(Arc::new(data.finish()));
543        }
544
545        info!("Serialized {} states differences", self_states.len());
546
547        // Serialize all of the devices and add that to the parquet file too.
548        let mut metadata = HashMap::new();
549        metadata.insert(
550            "Purpose".to_string(),
551            "Trajectory difference data".to_string(),
552        );
553        if let Some(add_meta) = cfg.metadata {
554            for (k, v) in add_meta {
555                metadata.insert(k, v);
556            }
557        }
558
559        let props = pq_writer(Some(metadata));
560
561        let file = File::create(&path_buf)?;
562        let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
563
564        let batch = RecordBatch::try_new(schema, record)?;
565        writer.write(&batch)?;
566        writer.close()?;
567
568        // Return the path this was written to
569        let tock_time = Epoch::now().unwrap() - tick;
570        info!(
571            "Trajectory written to {} in {tock_time}",
572            path_buf.display()
573        );
574        Ok(path_buf)
575    }
576}
577
578impl<S: Interpolatable> ops::Add for Traj<S>
579where
580    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
581{
582    type Output = Result<Traj<S>, NyxError>;
583
584    /// Add one trajectory to another. If they do not overlap to within 10ms, a warning will be printed.
585    fn add(self, other: Traj<S>) -> Self::Output {
586        &self + &other
587    }
588}
589
590impl<S: Interpolatable> ops::Add<&Traj<S>> for &Traj<S>
591where
592    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
593{
594    type Output = Result<Traj<S>, NyxError>;
595
596    /// Add one trajectory to another, returns an error if the frames don't match
597    fn add(self, other: &Traj<S>) -> Self::Output {
598        if self.first().frame() != other.first().frame() {
599            Err(NyxError::Trajectory {
600                source: TrajError::CreationError {
601                    msg: format!(
602                        "Frame mismatch in add operation: {} != {}",
603                        self.first().frame(),
604                        other.first().frame()
605                    ),
606                },
607            })
608        } else {
609            if self.last().epoch() < other.first().epoch() {
610                let gap = other.first().epoch() - self.last().epoch();
611                warn!(
612                    "Resulting merged trajectory will have a time-gap of {} starting at {}",
613                    gap,
614                    self.last().epoch()
615                );
616            }
617
618            let mut me = self.clone();
619            // Now start adding the other segments while correcting the index
620            for state in &other
621                .states
622                .iter()
623                .copied()
624                .filter(|s| s.epoch() > self.last().epoch())
625                .collect::<Vec<S>>()
626            {
627                me.states.push(*state);
628            }
629            me.finalize();
630
631            Ok(me)
632        }
633    }
634}
635
636impl<S: Interpolatable> ops::AddAssign<&Traj<S>> for Traj<S>
637where
638    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
639{
640    /// Attempt to add two trajectories together and assign it to `self`
641    ///
642    /// # Warnings
643    /// 1. This will panic if the frames mismatch!
644    /// 2. This is inefficient because both `self` and `rhs` are cloned.
645    fn add_assign(&mut self, rhs: &Self) {
646        *self = (self.clone() + rhs.clone()).unwrap();
647    }
648}
649
650impl<S: Interpolatable> fmt::Display for Traj<S>
651where
652    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
653{
654    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
655        if self.states.is_empty() {
656            write!(f, "Empty Trajectory!")
657        } else {
658            let dur = self.last().epoch() - self.first().epoch();
659            write!(
660                f,
661                "Trajectory {}in {} from {} to {} ({}, or {:.3} s) [{} states]",
662                match &self.name {
663                    Some(name) => format!("of {name} "),
664                    None => String::new(),
665                },
666                self.first().frame(),
667                self.first().epoch(),
668                self.last().epoch(),
669                dur,
670                dur.to_seconds(),
671                self.states.len()
672            )
673        }
674    }
675}
676
677impl<S: Interpolatable> fmt::Debug for Traj<S>
678where
679    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
680{
681    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
682        write!(f, "{self}",)
683    }
684}
685
686impl<S: Interpolatable> Default for Traj<S>
687where
688    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
689{
690    fn default() -> Self {
691        Self::new()
692    }
693}