Skip to main content

nyx_space/md/trajectory/
traj.rs

1/*
2    Nyx, blazing fast astrodynamics
3    Copyright (C) 2018-onwards Christopher Rabotin <christopher.rabotin@gmail.com>
4
5    This program is free software: you can redistribute it and/or modify
6    it under the terms of the GNU Affero General Public License as published
7    by the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU Affero General Public License for more details.
14
15    You should have received a copy of the GNU Affero General Public License
16    along with this program.  If not, see <https://www.gnu.org/licenses/>.
17*/
18
19use super::traj_it::TrajIterator;
20use super::{ExportCfg, InterpolationSnafu, INTERPOLATION_SAMPLES};
21use super::{Interpolatable, TrajError};
22use crate::errors::NyxError;
23use crate::io::watermark::pq_writer;
24use crate::io::InputOutputError;
25use crate::linalg::allocator::Allocator;
26use crate::linalg::DefaultAllocator;
27use crate::md::prelude::{GuidanceMode, StateParameter};
28use crate::md::trajectory::smooth_state_diff_in_place;
29use crate::time::{Duration, Epoch, TimeSeries, TimeUnits};
30use anise::analysis::specs::StateSpecTrait;
31use anise::analysis::AnalysisError;
32use anise::astro::orbit::Orbit;
33use anise::prelude::{Aberration, Almanac};
34use arrow::array::{Array, Float64Builder, StringBuilder};
35use arrow::datatypes::{DataType, Field, Schema};
36use arrow::record_batch::RecordBatch;
37use hifitime::TimeScale;
38use log::{info, warn};
39use parquet::arrow::ArrowWriter;
40use snafu::ResultExt;
41use std::collections::HashMap;
42use std::error::Error;
43use std::fmt;
44use std::fs::File;
45use std::iter::Iterator;
46use std::ops;
47use std::ops::Bound::{Excluded, Included, Unbounded};
48use std::path::{Path, PathBuf};
49use std::sync::Arc;
50
51/// Store a trajectory of any State.
52#[derive(Clone, PartialEq)]
53pub struct Traj<S: Interpolatable>
54where
55    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
56{
57    /// Optionally name this trajectory
58    pub name: Option<String>,
59    /// We use a vector because we know that the states are produced in a chronological manner (the direction does not matter).
60    pub states: Vec<S>,
61}
62
63impl<S: Interpolatable> Traj<S>
64where
65    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
66{
67    pub fn new() -> Self {
68        Self {
69            name: None,
70            states: Vec::new(),
71        }
72    }
73    /// Orders the states, can be used to store the states out of order
74    pub fn finalize(&mut self) {
75        // Remove duplicate epochs
76        self.states.dedup_by(|a, b| a.epoch().eq(&b.epoch()));
77        // And sort
78        self.states.sort_by_key(|a| a.epoch());
79    }
80
81    /// Evaluate the trajectory at this specific epoch.
82    pub fn at(&self, epoch: Epoch) -> Result<S, TrajError> {
83        if self.states.is_empty() || self.first().epoch() > epoch || self.last().epoch() < epoch {
84            return Err(TrajError::NoInterpolationData { epoch });
85        }
86        match self
87            .states
88            .binary_search_by(|state| state.epoch().cmp(&epoch))
89        {
90            Ok(idx) => {
91                // Oh wow, we actually had this exact state!
92                Ok(self.states[idx])
93            }
94            Err(idx) => {
95                if idx == 0 || idx >= self.states.len() {
96                    // The binary search returns where we should insert the data, so if it's at either end of the list, then we're out of bounds.
97                    // This condition should have been handled by the check at the start of this function.
98                    return Err(TrajError::NoInterpolationData { epoch });
99                }
100                // This is the closest index, so let's grab the items around it.
101                // NOTE: This is essentially the same code as in ANISE for the Hermite SPK type 13
102
103                // We didn't find it, so let's build an interpolation here.
104                let num_left = INTERPOLATION_SAMPLES / 2;
105
106                // Ensure that we aren't fetching out of the window
107                let mut first_idx = idx.saturating_sub(num_left);
108                let last_idx = self.states.len().min(first_idx + INTERPOLATION_SAMPLES);
109
110                // Check that we have enough samples
111                if last_idx == self.states.len() {
112                    first_idx = last_idx.saturating_sub(2 * num_left);
113                }
114
115                let mut states = Vec::with_capacity(last_idx - first_idx);
116                for idx in first_idx..last_idx {
117                    states.push(self.states[idx]);
118                }
119
120                self.states[idx]
121                    .interpolate(epoch, &states)
122                    .context(InterpolationSnafu)
123            }
124        }
125    }
126
127    /// Returns the first state in this ephemeris
128    pub fn first(&self) -> &S {
129        // This is done after we've ordered the states we received, so we can just return the first state.
130        self.states.first().unwrap()
131    }
132
133    /// Returns the last state in this ephemeris
134    pub fn last(&self) -> &S {
135        self.states.last().unwrap()
136    }
137
138    pub fn start_epoch(&self) -> Epoch {
139        self.first().epoch()
140    }
141
142    pub fn end_epoch(&self) -> Epoch {
143        self.last().epoch()
144    }
145
146    /// Creates an iterator through the trajectory by the provided step size
147    pub fn every(&self, step: Duration) -> TrajIterator<'_, S> {
148        self.every_between(step, self.first().epoch(), self.last().epoch())
149    }
150
151    /// Creates an iterator through the trajectory by the provided step size between the provided bounds
152    pub fn every_between(&self, step: Duration, start: Epoch, end: Epoch) -> TrajIterator<'_, S> {
153        TrajIterator {
154            time_series: TimeSeries::inclusive(
155                start.max(self.first().epoch()),
156                end.min(self.last().epoch()),
157                step,
158            ),
159            traj: self,
160        }
161    }
162
163    /// Returns a new trajectory that only contains states that fall within the given epoch range.
164    pub fn filter_by_epoch<R: ops::RangeBounds<Epoch>>(mut self, bound: R) -> Self {
165        self.states = self
166            .states
167            .iter()
168            .copied()
169            .filter(|s| bound.contains(&s.epoch()))
170            .collect::<Vec<_>>();
171        self
172    }
173
174    /// Returns a new trajectory that only contains states that fall within the given offset from the first epoch.
175    /// For example, a bound of 30.minutes()..90.minutes() will only return states from the start of the trajectory + 30 minutes until start + 90 minutes.
176    pub fn filter_by_offset<R: ops::RangeBounds<Duration>>(self, bound: R) -> Self {
177        if self.states.is_empty() {
178            return self;
179        }
180        // Rebuild an epoch bound.
181        let start = match bound.start_bound() {
182            Unbounded => self.states.first().unwrap().epoch(),
183            Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
184        };
185
186        let end = match bound.end_bound() {
187            Unbounded => self.states.last().unwrap().epoch(),
188            Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
189        };
190
191        self.filter_by_epoch(start..=end)
192    }
193    /// Store this trajectory arc to a parquet file with the default configuration (depends on the state type, search for `export_params` in the documentation for details).
194    pub fn to_parquet_simple<P: AsRef<Path>>(&self, path: P) -> Result<PathBuf, Box<dyn Error>> {
195        self.to_parquet(path, ExportCfg::default())
196    }
197
198    /// Store this trajectory arc to a parquet file with the provided configuration
199    pub fn to_parquet_with_cfg<P: AsRef<Path>>(
200        &self,
201        path: P,
202        cfg: ExportCfg,
203    ) -> Result<PathBuf, Box<dyn Error>> {
204        self.to_parquet(path, cfg)
205    }
206
207    /// A shortcut to `to_parquet_with_cfg`
208    pub fn to_parquet_with_step<P: AsRef<Path>>(
209        &self,
210        path: P,
211        step: Duration,
212    ) -> Result<(), Box<dyn Error>> {
213        self.to_parquet_with_cfg(
214            path,
215            ExportCfg {
216                step: Some(step),
217                ..Default::default()
218            },
219        )?;
220
221        Ok(())
222    }
223
224    /// Store this trajectory arc to a parquet file with the provided configuration
225    pub fn to_parquet<P: AsRef<Path>>(
226        &self,
227        path: P,
228        cfg: ExportCfg,
229    ) -> Result<PathBuf, Box<dyn Error>> {
230        let tick = Epoch::now().unwrap();
231        info!("Exporting trajectory to parquet file...");
232
233        // Grab the path here before we move stuff.
234        let path_buf = cfg.actual_path(path);
235
236        // Build the states iterator -- this does require copying the current states but I can't either get a reference or a copy of all the states.
237        let states = if cfg.start_epoch.is_some() || cfg.end_epoch.is_some() || cfg.step.is_some() {
238            // Must interpolate the data!
239            let start = cfg.start_epoch.unwrap_or_else(|| self.first().epoch());
240            let end = cfg.end_epoch.unwrap_or_else(|| self.last().epoch());
241            let step = cfg.step.unwrap_or_else(|| 1.minutes());
242            self.every_between(step, start, end).collect::<Vec<S>>()
243        } else {
244            self.states.to_vec()
245        };
246
247        // Build the schema
248        let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
249
250        let frame = self.states[0].frame();
251        let more_meta = Some(vec![(
252            "Frame".to_string(),
253            serde_dhall::serialize(&frame)
254                .static_type_annotation()
255                .to_string()
256                .map_err(|e| {
257                    Box::new(InputOutputError::SerializeDhall {
258                        what: format!("frame `{frame}`"),
259                        err: e.to_string(),
260                    })
261                })?,
262        )]);
263
264        let requested_fields = match cfg.fields {
265            Some(fields) => fields,
266            None => S::export_params(),
267        };
268
269        let mut fields = Vec::new();
270        let mut field_nullable = Vec::new();
271        for field in requested_fields {
272            let mut any_ok = false;
273            let mut any_err = false;
274            for state in &states {
275                if state.value(field).is_ok() {
276                    any_ok = true;
277                } else {
278                    any_err = true;
279                }
280            }
281
282            if any_ok {
283                fields.push(field);
284                field_nullable.push(any_err);
285            }
286        }
287
288        for (field, nullable) in fields.iter().zip(field_nullable.iter().copied()) {
289            hdrs.push(field.to_field(more_meta.clone()).with_nullable(nullable));
290        }
291
292        // Build the schema
293        let schema = Arc::new(Schema::new(hdrs));
294        let mut record: Vec<Arc<dyn Array>> = Vec::new();
295
296        // Build all of the records
297
298        // Epochs
299        let mut utc_epoch = StringBuilder::new();
300        for s in &states {
301            utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
302        }
303        record.push(Arc::new(utc_epoch.finish()));
304
305        // Add all of the fields
306        for field in fields {
307            if field == StateParameter::GuidanceMode() {
308                let mut guid_mode = StringBuilder::new();
309                for s in &states {
310                    match s.value(field) {
311                        Ok(value) => {
312                            guid_mode.append_value(format!("{:?}", GuidanceMode::from(value)));
313                        }
314                        Err(_) => guid_mode.append_null(),
315                    }
316                }
317                record.push(Arc::new(guid_mode.finish()));
318            } else {
319                let mut data = Float64Builder::new();
320                for s in &states {
321                    match s.value(field) {
322                        Ok(value) => data.append_value(value),
323                        Err(_) => data.append_null(),
324                    };
325                }
326                record.push(Arc::new(data.finish()));
327            }
328        }
329
330        info!(
331            "Serialized {} states from {} to {}",
332            states.len(),
333            states.first().unwrap().epoch(),
334            states.last().unwrap().epoch()
335        );
336
337        // Serialize all of the devices and add that to the parquet file too.
338        let mut metadata = HashMap::new();
339        metadata.insert("Purpose".to_string(), "Trajectory data".to_string());
340        if let Some(add_meta) = cfg.metadata {
341            for (k, v) in add_meta {
342                metadata.insert(k, v);
343            }
344        }
345
346        let props = pq_writer(Some(metadata));
347
348        let file = File::create(&path_buf)?;
349        let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
350
351        let batch = RecordBatch::try_new(schema, record)?;
352        writer.write(&batch)?;
353        writer.close()?;
354
355        // Return the path this was written to
356        let tock_time = Epoch::now().unwrap() - tick;
357        info!(
358            "Trajectory written to {} in {tock_time}",
359            path_buf.display()
360        );
361        Ok(path_buf)
362    }
363
364    /// Allows resampling this trajectory at a fixed interval instead of using the propagator step size.
365    /// This may lead to aliasing due to the Nyquist–Shannon sampling theorem.
366    pub fn resample(&self, step: Duration) -> Result<Self, NyxError> {
367        if self.states.is_empty() {
368            return Err(NyxError::Trajectory {
369                source: TrajError::CreationError {
370                    msg: "No trajectory to convert".to_string(),
371                },
372            });
373        }
374
375        let mut traj = Self::new();
376        for state in self.every(step) {
377            traj.states.push(state);
378        }
379
380        traj.finalize();
381
382        Ok(traj)
383    }
384
385    /// Rebuilds this trajectory with the provided epochs.
386    /// This may lead to aliasing due to the Nyquist–Shannon sampling theorem.
387    pub fn rebuild(&self, epochs: &[Epoch]) -> Result<Self, NyxError> {
388        if self.states.is_empty() {
389            return Err(NyxError::Trajectory {
390                source: TrajError::CreationError {
391                    msg: "No trajectory to convert".to_string(),
392                },
393            });
394        }
395
396        let mut traj = Self::new();
397        for epoch in epochs {
398            traj.states.push(self.at(*epoch)?);
399        }
400
401        traj.finalize();
402
403        Ok(traj)
404    }
405
406    /// Export the difference in RIC from of this trajectory compare to the "other" trajectory in parquet format.
407    ///
408    /// # Notes
409    /// + The RIC frame accounts for the transport theorem by performing a finite differencing of the RIC frame.
410    pub fn ric_diff_to_parquet<P: AsRef<Path>>(
411        &self,
412        other: &Self,
413        path: P,
414        cfg: ExportCfg,
415    ) -> Result<PathBuf, Box<dyn Error>> {
416        let tick = Epoch::now().unwrap();
417        info!("Exporting trajectory to parquet file...");
418
419        // Grab the path here before we move stuff.
420        let path_buf = cfg.actual_path(path);
421
422        // Build the schema
423        let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
424
425        // Add the RIC headers
426        for coord in ["X", "Y", "Z"] {
427            let mut meta = HashMap::new();
428            meta.insert("unit".to_string(), "km".to_string());
429
430            let field = Field::new(
431                format!("Delta {coord} (RIC) (km)"),
432                DataType::Float64,
433                false,
434            )
435            .with_metadata(meta);
436
437            hdrs.push(field);
438        }
439
440        for coord in ["x", "y", "z"] {
441            let mut meta = HashMap::new();
442            meta.insert("unit".to_string(), "km/s".to_string());
443
444            let field = Field::new(
445                format!("Delta V{coord} (RIC) (km/s)"),
446                DataType::Float64,
447                false,
448            )
449            .with_metadata(meta);
450
451            hdrs.push(field);
452        }
453
454        let frame = self.states[0].frame();
455        let more_meta = Some(vec![(
456            "Frame".to_string(),
457            serde_dhall::serialize(&frame)
458                .static_type_annotation()
459                .to_string()
460                .unwrap_or(frame.to_string()),
461        )]);
462
463        let mut cfg = cfg;
464
465        let mut fields = match cfg.fields {
466            Some(fields) => fields,
467            None => S::export_params(),
468        };
469
470        // Remove disallowed field and check that we can retrieve this information
471        fields.retain(|param| {
472            param != &StateParameter::GuidanceMode() && self.first().value(*param).is_ok()
473        });
474
475        for field in &fields {
476            hdrs.push(field.to_field(more_meta.clone()));
477        }
478
479        // Build the schema
480        let schema = Arc::new(Schema::new(hdrs));
481        let mut record: Vec<Arc<dyn Array>> = Vec::new();
482
483        // Ensure the times match.
484        cfg.start_epoch = if self.first().epoch() > other.first().epoch() {
485            Some(self.first().epoch())
486        } else {
487            Some(other.first().epoch())
488        };
489
490        cfg.end_epoch = if self.last().epoch() > other.last().epoch() {
491            Some(other.last().epoch())
492        } else {
493            Some(self.last().epoch())
494        };
495
496        // Build the states iterator
497        let step = cfg.step.unwrap_or_else(|| 1.minutes());
498        let self_states = self
499            .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
500            .collect::<Vec<S>>();
501
502        let other_states = other
503            .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
504            .collect::<Vec<S>>();
505
506        // Build an array of all the RIC differences
507        let mut ric_diff = Vec::with_capacity(other_states.len());
508        for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
509            let self_orbit = self_state.orbit();
510            let other_orbit = other_state.orbit();
511
512            let this_ric_diff = self_orbit.ric_difference(&other_orbit).map_err(Box::new)?;
513
514            ric_diff.push(this_ric_diff);
515        }
516
517        smooth_state_diff_in_place(&mut ric_diff, if other_states.len() > 5 { 5 } else { 1 });
518
519        // Build all of the records
520
521        // Epochs (both match for self and others)
522        let mut utc_epoch = StringBuilder::new();
523        for s in &self_states {
524            utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
525        }
526        record.push(Arc::new(utc_epoch.finish()));
527
528        // Add the RIC data
529        for coord_no in 0..6 {
530            let mut data = Float64Builder::new();
531            for this_ric_dff in &ric_diff {
532                data.append_value(this_ric_dff.to_cartesian_pos_vel()[coord_no]);
533            }
534            record.push(Arc::new(data.finish()));
535        }
536
537        // Add all of the fields
538        for field in fields {
539            let mut data = Float64Builder::new();
540            for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
541                let self_val = self_state.value(field).map_err(Box::new)?;
542                let other_val = other_state.value(field).map_err(Box::new)?;
543
544                data.append_value(self_val - other_val);
545            }
546
547            record.push(Arc::new(data.finish()));
548        }
549
550        info!("Serialized {} states differences", self_states.len());
551
552        // Serialize all of the devices and add that to the parquet file too.
553        let mut metadata = HashMap::new();
554        metadata.insert(
555            "Purpose".to_string(),
556            "Trajectory difference data".to_string(),
557        );
558        if let Some(add_meta) = cfg.metadata {
559            for (k, v) in add_meta {
560                metadata.insert(k, v);
561            }
562        }
563
564        let props = pq_writer(Some(metadata));
565
566        let file = File::create(&path_buf)?;
567        let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
568
569        let batch = RecordBatch::try_new(schema, record)?;
570        writer.write(&batch)?;
571        writer.close()?;
572
573        // Return the path this was written to
574        let tock_time = Epoch::now().unwrap() - tick;
575        info!(
576            "Trajectory written to {} in {tock_time}",
577            path_buf.display()
578        );
579        Ok(path_buf)
580    }
581}
582
583impl<S: Interpolatable> ops::Add for Traj<S>
584where
585    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
586{
587    type Output = Result<Traj<S>, NyxError>;
588
589    /// Add one trajectory to another. If they do not overlap to within 10ms, a warning will be printed.
590    fn add(self, other: Traj<S>) -> Self::Output {
591        &self + &other
592    }
593}
594
595impl<S: Interpolatable> ops::Add<&Traj<S>> for &Traj<S>
596where
597    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
598{
599    type Output = Result<Traj<S>, NyxError>;
600
601    /// Add one trajectory to another, returns an error if the frames don't match
602    fn add(self, other: &Traj<S>) -> Self::Output {
603        if self.first().frame() != other.first().frame() {
604            Err(NyxError::Trajectory {
605                source: TrajError::CreationError {
606                    msg: format!(
607                        "Frame mismatch in add operation: {} != {}",
608                        self.first().frame(),
609                        other.first().frame()
610                    ),
611                },
612            })
613        } else {
614            if self.last().epoch() < other.first().epoch() {
615                let gap = other.first().epoch() - self.last().epoch();
616                warn!(
617                    "Resulting merged trajectory will have a time-gap of {} starting at {}",
618                    gap,
619                    self.last().epoch()
620                );
621            }
622
623            let mut me = self.clone();
624            // Now start adding the other segments while correcting the index
625            for state in &other
626                .states
627                .iter()
628                .copied()
629                .filter(|s| s.epoch() > self.last().epoch())
630                .collect::<Vec<S>>()
631            {
632                me.states.push(*state);
633            }
634            me.finalize();
635
636            Ok(me)
637        }
638    }
639}
640
641impl<S: Interpolatable> ops::AddAssign<&Traj<S>> for Traj<S>
642where
643    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
644{
645    /// Attempt to add two trajectories together and assign it to `self`
646    ///
647    /// # Warnings
648    /// 1. This will panic if the frames mismatch!
649    /// 2. This is inefficient because both `self` and `rhs` are cloned.
650    fn add_assign(&mut self, rhs: &Self) {
651        *self = (self.clone() + rhs.clone()).unwrap();
652    }
653}
654
655impl<S: Interpolatable> fmt::Display for Traj<S>
656where
657    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
658{
659    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
660        if self.states.is_empty() {
661            write!(f, "Empty Trajectory!")
662        } else {
663            let dur = self.last().epoch() - self.first().epoch();
664            write!(
665                f,
666                "Trajectory {}in {} from {} to {} ({}, or {:.3} s) [{} states]",
667                match &self.name {
668                    Some(name) => format!("of {name} "),
669                    None => String::new(),
670                },
671                self.first().frame(),
672                self.first().epoch(),
673                self.last().epoch(),
674                dur,
675                dur.to_seconds(),
676                self.states.len()
677            )
678        }
679    }
680}
681
682impl<S: Interpolatable> fmt::Debug for Traj<S>
683where
684    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
685{
686    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
687        write!(f, "{self}",)
688    }
689}
690
691impl<S: Interpolatable> Default for Traj<S>
692where
693    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
694{
695    fn default() -> Self {
696        Self::new()
697    }
698}
699
700impl<S: Interpolatable> StateSpecTrait for Traj<S>
701where
702    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
703{
704    fn ab_corr(&self) -> Option<Aberration> {
705        None
706    }
707
708    fn evaluate(&self, epoch: Epoch, _almanac: &Almanac) -> Result<Orbit, AnalysisError> {
709        self.at(epoch)
710            .map(|state| state.orbit())
711            .map_err(|e| AnalysisError::GenericAnalysisError {
712                err: format!("{e}"),
713            })
714    }
715}