nyx_space/md/trajectory/
traj.rs

1/*
2    Nyx, blazing fast astrodynamics
3    Copyright (C) 2018-onwards Christopher Rabotin <christopher.rabotin@gmail.com>
4
5    This program is free software: you can redistribute it and/or modify
6    it under the terms of the GNU Affero General Public License as published
7    by the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU Affero General Public License for more details.
14
15    You should have received a copy of the GNU Affero General Public License
16    along with this program.  If not, see <https://www.gnu.org/licenses/>.
17*/
18
19use super::traj_it::TrajIterator;
20use super::{ExportCfg, InterpolationSnafu, INTERPOLATION_SAMPLES};
21use super::{Interpolatable, TrajError};
22use crate::errors::NyxError;
23use crate::io::watermark::pq_writer;
24use crate::io::InputOutputError;
25use crate::linalg::allocator::Allocator;
26use crate::linalg::DefaultAllocator;
27use crate::md::prelude::{GuidanceMode, StateParameter};
28use crate::md::trajectory::smooth_state_diff_in_place;
29use crate::md::EventEvaluator;
30use crate::time::{Duration, Epoch, TimeSeries, TimeUnits};
31use anise::almanac::Almanac;
32use arrow::array::{Array, Float64Builder, StringBuilder};
33use arrow::datatypes::{DataType, Field, Schema};
34use arrow::record_batch::RecordBatch;
35use hifitime::TimeScale;
36use parquet::arrow::ArrowWriter;
37use snafu::ResultExt;
38use std::collections::HashMap;
39use std::error::Error;
40use std::fmt;
41use std::fs::File;
42use std::iter::Iterator;
43use std::ops;
44use std::path::{Path, PathBuf};
45use std::sync::Arc;
46
47/// Store a trajectory of any State.
48#[derive(Clone, PartialEq)]
49pub struct Traj<S: Interpolatable>
50where
51    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
52{
53    /// Optionally name this trajectory
54    pub name: Option<String>,
55    /// We use a vector because we know that the states are produced in a chronological manner (the direction does not matter).
56    pub states: Vec<S>,
57}
58
59impl<S: Interpolatable> Traj<S>
60where
61    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
62{
63    pub fn new() -> Self {
64        Self {
65            name: None,
66            states: Vec::new(),
67        }
68    }
69    /// Orders the states, can be used to store the states out of order
70    pub fn finalize(&mut self) {
71        // Remove duplicate epochs
72        self.states.dedup_by(|a, b| a.epoch().eq(&b.epoch()));
73        // And sort
74        self.states.sort_by_key(|a| a.epoch());
75    }
76
77    /// Evaluate the trajectory at this specific epoch.
78    pub fn at(&self, epoch: Epoch) -> Result<S, TrajError> {
79        if self.states.is_empty() || self.first().epoch() > epoch || self.last().epoch() < epoch {
80            return Err(TrajError::NoInterpolationData { epoch });
81        }
82        match self
83            .states
84            .binary_search_by(|state| state.epoch().cmp(&epoch))
85        {
86            Ok(idx) => {
87                // Oh wow, we actually had this exact state!
88                Ok(self.states[idx])
89            }
90            Err(idx) => {
91                if idx == 0 || idx >= self.states.len() {
92                    // The binary search returns where we should insert the data, so if it's at either end of the list, then we're out of bounds.
93                    // This condition should have been handled by the check at the start of this function.
94                    return Err(TrajError::NoInterpolationData { epoch });
95                }
96                // This is the closest index, so let's grab the items around it.
97                // NOTE: This is essentially the same code as in ANISE for the Hermite SPK type 13
98
99                // We didn't find it, so let's build an interpolation here.
100                let num_left = INTERPOLATION_SAMPLES / 2;
101
102                // Ensure that we aren't fetching out of the window
103                let mut first_idx = idx.saturating_sub(num_left);
104                let last_idx = self.states.len().min(first_idx + INTERPOLATION_SAMPLES);
105
106                // Check that we have enough samples
107                if last_idx == self.states.len() {
108                    first_idx = last_idx.saturating_sub(2 * num_left);
109                }
110
111                let mut states = Vec::with_capacity(last_idx - first_idx);
112                for idx in first_idx..last_idx {
113                    states.push(self.states[idx]);
114                }
115
116                self.states[idx]
117                    .interpolate(epoch, &states)
118                    .context(InterpolationSnafu)
119            }
120        }
121    }
122
123    /// Returns the first state in this ephemeris
124    pub fn first(&self) -> &S {
125        // This is done after we've ordered the states we received, so we can just return the first state.
126        self.states.first().unwrap()
127    }
128
129    /// Returns the last state in this ephemeris
130    pub fn last(&self) -> &S {
131        self.states.last().unwrap()
132    }
133
134    /// Creates an iterator through the trajectory by the provided step size
135    pub fn every(&self, step: Duration) -> TrajIterator<'_, S> {
136        self.every_between(step, self.first().epoch(), self.last().epoch())
137    }
138
139    /// Creates an iterator through the trajectory by the provided step size between the provided bounds
140    pub fn every_between(&self, step: Duration, start: Epoch, end: Epoch) -> TrajIterator<'_, S> {
141        TrajIterator {
142            time_series: TimeSeries::inclusive(
143                start.max(self.first().epoch()),
144                end.min(self.last().epoch()),
145                step,
146            ),
147            traj: self,
148        }
149    }
150
151    /// Store this trajectory arc to a parquet file with the default configuration (depends on the state type, search for `export_params` in the documentation for details).
152    pub fn to_parquet_simple<P: AsRef<Path>>(
153        &self,
154        path: P,
155        almanac: Arc<Almanac>,
156    ) -> Result<PathBuf, Box<dyn Error>> {
157        self.to_parquet(path, None, ExportCfg::default(), almanac)
158    }
159
160    /// Store this trajectory arc to a parquet file with the provided configuration
161    pub fn to_parquet_with_cfg<P: AsRef<Path>>(
162        &self,
163        path: P,
164        cfg: ExportCfg,
165        almanac: Arc<Almanac>,
166    ) -> Result<PathBuf, Box<dyn Error>> {
167        self.to_parquet(path, None, cfg, almanac)
168    }
169
170    /// A shortcut to `to_parquet_with_cfg`
171    pub fn to_parquet_with_step<P: AsRef<Path>>(
172        &self,
173        path: P,
174        step: Duration,
175        almanac: Arc<Almanac>,
176    ) -> Result<(), Box<dyn Error>> {
177        self.to_parquet_with_cfg(
178            path,
179            ExportCfg {
180                step: Some(step),
181                ..Default::default()
182            },
183            almanac,
184        )?;
185
186        Ok(())
187    }
188
189    /// Store this trajectory arc to a parquet file with the provided configuration and event evaluators
190    pub fn to_parquet<P: AsRef<Path>>(
191        &self,
192        path: P,
193        events: Option<Vec<&dyn EventEvaluator<S>>>,
194        cfg: ExportCfg,
195        almanac: Arc<Almanac>,
196    ) -> Result<PathBuf, Box<dyn Error>> {
197        let tick = Epoch::now().unwrap();
198        info!("Exporting trajectory to parquet file...");
199
200        // Grab the path here before we move stuff.
201        let path_buf = cfg.actual_path(path);
202
203        // Build the schema
204        let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
205
206        let frame = self.states[0].frame();
207        let more_meta = Some(vec![(
208            "Frame".to_string(),
209            serde_dhall::serialize(&frame)
210                .static_type_annotation()
211                .to_string()
212                .map_err(|e| {
213                    Box::new(InputOutputError::SerializeDhall {
214                        what: format!("frame `{frame}`"),
215                        err: e.to_string(),
216                    })
217                })?,
218        )]);
219
220        let mut fields = match cfg.fields {
221            Some(fields) => fields,
222            None => S::export_params(),
223        };
224
225        // Check that we can retrieve this information
226        fields.retain(|param| self.first().value(*param).is_ok());
227
228        for field in &fields {
229            hdrs.push(field.to_field(more_meta.clone()));
230        }
231
232        if let Some(events) = events.as_ref() {
233            for event in events {
234                let field = Field::new(format!("{event}"), DataType::Float64, false);
235                hdrs.push(field);
236            }
237        }
238
239        // Build the schema
240        let schema = Arc::new(Schema::new(hdrs));
241        let mut record: Vec<Arc<dyn Array>> = Vec::new();
242
243        // Build the states iterator -- this does require copying the current states but I can't either get a reference or a copy of all the states.
244        let states = if cfg.start_epoch.is_some() || cfg.end_epoch.is_some() || cfg.step.is_some() {
245            // Must interpolate the data!
246            let start = cfg.start_epoch.unwrap_or_else(|| self.first().epoch());
247            let end = cfg.end_epoch.unwrap_or_else(|| self.last().epoch());
248            let step = cfg.step.unwrap_or_else(|| 1.minutes());
249            self.every_between(step, start, end).collect::<Vec<S>>()
250        } else {
251            self.states.to_vec()
252        };
253
254        // Build all of the records
255
256        // Epochs
257        let mut utc_epoch = StringBuilder::new();
258        for s in &states {
259            utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
260        }
261        record.push(Arc::new(utc_epoch.finish()));
262
263        // Add all of the fields
264        for field in fields {
265            if field == StateParameter::GuidanceMode {
266                let mut guid_mode = StringBuilder::new();
267                for s in &states {
268                    guid_mode
269                        .append_value(format!("{:?}", GuidanceMode::from(s.value(field).unwrap())));
270                }
271                record.push(Arc::new(guid_mode.finish()));
272            } else {
273                let mut data = Float64Builder::new();
274                for s in &states {
275                    data.append_value(s.value(field).unwrap());
276                }
277                record.push(Arc::new(data.finish()));
278            }
279        }
280
281        info!(
282            "Serialized {} states from {} to {}",
283            states.len(),
284            states.first().unwrap().epoch(),
285            states.last().unwrap().epoch()
286        );
287
288        // Add all of the evaluated events
289        if let Some(events) = events {
290            info!("Evaluating {} event(s)", events.len());
291            for event in events {
292                let mut data = Float64Builder::new();
293                for s in &states {
294                    data.append_value(event.eval(s, almanac.clone()).map_err(Box::new)?);
295                }
296                record.push(Arc::new(data.finish()));
297            }
298        }
299
300        // Serialize all of the devices and add that to the parquet file too.
301        let mut metadata = HashMap::new();
302        metadata.insert("Purpose".to_string(), "Trajectory data".to_string());
303        if let Some(add_meta) = cfg.metadata {
304            for (k, v) in add_meta {
305                metadata.insert(k, v);
306            }
307        }
308
309        let props = pq_writer(Some(metadata));
310
311        let file = File::create(&path_buf)?;
312        let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
313
314        let batch = RecordBatch::try_new(schema, record)?;
315        writer.write(&batch)?;
316        writer.close()?;
317
318        // Return the path this was written to
319        let tock_time = Epoch::now().unwrap() - tick;
320        info!(
321            "Trajectory written to {} in {tock_time}",
322            path_buf.display()
323        );
324        Ok(path_buf)
325    }
326
327    /// Allows resampling this trajectory at a fixed interval instead of using the propagator step size.
328    /// This may lead to aliasing due to the Nyquist–Shannon sampling theorem.
329    pub fn resample(&self, step: Duration) -> Result<Self, NyxError> {
330        if self.states.is_empty() {
331            return Err(NyxError::Trajectory {
332                source: TrajError::CreationError {
333                    msg: "No trajectory to convert".to_string(),
334                },
335            });
336        }
337
338        let mut traj = Self::new();
339        for state in self.every(step) {
340            traj.states.push(state);
341        }
342
343        traj.finalize();
344
345        Ok(traj)
346    }
347
348    /// Rebuilds this trajectory with the provided epochs.
349    /// This may lead to aliasing due to the Nyquist–Shannon sampling theorem.
350    pub fn rebuild(&self, epochs: &[Epoch]) -> Result<Self, NyxError> {
351        if self.states.is_empty() {
352            return Err(NyxError::Trajectory {
353                source: TrajError::CreationError {
354                    msg: "No trajectory to convert".to_string(),
355                },
356            });
357        }
358
359        let mut traj = Self::new();
360        for epoch in epochs {
361            traj.states.push(self.at(*epoch)?);
362        }
363
364        traj.finalize();
365
366        Ok(traj)
367    }
368
369    /// Export the difference in RIC from of this trajectory compare to the "other" trajectory in parquet format.
370    ///
371    /// # Notes
372    /// + The RIC frame accounts for the transport theorem by performing a finite differencing of the RIC frame.
373    pub fn ric_diff_to_parquet<P: AsRef<Path>>(
374        &self,
375        other: &Self,
376        path: P,
377        cfg: ExportCfg,
378    ) -> Result<PathBuf, Box<dyn Error>> {
379        let tick = Epoch::now().unwrap();
380        info!("Exporting trajectory to parquet file...");
381
382        // Grab the path here before we move stuff.
383        let path_buf = cfg.actual_path(path);
384
385        // Build the schema
386        let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
387
388        // Add the RIC headers
389        for coord in ["X", "Y", "Z"] {
390            let mut meta = HashMap::new();
391            meta.insert("unit".to_string(), "km".to_string());
392
393            let field = Field::new(
394                format!("Delta {coord} (RIC) (km)"),
395                DataType::Float64,
396                false,
397            )
398            .with_metadata(meta);
399
400            hdrs.push(field);
401        }
402
403        for coord in ["x", "y", "z"] {
404            let mut meta = HashMap::new();
405            meta.insert("unit".to_string(), "km/s".to_string());
406
407            let field = Field::new(
408                format!("Delta V{coord} (RIC) (km/s)"),
409                DataType::Float64,
410                false,
411            )
412            .with_metadata(meta);
413
414            hdrs.push(field);
415        }
416
417        let frame = self.states[0].frame();
418        let more_meta = Some(vec![(
419            "Frame".to_string(),
420            serde_dhall::serialize(&frame)
421                .static_type_annotation()
422                .to_string()
423                .unwrap_or(frame.to_string()),
424        )]);
425
426        let mut cfg = cfg;
427
428        let mut fields = match cfg.fields {
429            Some(fields) => fields,
430            None => S::export_params(),
431        };
432
433        // Remove disallowed field and check that we can retrieve this information
434        fields.retain(|param| {
435            param != &StateParameter::GuidanceMode && self.first().value(*param).is_ok()
436        });
437
438        for field in &fields {
439            hdrs.push(field.to_field(more_meta.clone()));
440        }
441
442        // Build the schema
443        let schema = Arc::new(Schema::new(hdrs));
444        let mut record: Vec<Arc<dyn Array>> = Vec::new();
445
446        // Ensure the times match.
447        cfg.start_epoch = if self.first().epoch() > other.first().epoch() {
448            Some(self.first().epoch())
449        } else {
450            Some(other.first().epoch())
451        };
452
453        cfg.end_epoch = if self.last().epoch() > other.last().epoch() {
454            Some(other.last().epoch())
455        } else {
456            Some(self.last().epoch())
457        };
458
459        // Build the states iterator
460        let step = cfg.step.unwrap_or_else(|| 1.minutes());
461        let self_states = self
462            .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
463            .collect::<Vec<S>>();
464
465        let other_states = other
466            .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
467            .collect::<Vec<S>>();
468
469        // Build an array of all the RIC differences
470        let mut ric_diff = Vec::with_capacity(other_states.len());
471        for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
472            let self_orbit = self_state.orbit();
473            let other_orbit = other_state.orbit();
474
475            let this_ric_diff = self_orbit.ric_difference(&other_orbit).map_err(Box::new)?;
476
477            ric_diff.push(this_ric_diff);
478        }
479
480        smooth_state_diff_in_place(&mut ric_diff, if other_states.len() > 5 { 5 } else { 1 });
481
482        // Build all of the records
483
484        // Epochs (both match for self and others)
485        let mut utc_epoch = StringBuilder::new();
486        for s in &self_states {
487            utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
488        }
489        record.push(Arc::new(utc_epoch.finish()));
490
491        // Add the RIC data
492        for coord_no in 0..6 {
493            let mut data = Float64Builder::new();
494            for this_ric_dff in &ric_diff {
495                data.append_value(this_ric_dff.to_cartesian_pos_vel()[coord_no]);
496            }
497            record.push(Arc::new(data.finish()));
498        }
499
500        // Add all of the fields
501        for field in fields {
502            let mut data = Float64Builder::new();
503            for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
504                let self_val = self_state.value(field).map_err(Box::new)?;
505                let other_val = other_state.value(field).map_err(Box::new)?;
506
507                data.append_value(self_val - other_val);
508            }
509
510            record.push(Arc::new(data.finish()));
511        }
512
513        info!("Serialized {} states differences", self_states.len());
514
515        // Serialize all of the devices and add that to the parquet file too.
516        let mut metadata = HashMap::new();
517        metadata.insert(
518            "Purpose".to_string(),
519            "Trajectory difference data".to_string(),
520        );
521        if let Some(add_meta) = cfg.metadata {
522            for (k, v) in add_meta {
523                metadata.insert(k, v);
524            }
525        }
526
527        let props = pq_writer(Some(metadata));
528
529        let file = File::create(&path_buf)?;
530        let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
531
532        let batch = RecordBatch::try_new(schema, record)?;
533        writer.write(&batch)?;
534        writer.close()?;
535
536        // Return the path this was written to
537        let tock_time = Epoch::now().unwrap() - tick;
538        info!(
539            "Trajectory written to {} in {tock_time}",
540            path_buf.display()
541        );
542        Ok(path_buf)
543    }
544}
545
546impl<S: Interpolatable> ops::Add for Traj<S>
547where
548    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
549{
550    type Output = Result<Traj<S>, NyxError>;
551
552    /// Add one trajectory to another. If they do not overlap to within 10ms, a warning will be printed.
553    fn add(self, other: Traj<S>) -> Self::Output {
554        &self + &other
555    }
556}
557
558impl<S: Interpolatable> ops::Add<&Traj<S>> for &Traj<S>
559where
560    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
561{
562    type Output = Result<Traj<S>, NyxError>;
563
564    /// Add one trajectory to another, returns an error if the frames don't match
565    fn add(self, other: &Traj<S>) -> Self::Output {
566        if self.first().frame() != other.first().frame() {
567            Err(NyxError::Trajectory {
568                source: TrajError::CreationError {
569                    msg: format!(
570                        "Frame mismatch in add operation: {} != {}",
571                        self.first().frame(),
572                        other.first().frame()
573                    ),
574                },
575            })
576        } else {
577            if self.last().epoch() < other.first().epoch() {
578                let gap = other.first().epoch() - self.last().epoch();
579                warn!(
580                    "Resulting merged trajectory will have a time-gap of {} starting at {}",
581                    gap,
582                    self.last().epoch()
583                );
584            }
585
586            let mut me = self.clone();
587            // Now start adding the other segments while correcting the index
588            for state in &other
589                .states
590                .iter()
591                .copied()
592                .filter(|s| s.epoch() > self.last().epoch())
593                .collect::<Vec<S>>()
594            {
595                me.states.push(*state);
596            }
597            me.finalize();
598
599            Ok(me)
600        }
601    }
602}
603
604impl<S: Interpolatable> ops::AddAssign<&Traj<S>> for Traj<S>
605where
606    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
607{
608    /// Attempt to add two trajectories together and assign it to `self`
609    ///
610    /// # Warnings
611    /// 1. This will panic if the frames mismatch!
612    /// 2. This is inefficient because both `self` and `rhs` are cloned.
613    fn add_assign(&mut self, rhs: &Self) {
614        *self = (self.clone() + rhs.clone()).unwrap();
615    }
616}
617
618impl<S: Interpolatable> fmt::Display for Traj<S>
619where
620    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
621{
622    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
623        if self.states.is_empty() {
624            write!(f, "Empty Trajectory!")
625        } else {
626            let dur = self.last().epoch() - self.first().epoch();
627            write!(
628                f,
629                "Trajectory {}in {} from {} to {} ({}, or {:.3} s) [{} states]",
630                match &self.name {
631                    Some(name) => format!("of {name} "),
632                    None => String::new(),
633                },
634                self.first().frame(),
635                self.first().epoch(),
636                self.last().epoch(),
637                dur,
638                dur.to_seconds(),
639                self.states.len()
640            )
641        }
642    }
643}
644
645impl<S: Interpolatable> fmt::Debug for Traj<S>
646where
647    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
648{
649    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
650        write!(f, "{self}",)
651    }
652}
653
654impl<S: Interpolatable> Default for Traj<S>
655where
656    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
657{
658    fn default() -> Self {
659        Self::new()
660    }
661}