nyx_space/md/trajectory/
traj.rs

1/*
2    Nyx, blazing fast astrodynamics
3    Copyright (C) 2018-onwards Christopher Rabotin <christopher.rabotin@gmail.com>
4
5    This program is free software: you can redistribute it and/or modify
6    it under the terms of the GNU Affero General Public License as published
7    by the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU Affero General Public License for more details.
14
15    You should have received a copy of the GNU Affero General Public License
16    along with this program.  If not, see <https://www.gnu.org/licenses/>.
17*/
18
19use super::traj_it::TrajIterator;
20use super::{ExportCfg, InterpolationSnafu, INTERPOLATION_SAMPLES};
21use super::{Interpolatable, TrajError};
22use crate::errors::NyxError;
23use crate::io::watermark::pq_writer;
24use crate::io::InputOutputError;
25use crate::linalg::allocator::Allocator;
26use crate::linalg::DefaultAllocator;
27use crate::md::prelude::{GuidanceMode, StateParameter};
28use crate::md::EventEvaluator;
29use crate::time::{Duration, Epoch, TimeSeries, TimeUnits};
30use anise::almanac::Almanac;
31use arrow::array::{Array, Float64Builder, StringBuilder};
32use arrow::datatypes::{DataType, Field, Schema};
33use arrow::record_batch::RecordBatch;
34use hifitime::TimeScale;
35use parquet::arrow::ArrowWriter;
36use snafu::ResultExt;
37use std::collections::HashMap;
38use std::error::Error;
39use std::fmt;
40use std::fs::File;
41use std::iter::Iterator;
42use std::ops;
43use std::path::{Path, PathBuf};
44use std::sync::Arc;
45
46/// Store a trajectory of any State.
47#[derive(Clone, PartialEq)]
48pub struct Traj<S: Interpolatable>
49where
50    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
51{
52    /// Optionally name this trajectory
53    pub name: Option<String>,
54    /// We use a vector because we know that the states are produced in a chronological manner (the direction does not matter).
55    pub states: Vec<S>,
56}
57
58impl<S: Interpolatable> Traj<S>
59where
60    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
61{
62    pub fn new() -> Self {
63        Self {
64            name: None,
65            states: Vec::new(),
66        }
67    }
68    /// Orders the states, can be used to store the states out of order
69    pub fn finalize(&mut self) {
70        // Remove duplicate epochs
71        self.states.dedup_by(|a, b| a.epoch().eq(&b.epoch()));
72        // And sort
73        self.states.sort_by_key(|a| a.epoch());
74    }
75
76    /// Evaluate the trajectory at this specific epoch.
77    pub fn at(&self, epoch: Epoch) -> Result<S, TrajError> {
78        if self.states.is_empty() || self.first().epoch() > epoch || self.last().epoch() < epoch {
79            return Err(TrajError::NoInterpolationData { epoch });
80        }
81        match self
82            .states
83            .binary_search_by(|state| state.epoch().cmp(&epoch))
84        {
85            Ok(idx) => {
86                // Oh wow, we actually had this exact state!
87                Ok(self.states[idx])
88            }
89            Err(idx) => {
90                if idx == 0 || idx >= self.states.len() {
91                    // The binary search returns where we should insert the data, so if it's at either end of the list, then we're out of bounds.
92                    // This condition should have been handled by the check at the start of this function.
93                    return Err(TrajError::NoInterpolationData { epoch });
94                }
95                // This is the closest index, so let's grab the items around it.
96                // NOTE: This is essentially the same code as in ANISE for the Hermite SPK type 13
97
98                // We didn't find it, so let's build an interpolation here.
99                let num_left = INTERPOLATION_SAMPLES / 2;
100
101                // Ensure that we aren't fetching out of the window
102                let mut first_idx = idx.saturating_sub(num_left);
103                let last_idx = self.states.len().min(first_idx + INTERPOLATION_SAMPLES);
104
105                // Check that we have enough samples
106                if last_idx == self.states.len() {
107                    first_idx = last_idx.saturating_sub(2 * num_left);
108                }
109
110                let mut states = Vec::with_capacity(last_idx - first_idx);
111                for idx in first_idx..last_idx {
112                    states.push(self.states[idx]);
113                }
114
115                self.states[idx]
116                    .interpolate(epoch, &states)
117                    .context(InterpolationSnafu)
118            }
119        }
120    }
121
122    /// Returns the first state in this ephemeris
123    pub fn first(&self) -> &S {
124        // This is done after we've ordered the states we received, so we can just return the first state.
125        self.states.first().unwrap()
126    }
127
128    /// Returns the last state in this ephemeris
129    pub fn last(&self) -> &S {
130        self.states.last().unwrap()
131    }
132
133    /// Creates an iterator through the trajectory by the provided step size
134    pub fn every(&self, step: Duration) -> TrajIterator<S> {
135        self.every_between(step, self.first().epoch(), self.last().epoch())
136    }
137
138    /// Creates an iterator through the trajectory by the provided step size between the provided bounds
139    pub fn every_between(&self, step: Duration, start: Epoch, end: Epoch) -> TrajIterator<S> {
140        TrajIterator {
141            time_series: TimeSeries::inclusive(
142                start.max(self.first().epoch()),
143                end.min(self.last().epoch()),
144                step,
145            ),
146            traj: self,
147        }
148    }
149
150    /// Store this trajectory arc to a parquet file with the default configuration (depends on the state type, search for `export_params` in the documentation for details).
151    pub fn to_parquet_simple<P: AsRef<Path>>(
152        &self,
153        path: P,
154        almanac: Arc<Almanac>,
155    ) -> Result<PathBuf, Box<dyn Error>> {
156        self.to_parquet(path, None, ExportCfg::default(), almanac)
157    }
158
159    /// Store this trajectory arc to a parquet file with the provided configuration
160    pub fn to_parquet_with_cfg<P: AsRef<Path>>(
161        &self,
162        path: P,
163        cfg: ExportCfg,
164        almanac: Arc<Almanac>,
165    ) -> Result<PathBuf, Box<dyn Error>> {
166        self.to_parquet(path, None, cfg, almanac)
167    }
168
169    /// A shortcut to `to_parquet_with_cfg`
170    pub fn to_parquet_with_step<P: AsRef<Path>>(
171        &self,
172        path: P,
173        step: Duration,
174        almanac: Arc<Almanac>,
175    ) -> Result<(), Box<dyn Error>> {
176        self.to_parquet_with_cfg(
177            path,
178            ExportCfg {
179                step: Some(step),
180                ..Default::default()
181            },
182            almanac,
183        )?;
184
185        Ok(())
186    }
187
188    /// Store this trajectory arc to a parquet file with the provided configuration and event evaluators
189    pub fn to_parquet<P: AsRef<Path>>(
190        &self,
191        path: P,
192        events: Option<Vec<&dyn EventEvaluator<S>>>,
193        cfg: ExportCfg,
194        almanac: Arc<Almanac>,
195    ) -> Result<PathBuf, Box<dyn Error>> {
196        let tick = Epoch::now().unwrap();
197        info!("Exporting trajectory to parquet file...");
198
199        // Grab the path here before we move stuff.
200        let path_buf = cfg.actual_path(path);
201
202        // Build the schema
203        let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
204
205        let frame = self.states[0].frame();
206        let more_meta = Some(vec![(
207            "Frame".to_string(),
208            serde_dhall::serialize(&frame).to_string().map_err(|e| {
209                Box::new(InputOutputError::SerializeDhall {
210                    what: format!("frame `{frame}`"),
211                    err: e.to_string(),
212                })
213            })?,
214        )]);
215
216        let mut fields = match cfg.fields {
217            Some(fields) => fields,
218            None => S::export_params(),
219        };
220
221        // Check that we can retrieve this information
222        fields.retain(|param| self.first().value(*param).is_ok());
223
224        for field in &fields {
225            hdrs.push(field.to_field(more_meta.clone()));
226        }
227
228        if let Some(events) = events.as_ref() {
229            for event in events {
230                let field = Field::new(format!("{event}"), DataType::Float64, false);
231                hdrs.push(field);
232            }
233        }
234
235        // Build the schema
236        let schema = Arc::new(Schema::new(hdrs));
237        let mut record: Vec<Arc<dyn Array>> = Vec::new();
238
239        // Build the states iterator -- this does require copying the current states but I can't either get a reference or a copy of all the states.
240        let states = if cfg.start_epoch.is_some() || cfg.end_epoch.is_some() || cfg.step.is_some() {
241            // Must interpolate the data!
242            let start = cfg.start_epoch.unwrap_or_else(|| self.first().epoch());
243            let end = cfg.end_epoch.unwrap_or_else(|| self.last().epoch());
244            let step = cfg.step.unwrap_or_else(|| 1.minutes());
245            self.every_between(step, start, end).collect::<Vec<S>>()
246        } else {
247            self.states.to_vec()
248        };
249
250        // Build all of the records
251
252        // Epochs
253        let mut utc_epoch = StringBuilder::new();
254        for s in &states {
255            utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
256        }
257        record.push(Arc::new(utc_epoch.finish()));
258
259        // Add all of the fields
260        for field in fields {
261            if field == StateParameter::GuidanceMode {
262                let mut guid_mode = StringBuilder::new();
263                for s in &states {
264                    guid_mode
265                        .append_value(format!("{:?}", GuidanceMode::from(s.value(field).unwrap())));
266                }
267                record.push(Arc::new(guid_mode.finish()));
268            } else {
269                let mut data = Float64Builder::new();
270                for s in &states {
271                    data.append_value(s.value(field).unwrap());
272                }
273                record.push(Arc::new(data.finish()));
274            }
275        }
276
277        info!(
278            "Serialized {} states from {} to {}",
279            states.len(),
280            states.first().unwrap().epoch(),
281            states.last().unwrap().epoch()
282        );
283
284        // Add all of the evaluated events
285        if let Some(events) = events {
286            info!("Evaluating {} event(s)", events.len());
287            for event in events {
288                let mut data = Float64Builder::new();
289                for s in &states {
290                    data.append_value(event.eval(s, almanac.clone()).map_err(Box::new)?);
291                }
292                record.push(Arc::new(data.finish()));
293            }
294        }
295
296        // Serialize all of the devices and add that to the parquet file too.
297        let mut metadata = HashMap::new();
298        metadata.insert("Purpose".to_string(), "Trajectory data".to_string());
299        if let Some(add_meta) = cfg.metadata {
300            for (k, v) in add_meta {
301                metadata.insert(k, v);
302            }
303        }
304
305        let props = pq_writer(Some(metadata));
306
307        let file = File::create(&path_buf)?;
308        let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
309
310        let batch = RecordBatch::try_new(schema, record)?;
311        writer.write(&batch)?;
312        writer.close()?;
313
314        // Return the path this was written to
315        let tock_time = Epoch::now().unwrap() - tick;
316        info!(
317            "Trajectory written to {} in {tock_time}",
318            path_buf.display()
319        );
320        Ok(path_buf)
321    }
322
323    /// Allows resampling this trajectory at a fixed interval instead of using the propagator step size.
324    /// This may lead to aliasing due to the Nyquist–Shannon sampling theorem.
325    pub fn resample(&self, step: Duration) -> Result<Self, NyxError> {
326        if self.states.is_empty() {
327            return Err(NyxError::Trajectory {
328                source: TrajError::CreationError {
329                    msg: "No trajectory to convert".to_string(),
330                },
331            });
332        }
333
334        let mut traj = Self::new();
335        for state in self.every(step) {
336            traj.states.push(state);
337        }
338
339        traj.finalize();
340
341        Ok(traj)
342    }
343
344    /// Rebuilds this trajectory with the provided epochs.
345    /// This may lead to aliasing due to the Nyquist–Shannon sampling theorem.
346    pub fn rebuild(&self, epochs: &[Epoch]) -> Result<Self, NyxError> {
347        if self.states.is_empty() {
348            return Err(NyxError::Trajectory {
349                source: TrajError::CreationError {
350                    msg: "No trajectory to convert".to_string(),
351                },
352            });
353        }
354
355        let mut traj = Self::new();
356        for epoch in epochs {
357            traj.states.push(self.at(*epoch)?);
358        }
359
360        traj.finalize();
361
362        Ok(traj)
363    }
364
365    /// Export the difference in RIC from of this trajectory compare to the "other" trajectory in parquet format.
366    ///
367    /// # Notes
368    /// + The RIC frame accounts for the transport theorem by performing a finite differencing of the RIC frame.
369    pub fn ric_diff_to_parquet<P: AsRef<Path>>(
370        &self,
371        other: &Self,
372        path: P,
373        cfg: ExportCfg,
374    ) -> Result<PathBuf, Box<dyn Error>> {
375        let tick = Epoch::now().unwrap();
376        info!("Exporting trajectory to parquet file...");
377
378        // Grab the path here before we move stuff.
379        let path_buf = cfg.actual_path(path);
380
381        // Build the schema
382        let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
383
384        // Add the RIC headers
385        for coord in ["X", "Y", "Z"] {
386            let mut meta = HashMap::new();
387            meta.insert("unit".to_string(), "km".to_string());
388
389            let field = Field::new(
390                format!("Delta {coord} (RIC) (km)"),
391                DataType::Float64,
392                false,
393            )
394            .with_metadata(meta);
395
396            hdrs.push(field);
397        }
398
399        for coord in ["x", "y", "z"] {
400            let mut meta = HashMap::new();
401            meta.insert("unit".to_string(), "km/s".to_string());
402
403            let field = Field::new(
404                format!("Delta V{coord} (RIC) (km/s)"),
405                DataType::Float64,
406                false,
407            )
408            .with_metadata(meta);
409
410            hdrs.push(field);
411        }
412
413        let frame = self.states[0].frame();
414        let more_meta = Some(vec![(
415            "Frame".to_string(),
416            serde_dhall::serialize(&frame)
417                .to_string()
418                .unwrap_or(frame.to_string()),
419        )]);
420
421        let mut cfg = cfg;
422
423        let mut fields = match cfg.fields {
424            Some(fields) => fields,
425            None => S::export_params(),
426        };
427
428        // Remove disallowed field and check that we can retrieve this information
429        fields.retain(|param| {
430            param != &StateParameter::GuidanceMode && self.first().value(*param).is_ok()
431        });
432
433        for field in &fields {
434            hdrs.push(field.to_field(more_meta.clone()));
435        }
436
437        // Build the schema
438        let schema = Arc::new(Schema::new(hdrs));
439        let mut record: Vec<Arc<dyn Array>> = Vec::new();
440
441        // Ensure the times match.
442        cfg.start_epoch = if self.first().epoch() > other.first().epoch() {
443            Some(self.first().epoch())
444        } else {
445            Some(other.first().epoch())
446        };
447
448        cfg.end_epoch = if self.last().epoch() > other.last().epoch() {
449            Some(other.last().epoch())
450        } else {
451            Some(self.last().epoch())
452        };
453
454        // Build the states iterator
455        let step = cfg.step.unwrap_or_else(|| 1.minutes());
456        let self_states = self
457            .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
458            .collect::<Vec<S>>();
459
460        let other_states = other
461            .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
462            .collect::<Vec<S>>();
463
464        // Build an array of all the RIC differences
465        let mut ric_diff = Vec::with_capacity(other_states.len());
466        for (ii, other_state) in other_states.iter().enumerate() {
467            let self_orbit = self_states[ii].orbit();
468            let other_orbit = other_state.orbit();
469
470            let this_ric_diff = self_orbit.ric_difference(&other_orbit).map_err(Box::new)?;
471
472            ric_diff.push(this_ric_diff);
473        }
474
475        // Build all of the records
476
477        // Epochs (both match for self and others)
478        let mut utc_epoch = StringBuilder::new();
479        for s in &self_states {
480            utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
481        }
482        record.push(Arc::new(utc_epoch.finish()));
483
484        // Add the RIC data
485        for coord_no in 0..6 {
486            let mut data = Float64Builder::new();
487            for this_ric_dff in &ric_diff {
488                data.append_value(this_ric_dff.to_cartesian_pos_vel()[coord_no]);
489            }
490            record.push(Arc::new(data.finish()));
491        }
492
493        // Add all of the fields
494        for field in fields {
495            let mut data = Float64Builder::new();
496            for (ii, self_state) in self_states.iter().enumerate() {
497                let self_val = self_state.value(field).unwrap();
498                let other_val = other_states[ii].value(field).unwrap();
499                data.append_value(self_val - other_val);
500            }
501            record.push(Arc::new(data.finish()));
502        }
503
504        info!("Serialized {} states differences", self_states.len());
505
506        // Serialize all of the devices and add that to the parquet file too.
507        let mut metadata = HashMap::new();
508        metadata.insert(
509            "Purpose".to_string(),
510            "Trajectory difference data".to_string(),
511        );
512        if let Some(add_meta) = cfg.metadata {
513            for (k, v) in add_meta {
514                metadata.insert(k, v);
515            }
516        }
517
518        let props = pq_writer(Some(metadata));
519
520        let file = File::create(&path_buf)?;
521        let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
522
523        let batch = RecordBatch::try_new(schema, record)?;
524        writer.write(&batch)?;
525        writer.close()?;
526
527        // Return the path this was written to
528        let tock_time = Epoch::now().unwrap() - tick;
529        info!(
530            "Trajectory written to {} in {tock_time}",
531            path_buf.display()
532        );
533        Ok(path_buf)
534    }
535}
536
537impl<S: Interpolatable> ops::Add for Traj<S>
538where
539    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
540{
541    type Output = Result<Traj<S>, NyxError>;
542
543    /// Add one trajectory to another. If they do not overlap to within 10ms, a warning will be printed.
544    fn add(self, other: Traj<S>) -> Self::Output {
545        &self + &other
546    }
547}
548
549impl<S: Interpolatable> ops::Add<&Traj<S>> for &Traj<S>
550where
551    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
552{
553    type Output = Result<Traj<S>, NyxError>;
554
555    /// Add one trajectory to another, returns an error if the frames don't match
556    fn add(self, other: &Traj<S>) -> Self::Output {
557        if self.first().frame() != other.first().frame() {
558            Err(NyxError::Trajectory {
559                source: TrajError::CreationError {
560                    msg: format!(
561                        "Frame mismatch in add operation: {} != {}",
562                        self.first().frame(),
563                        other.first().frame()
564                    ),
565                },
566            })
567        } else {
568            if self.last().epoch() < other.first().epoch() {
569                let gap = other.first().epoch() - self.last().epoch();
570                warn!(
571                    "Resulting merged trajectory will have a time-gap of {} starting at {}",
572                    gap,
573                    self.last().epoch()
574                );
575            }
576
577            let mut me = self.clone();
578            // Now start adding the other segments while correcting the index
579            for state in &other
580                .states
581                .iter()
582                .copied()
583                .filter(|s| s.epoch() > self.last().epoch())
584                .collect::<Vec<S>>()
585            {
586                me.states.push(*state);
587            }
588            me.finalize();
589
590            Ok(me)
591        }
592    }
593}
594
595impl<S: Interpolatable> ops::AddAssign<&Traj<S>> for Traj<S>
596where
597    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
598{
599    /// Attempt to add two trajectories together and assign it to `self`
600    ///
601    /// # Warnings
602    /// 1. This will panic if the frames mismatch!
603    /// 2. This is inefficient because both `self` and `rhs` are cloned.
604    fn add_assign(&mut self, rhs: &Self) {
605        *self = (self.clone() + rhs.clone()).unwrap();
606    }
607}
608
609impl<S: Interpolatable> fmt::Display for Traj<S>
610where
611    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
612{
613    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
614        if self.states.is_empty() {
615            write!(f, "Empty Trajectory!")
616        } else {
617            let dur = self.last().epoch() - self.first().epoch();
618            write!(
619                f,
620                "Trajectory {}in {} from {} to {} ({}, or {:.3} s) [{} states]",
621                match &self.name {
622                    Some(name) => format!("of {name} "),
623                    None => String::new(),
624                },
625                self.first().frame(),
626                self.first().epoch(),
627                self.last().epoch(),
628                dur,
629                dur.to_seconds(),
630                self.states.len()
631            )
632        }
633    }
634}
635
636impl<S: Interpolatable> fmt::Debug for Traj<S>
637where
638    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
639{
640    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
641        write!(f, "{self}",)
642    }
643}
644
645impl<S: Interpolatable> Default for Traj<S>
646where
647    DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
648{
649    fn default() -> Self {
650        Self::new()
651    }
652}