use super::traj_it::TrajIterator;
use super::{ExportCfg, InterpolationSnafu, INTERPOLATION_SAMPLES};
use super::{Interpolatable, TrajError};
use crate::errors::NyxError;
use crate::io::watermark::pq_writer;
use crate::io::InputOutputError;
use crate::linalg::allocator::Allocator;
use crate::linalg::DefaultAllocator;
use crate::md::prelude::{GuidanceMode, StateParameter};
use crate::md::EventEvaluator;
use crate::time::{Duration, Epoch, TimeSeries, TimeUnits};
use anise::almanac::Almanac;
use arrow::array::{Array, Float64Builder, StringBuilder};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use hifitime::TimeScale;
use parquet::arrow::ArrowWriter;
use snafu::ResultExt;
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use std::fs::File;
use std::iter::Iterator;
use std::ops;
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[derive(Clone, PartialEq)]
pub struct Traj<S: Interpolatable>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
pub name: Option<String>,
pub states: Vec<S>,
}
impl<S: Interpolatable> Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
pub fn new() -> Self {
Self {
name: None,
states: Vec::new(),
}
}
pub fn finalize(&mut self) {
self.states.dedup_by(|a, b| a.epoch().eq(&b.epoch()));
self.states.sort_by_key(|a| a.epoch());
}
pub fn at(&self, epoch: Epoch) -> Result<S, TrajError> {
if self.states.is_empty() || self.first().epoch() > epoch || self.last().epoch() < epoch {
return Err(TrajError::NoInterpolationData { epoch });
}
match self
.states
.binary_search_by(|state| state.epoch().cmp(&epoch))
{
Ok(idx) => {
Ok(self.states[idx])
}
Err(idx) => {
if idx == 0 || idx >= self.states.len() {
return Err(TrajError::NoInterpolationData { epoch });
}
let num_left = INTERPOLATION_SAMPLES / 2;
let mut first_idx = idx.saturating_sub(num_left);
let last_idx = self.states.len().min(first_idx + INTERPOLATION_SAMPLES);
if last_idx == self.states.len() {
first_idx = last_idx.saturating_sub(2 * num_left);
}
let mut states = Vec::with_capacity(last_idx - first_idx);
for idx in first_idx..last_idx {
states.push(self.states[idx]);
}
self.states[idx]
.interpolate(epoch, &states)
.context(InterpolationSnafu)
}
}
}
pub fn first(&self) -> &S {
self.states.first().unwrap()
}
pub fn last(&self) -> &S {
self.states.last().unwrap()
}
pub fn every(&self, step: Duration) -> TrajIterator<S> {
self.every_between(step, self.first().epoch(), self.last().epoch())
}
pub fn every_between(&self, step: Duration, start: Epoch, end: Epoch) -> TrajIterator<S> {
TrajIterator {
time_series: TimeSeries::inclusive(
start.max(self.first().epoch()),
end.min(self.last().epoch()),
step,
),
traj: self,
}
}
pub fn to_parquet_simple<P: AsRef<Path>>(
&self,
path: P,
almanac: Arc<Almanac>,
) -> Result<PathBuf, Box<dyn Error>> {
self.to_parquet(path, None, ExportCfg::default(), almanac)
}
pub fn to_parquet_with_cfg<P: AsRef<Path>>(
&self,
path: P,
cfg: ExportCfg,
almanac: Arc<Almanac>,
) -> Result<PathBuf, Box<dyn Error>> {
self.to_parquet(path, None, cfg, almanac)
}
pub fn to_parquet<P: AsRef<Path>>(
&self,
path: P,
events: Option<Vec<&dyn EventEvaluator<S>>>,
cfg: ExportCfg,
almanac: Arc<Almanac>,
) -> Result<PathBuf, Box<dyn Error>> {
let tick = Epoch::now().unwrap();
info!("Exporting trajectory to parquet file...");
let path_buf = cfg.actual_path(path);
let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
let frame = self.states[0].frame();
let more_meta = Some(vec![(
"Frame".to_string(),
serde_dhall::serialize(&frame).to_string().map_err(|e| {
Box::new(InputOutputError::SerializeDhall {
what: format!("frame `{frame}`"),
err: e.to_string(),
})
})?,
)]);
let mut fields = match cfg.fields {
Some(fields) => fields,
None => S::export_params(),
};
fields.retain(|param| self.first().value(*param).is_ok());
for field in &fields {
hdrs.push(field.to_field(more_meta.clone()));
}
if let Some(events) = events.as_ref() {
for event in events {
let field = Field::new(format!("{event}"), DataType::Float64, false);
hdrs.push(field);
}
}
let schema = Arc::new(Schema::new(hdrs));
let mut record: Vec<Arc<dyn Array>> = Vec::new();
let states = if cfg.start_epoch.is_some() || cfg.end_epoch.is_some() || cfg.step.is_some() {
let start = cfg.start_epoch.unwrap_or_else(|| self.first().epoch());
let end = cfg.end_epoch.unwrap_or_else(|| self.last().epoch());
let step = cfg.step.unwrap_or_else(|| 1.minutes());
self.every_between(step, start, end).collect::<Vec<S>>()
} else {
self.states.to_vec()
};
let mut utc_epoch = StringBuilder::new();
for s in &states {
utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
}
record.push(Arc::new(utc_epoch.finish()));
for field in fields {
if field == StateParameter::GuidanceMode {
let mut guid_mode = StringBuilder::new();
for s in &states {
guid_mode
.append_value(format!("{:?}", GuidanceMode::from(s.value(field).unwrap())));
}
record.push(Arc::new(guid_mode.finish()));
} else {
let mut data = Float64Builder::new();
for s in &states {
data.append_value(s.value(field).unwrap());
}
record.push(Arc::new(data.finish()));
}
}
info!(
"Serialized {} states from {} to {}",
states.len(),
states.first().unwrap().epoch(),
states.last().unwrap().epoch()
);
if let Some(events) = events {
info!("Evaluating {} event(s)", events.len());
for event in events {
let mut data = Float64Builder::new();
for s in &states {
data.append_value(event.eval(s, almanac.clone()).map_err(Box::new)?);
}
record.push(Arc::new(data.finish()));
}
}
let mut metadata = HashMap::new();
metadata.insert("Purpose".to_string(), "Trajectory data".to_string());
if let Some(add_meta) = cfg.metadata {
for (k, v) in add_meta {
metadata.insert(k, v);
}
}
let props = pq_writer(Some(metadata));
let file = File::create(&path_buf)?;
let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
let batch = RecordBatch::try_new(schema, record)?;
writer.write(&batch)?;
writer.close()?;
let tock_time = Epoch::now().unwrap() - tick;
info!(
"Trajectory written to {} in {tock_time}",
path_buf.display()
);
Ok(path_buf)
}
pub fn resample(&self, step: Duration) -> Result<Self, NyxError> {
if self.states.is_empty() {
return Err(NyxError::Trajectory {
source: TrajError::CreationError {
msg: "No trajectory to convert".to_string(),
},
});
}
let mut traj = Self::new();
for state in self.every(step) {
traj.states.push(state);
}
traj.finalize();
Ok(traj)
}
pub fn rebuild(&self, epochs: &[Epoch]) -> Result<Self, NyxError> {
if self.states.is_empty() {
return Err(NyxError::Trajectory {
source: TrajError::CreationError {
msg: "No trajectory to convert".to_string(),
},
});
}
let mut traj = Self::new();
for epoch in epochs {
traj.states.push(self.at(*epoch)?);
}
traj.finalize();
Ok(traj)
}
pub fn ric_diff_to_parquet<P: AsRef<Path>>(
&self,
other: &Self,
path: P,
cfg: ExportCfg,
) -> Result<PathBuf, Box<dyn Error>> {
let tick = Epoch::now().unwrap();
info!("Exporting trajectory to parquet file...");
let path_buf = cfg.actual_path(path);
let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
for coord in ["X", "Y", "Z"] {
let mut meta = HashMap::new();
meta.insert("unit".to_string(), "km".to_string());
let field = Field::new(
format!("Delta {coord} (RIC) (km)"),
DataType::Float64,
false,
)
.with_metadata(meta);
hdrs.push(field);
}
for coord in ["x", "y", "z"] {
let mut meta = HashMap::new();
meta.insert("unit".to_string(), "km/s".to_string());
let field = Field::new(
format!("Delta V{coord} (RIC) (km/s)"),
DataType::Float64,
false,
)
.with_metadata(meta);
hdrs.push(field);
}
let frame = self.states[0].frame();
let more_meta = Some(vec![(
"Frame".to_string(),
serde_dhall::serialize(&frame)
.to_string()
.unwrap_or(frame.to_string()),
)]);
let mut cfg = cfg;
let mut fields = match cfg.fields {
Some(fields) => fields,
None => S::export_params(),
};
fields.retain(|param| {
param != &StateParameter::GuidanceMode && self.first().value(*param).is_ok()
});
for field in &fields {
hdrs.push(field.to_field(more_meta.clone()));
}
let schema = Arc::new(Schema::new(hdrs));
let mut record: Vec<Arc<dyn Array>> = Vec::new();
cfg.start_epoch = if self.first().epoch() > other.first().epoch() {
Some(self.first().epoch())
} else {
Some(other.first().epoch())
};
cfg.end_epoch = if self.last().epoch() > other.last().epoch() {
Some(other.last().epoch())
} else {
Some(self.last().epoch())
};
let step = cfg.step.unwrap_or_else(|| 1.minutes());
let self_states = self
.every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
.collect::<Vec<S>>();
let other_states = other
.every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
.collect::<Vec<S>>();
let mut ric_diff = Vec::with_capacity(other_states.len());
for (ii, other_state) in other_states.iter().enumerate() {
let self_orbit = self_states[ii].orbit();
let other_orbit = other_state.orbit();
let this_ric_diff = self_orbit.ric_difference(&other_orbit).map_err(Box::new)?;
ric_diff.push(this_ric_diff);
}
let mut utc_epoch = StringBuilder::new();
for s in &self_states {
utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
}
record.push(Arc::new(utc_epoch.finish()));
for coord_no in 0..6 {
let mut data = Float64Builder::new();
for this_ric_dff in &ric_diff {
data.append_value(this_ric_dff.to_cartesian_pos_vel()[coord_no]);
}
record.push(Arc::new(data.finish()));
}
for field in fields {
let mut data = Float64Builder::new();
for (ii, self_state) in self_states.iter().enumerate() {
let self_val = self_state.value(field).unwrap();
let other_val = other_states[ii].value(field).unwrap();
data.append_value(self_val - other_val);
}
record.push(Arc::new(data.finish()));
}
info!("Serialized {} states differences", self_states.len());
let mut metadata = HashMap::new();
metadata.insert(
"Purpose".to_string(),
"Trajectory difference data".to_string(),
);
if let Some(add_meta) = cfg.metadata {
for (k, v) in add_meta {
metadata.insert(k, v);
}
}
let props = pq_writer(Some(metadata));
let file = File::create(&path_buf)?;
let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
let batch = RecordBatch::try_new(schema, record)?;
writer.write(&batch)?;
writer.close()?;
let tock_time = Epoch::now().unwrap() - tick;
info!(
"Trajectory written to {} in {tock_time}",
path_buf.display()
);
Ok(path_buf)
}
}
impl<S: Interpolatable> ops::Add for Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
type Output = Result<Traj<S>, NyxError>;
fn add(self, other: Traj<S>) -> Self::Output {
&self + &other
}
}
impl<S: Interpolatable> ops::Add<&Traj<S>> for &Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
type Output = Result<Traj<S>, NyxError>;
fn add(self, other: &Traj<S>) -> Self::Output {
if self.first().frame() != other.first().frame() {
Err(NyxError::Trajectory {
source: TrajError::CreationError {
msg: format!(
"Frame mismatch in add operation: {} != {}",
self.first().frame(),
other.first().frame()
),
},
})
} else {
if self.last().epoch() < other.first().epoch() {
let gap = other.first().epoch() - self.last().epoch();
warn!(
"Resulting merged trajectory will have a time-gap of {} starting at {}",
gap,
self.last().epoch()
);
}
let mut me = self.clone();
for state in &other
.states
.iter()
.copied()
.filter(|s| s.epoch() > self.last().epoch())
.collect::<Vec<S>>()
{
me.states.push(*state);
}
me.finalize();
Ok(me)
}
}
}
impl<S: Interpolatable> ops::AddAssign<&Traj<S>> for Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
fn add_assign(&mut self, rhs: &Self) {
*self = (self.clone() + rhs.clone()).unwrap();
}
}
impl<S: Interpolatable> fmt::Display for Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.states.is_empty() {
write!(f, "Empty Trajectory!")
} else {
let dur = self.last().epoch() - self.first().epoch();
write!(
f,
"Trajectory {}in {} from {} to {} ({}, or {:.3} s) [{} states]",
match &self.name {
Some(name) => format!("of {name} "),
None => String::new(),
},
self.first().frame(),
self.first().epoch(),
self.last().epoch(),
dur,
dur.to_seconds(),
self.states.len()
)
}
}
}
impl<S: Interpolatable> fmt::Debug for Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self}",)
}
}
impl<S: Interpolatable> Default for Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
fn default() -> Self {
Self::new()
}
}