use std::collections::HashMap;
use std::error::Error;
use std::fs::File;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use crate::errors::{MonteCarloError, NoSuccessfulRunsSnafu, StateError};
use crate::io::watermark::pq_writer;
use crate::io::{ExportCfg, InputOutputError};
use crate::linalg::allocator::Allocator;
use crate::linalg::DefaultAllocator;
use crate::md::prelude::GuidanceMode;
use crate::md::trajectory::{Interpolatable, Traj};
use crate::md::{EventEvaluator, StateParameter};
use crate::propagators::PropagationError;
use crate::time::{Duration, Epoch, TimeUnits};
use anise::almanac::Almanac;
use anise::constants::frames::EARTH_J2000;
use arrow::array::{Array, Float64Builder, Int32Builder, StringBuilder};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use hifitime::TimeScale;
use parquet::arrow::ArrowWriter;
pub use rstats::Stats;
use snafu::ensure;
use super::DispersedState;
pub struct Run<S: Interpolatable, R>
where
DefaultAllocator: Allocator<S::Size> + Allocator<S::Size, S::Size> + Allocator<S::VecLength>,
<DefaultAllocator as Allocator<S::VecLength>>::Buffer<f64>: Send,
{
pub index: usize,
pub dispersed_state: DispersedState<S>,
pub result: Result<R, PropagationError>,
}
pub struct Results<S: Interpolatable, R>
where
DefaultAllocator: Allocator<S::Size> + Allocator<S::Size, S::Size> + Allocator<S::VecLength>,
<DefaultAllocator as Allocator<S::VecLength>>::Buffer<f64>: Send,
{
pub runs: Vec<Run<S, R>>,
pub scenario: String,
}
pub struct PropResult<S: Interpolatable>
where
DefaultAllocator: Allocator<S::Size> + Allocator<S::Size, S::Size> + Allocator<S::VecLength>,
<DefaultAllocator as Allocator<S::VecLength>>::Buffer<f64>: Send,
{
pub state: S,
pub traj: Traj<S>,
}
impl<S: Interpolatable> Results<S, PropResult<S>>
where
DefaultAllocator: Allocator<S::Size> + Allocator<S::Size, S::Size> + Allocator<S::VecLength>,
<DefaultAllocator as Allocator<S::VecLength>>::Buffer<f64>: Send,
{
pub fn every_value_of_between(
&self,
param: StateParameter,
step: Duration,
start: Epoch,
end: Epoch,
value_if_run_failed: Option<f64>,
) -> Vec<f64> {
let mut report = Vec::with_capacity(self.runs.len());
for run in &self.runs {
match &run.result {
Ok(r) => {
for state in r.traj.every_between(step, start, end) {
match state.value(param) {
Ok(val) => report.push(val),
Err(e) => match value_if_run_failed {
Some(val) => report.push(val),
None => {
warn!("run #{}: {}, skipping {} in report", run.index, e, param)
}
},
}
}
}
Err(e) => match value_if_run_failed {
Some(val) => report.push(val),
None => warn!(
"run #{} failed with {}, skipping {} in report",
run.index, e, param
),
},
}
}
report
}
pub fn every_value_of(
&self,
param: StateParameter,
step: Duration,
value_if_run_failed: Option<f64>,
) -> Vec<f64> {
let mut report = Vec::with_capacity(self.runs.len());
for run in &self.runs {
match &run.result {
Ok(r) => {
for state in r.traj.every(step) {
match state.value(param) {
Ok(val) => report.push(val),
Err(e) => match value_if_run_failed {
Some(val) => report.push(val),
None => {
warn!("run #{}: {}, skipping {} in report", run.index, e, param)
}
},
}
}
}
Err(e) => match value_if_run_failed {
Some(val) => report.push(val),
None => warn!(
"run #{} failed with {}, skipping {} in report",
run.index, e, param
),
},
}
}
report
}
pub fn first_values_of(
&self,
param: StateParameter,
value_if_run_failed: Option<f64>,
) -> Vec<f64> {
let mut report = Vec::with_capacity(self.runs.len());
for run in &self.runs {
match &run.result {
Ok(r) => match r.traj.first().value(param) {
Ok(val) => report.push(val),
Err(e) => match value_if_run_failed {
Some(val) => report.push(val),
None => {
warn!("run #{}: {}, skipping {} in report", run.index, e, param)
}
},
},
Err(e) => match value_if_run_failed {
Some(val) => report.push(val),
None => warn!(
"run #{} failed with {}, skipping {} in report",
run.index, e, param
),
},
}
}
report
}
pub fn last_values_of(
&self,
param: StateParameter,
value_if_run_failed: Option<f64>,
) -> Vec<f64> {
let mut report = Vec::with_capacity(self.runs.len());
for run in &self.runs {
match &run.result {
Ok(r) => match r.traj.last().value(param) {
Ok(val) => report.push(val),
Err(e) => match value_if_run_failed {
Some(val) => report.push(val),
None => {
warn!("run #{}: {}, skipping {} in report", run.index, e, param)
}
},
},
Err(e) => match value_if_run_failed {
Some(val) => report.push(val),
None => warn!(
"run #{} failed with {}, skipping {} in report",
run.index, e, param
),
},
}
}
report
}
pub fn dispersion_values_of(&self, param: StateParameter) -> Result<Vec<f64>, MonteCarloError> {
let mut report = Vec::with_capacity(self.runs.len());
'run_loop: for run in &self.runs {
for (dparam, val) in &run.dispersed_state.actual_dispersions {
if dparam == ¶m {
report.push(*val);
continue 'run_loop;
}
}
return Err(MonteCarloError::StateError {
source: StateError::Unavailable { param },
});
}
Ok(report)
}
pub fn to_parquet<P: AsRef<Path>>(
&self,
path: P,
events: Option<Vec<&dyn EventEvaluator<S>>>,
cfg: ExportCfg,
almanac: Arc<Almanac>,
) -> Result<PathBuf, Box<dyn Error>> {
let tick = Epoch::now().unwrap();
info!("Exporting Monte Carlo results to parquet file...");
let path_buf = cfg.actual_path(path);
let mut hdrs = vec![
Field::new("Epoch (UTC)", DataType::Utf8, false),
Field::new("Monte Carlo Run Index", DataType::Int32, false),
];
let mut frame = EARTH_J2000;
let mut fields = match cfg.fields {
Some(fields) => fields,
None => S::export_params(),
};
let mut start = None;
let mut end = None;
let mut all_states: Vec<S> = vec![];
let mut run_indexes: Vec<i32> = vec![];
for run in &self.runs {
if let Ok(success) = &run.result {
if start.is_none() {
frame = success.state.frame();
fields.retain(|param| success.state.value(*param).is_ok());
start = Some(success.traj.first().epoch());
end = Some(success.state.epoch());
}
let states =
if cfg.start_epoch.is_some() || cfg.end_epoch.is_some() || cfg.step.is_some() {
let start = cfg.start_epoch.unwrap_or_else(|| start.unwrap());
let end = cfg.end_epoch.unwrap_or_else(|| end.unwrap());
let step = cfg.step.unwrap_or_else(|| 1.minutes());
success
.traj
.every_between(step, start, end)
.collect::<Vec<S>>()
} else {
success.traj.states.to_vec()
};
for _ in 0..states.len() {
run_indexes.push(run.index as i32);
}
all_states.extend(states.iter());
}
}
ensure!(
start.is_some(),
NoSuccessfulRunsSnafu {
action: "export",
num_runs: self.runs.len()
}
);
let more_meta = Some(vec![(
"Frame".to_string(),
serde_dhall::serialize(&frame).to_string().map_err(|e| {
Box::new(InputOutputError::SerializeDhall {
what: format!("frame `{frame}`"),
err: e.to_string(),
})
})?,
)]);
for field in &fields {
hdrs.push(field.to_field(more_meta.clone()));
}
if let Some(events) = events.as_ref() {
for event in events {
let field = Field::new(format!("{event}"), DataType::Float64, false);
hdrs.push(field);
}
}
let schema = Arc::new(Schema::new(hdrs));
let mut record: Vec<Arc<dyn Array>> = Vec::new();
let mut utc_epoch = StringBuilder::new();
let mut idx_col = Int32Builder::new();
for (sno, s) in all_states.iter().enumerate() {
utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
idx_col.append_value(run_indexes[sno]);
}
record.push(Arc::new(utc_epoch.finish()));
record.push(Arc::new(idx_col.finish()));
for field in fields {
if field == StateParameter::GuidanceMode {
let mut guid_mode = StringBuilder::new();
for s in &all_states {
guid_mode
.append_value(format!("{:?}", GuidanceMode::from(s.value(field).unwrap())));
}
record.push(Arc::new(guid_mode.finish()));
} else {
let mut data = Float64Builder::new();
for s in &all_states {
data.append_value(s.value(field).unwrap());
}
record.push(Arc::new(data.finish()));
}
}
info!(
"Serialized {} states from {} to {}",
all_states.len(),
start.unwrap(),
end.unwrap()
);
if let Some(events) = events {
info!("Evaluating {} event(s)", events.len());
for event in events {
let mut data = Float64Builder::new();
for s in &all_states {
data.append_value(event.eval(s, almanac.clone()).map_err(Box::new)?);
}
record.push(Arc::new(data.finish()));
}
}
let mut metadata = HashMap::new();
metadata.insert(
"Purpose".to_string(),
"Monte Carlo Trajectory data".to_string(),
);
if let Some(add_meta) = cfg.metadata {
for (k, v) in add_meta {
metadata.insert(k, v);
}
}
let props = pq_writer(Some(metadata));
let file = File::create(&path_buf)?;
let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
let batch = RecordBatch::try_new(schema, record)?;
writer.write(&batch)?;
writer.close()?;
let tock_time = Epoch::now().unwrap() - tick;
info!(
"Trajectory written to {} in {tock_time}",
path_buf.display()
);
Ok(path_buf)
}
}