Skip to main content

nyx_space/mc/
montecarlo.rs

1/*
2    Nyx, blazing fast astrodynamics
3    Copyright (C) 2018-onwards Christopher Rabotin <christopher.rabotin@gmail.com>
4
5    This program is free software: you can redistribute it and/or modify
6    it under the terms of the GNU Affero General Public License as published
7    by the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU Affero General Public License for more details.
14
15    You should have received a copy of the GNU Affero General Public License
16    along with this program.  If not, see <https://www.gnu.org/licenses/>.
17*/
18
19use super::Pcg64Mcg;
20use crate::dynamics::Dynamics;
21use crate::linalg::allocator::Allocator;
22use crate::linalg::DefaultAllocator;
23use crate::mc::results::{PropResult, Results, Run};
24use crate::mc::DispersedState;
25use crate::md::trajectory::Interpolatable;
26use crate::propagators::Propagator;
27#[cfg(not(target_arch = "wasm32"))]
28use crate::time::Unit;
29use crate::time::{Duration, Epoch};
30use crate::State;
31use anise::almanac::Almanac;
32use anise::analysis::event::Event;
33use indicatif::{ParallelProgressIterator, ProgressBar, ProgressStyle};
34use log::info;
35use rand::SeedableRng;
36use rand_distr::Distribution;
37use rayon::prelude::ParallelIterator;
38use rayon::prelude::*;
39use std::fmt;
40use std::sync::mpsc::channel;
41use std::sync::Arc;
42#[cfg(not(target_arch = "wasm32"))]
43use std::time::Instant as StdInstant;
44
45/// A Monte Carlo framework, automatically running on all threads via a thread pool. This framework is targeted toward analysis of time-continuous variables.
46/// One caveat of the design is that the trajectory is used for post processing, not each individual state. This may prevent some event switching from being shown in GNC simulations.
47pub struct MonteCarlo<S: Interpolatable, Distr: Distribution<DispersedState<S>>>
48where
49    DefaultAllocator: Allocator<S::Size> + Allocator<S::Size, S::Size> + Allocator<S::VecLength>,
50{
51    /// Seed of the [64bit PCG random number generator](https://www.pcg-random.org/index.html)
52    pub seed: Option<u128>,
53    /// Generator of states for the Monte Carlo run
54    pub random_state: Distr,
55    /// Name of this run, will be reflected in the progress bar and in the output structure
56    pub scenario: String,
57    pub nominal_state: S,
58}
59
60impl<S: Interpolatable, Distr: Distribution<DispersedState<S>>> MonteCarlo<S, Distr>
61where
62    DefaultAllocator: Allocator<S::Size> + Allocator<S::Size, S::Size> + Allocator<S::VecLength>,
63{
64    pub fn new(
65        nominal_state: S,
66        random_variable: Distr,
67        scenario: String,
68        seed: Option<u128>,
69    ) -> Self {
70        Self {
71            random_state: random_variable,
72            seed,
73            scenario,
74            nominal_state,
75        }
76    }
77    // Just the template for the progress bar
78    fn progress_bar(&self, num_runs: usize) -> ProgressBar {
79        let pb = ProgressBar::new(num_runs.try_into().unwrap());
80        pb.set_style(
81            ProgressStyle::default_bar()
82                .template("[{elapsed_precise}] {bar:100.cyan/blue} {pos:>7}/{len:7} {msg}")
83                .unwrap()
84                .progress_chars("##-"),
85        );
86        pb.set_message(format!("{self}"));
87        pb
88    }
89
90    /// Generate states and propagate each independently until a specific event is found `trigger` times.
91    #[allow(clippy::needless_lifetimes)]
92    pub fn run_until_nth_event<D, F>(
93        self,
94        prop: Propagator<D>,
95        almanac: Arc<Almanac>,
96        max_duration: Duration,
97        event: &Event,
98        trigger: usize,
99        num_runs: usize,
100    ) -> Results<S, PropResult<S>>
101    where
102        D: Dynamics<StateType = S>,
103        DefaultAllocator: Allocator<<D::StateType as State>::Size>
104            + Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>
105            + Allocator<<D::StateType as State>::VecLength>,
106        <DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
107    {
108        self.resume_run_until_nth_event(prop, almanac, 0, max_duration, event, trigger, num_runs)
109    }
110
111    /// Generate states and propagate each independently until a specific event is found `trigger` times.
112    #[must_use = "Monte Carlo result must be used"]
113    #[allow(clippy::needless_lifetimes)]
114    pub fn resume_run_until_nth_event<D>(
115        &self,
116        prop: Propagator<D>,
117        almanac: Arc<Almanac>,
118        skip: usize,
119        max_duration: Duration,
120        event: &Event,
121        trigger: usize,
122        num_runs: usize,
123    ) -> Results<S, PropResult<S>>
124    where
125        D: Dynamics<StateType = S>,
126        DefaultAllocator: Allocator<<D::StateType as State>::Size>
127            + Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>
128            + Allocator<<D::StateType as State>::VecLength>,
129        <DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
130    {
131        // Generate the initial states
132        let init_states = self.generate_states(skip, num_runs, self.seed);
133        // Setup the progress bar
134        let pb = self.progress_bar(num_runs);
135        // Setup the thread friendly communication
136        let (tx, rx) = channel();
137
138        // Generate all states (must be done separately because the rng is not thread safe)
139        #[cfg(not(target_arch = "wasm32"))]
140        let start = StdInstant::now();
141
142        init_states.par_iter().progress_with(pb).for_each_with(
143            (prop, tx),
144            |(prop, tx), (index, dispersed_state)| {
145                let result = prop
146                    .with(dispersed_state.state, almanac.clone())
147                    .until_nth_event(max_duration, event, None, trigger);
148
149                // Build a single run result
150                let run = Run {
151                    index: *index,
152                    dispersed_state: dispersed_state.clone(),
153                    result: result.map(|r| PropResult {
154                        state: r.0,
155                        traj: r.1,
156                    }),
157                };
158                tx.send(run).unwrap();
159            },
160        );
161
162        #[cfg(not(target_arch = "wasm32"))]
163        {
164            let clock_time = StdInstant::now() - start;
165            info!(
166                "Propagated {} states in {}",
167                num_runs,
168                clock_time.as_secs_f64() * Unit::Second
169            );
170        }
171
172        // Collect all of the results and sort them by run index
173        let mut runs = rx
174            .iter()
175            .collect::<Vec<Run<D::StateType, PropResult<D::StateType>>>>();
176        runs.par_sort_by_key(|run| run.index);
177
178        Results {
179            runs,
180            scenario: self.scenario.clone(),
181        }
182    }
183
184    /// Generate states and propagate each independently until a specific event is found `trigger` times.
185    #[must_use = "Monte Carlo result must be used"]
186    #[allow(clippy::needless_lifetimes)]
187    pub fn run_until_epoch<D>(
188        self,
189        prop: Propagator<D>,
190        almanac: Arc<Almanac>,
191        end_epoch: Epoch,
192        num_runs: usize,
193    ) -> Results<S, PropResult<S>>
194    where
195        D: Dynamics<StateType = S>,
196
197        DefaultAllocator: Allocator<<D::StateType as State>::Size>
198            + Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>
199            + Allocator<<D::StateType as State>::VecLength>,
200        <DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
201    {
202        self.resume_run_until_epoch(prop, almanac, 0, end_epoch, num_runs)
203    }
204
205    /// Resumes a Monte Carlo run by skipping the first `skip` items, generating states only after that, and propagate each independently until the specified epoch.
206    #[must_use = "Monte Carlo result must be used"]
207    #[allow(clippy::needless_lifetimes)]
208    pub fn resume_run_until_epoch<D>(
209        &self,
210        prop: Propagator<D>,
211        almanac: Arc<Almanac>,
212        skip: usize,
213        end_epoch: Epoch,
214        num_runs: usize,
215    ) -> Results<S, PropResult<S>>
216    where
217        D: Dynamics<StateType = S>,
218
219        DefaultAllocator: Allocator<<D::StateType as State>::Size>
220            + Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>
221            + Allocator<<D::StateType as State>::VecLength>,
222        <DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
223    {
224        // Generate the initial states
225        let init_states = self.generate_states(skip, num_runs, self.seed);
226        // Setup the progress bar
227        let pb = self.progress_bar(num_runs);
228        // Setup the thread friendly communication
229        let (tx, rx) = channel();
230
231        // And propagate on the thread pool
232        #[cfg(not(target_arch = "wasm32"))]
233        let start = StdInstant::now();
234        init_states.par_iter().progress_with(pb).for_each_with(
235            (prop, tx),
236            |(arc_prop, tx), (index, dispersed_state)| {
237                let result = arc_prop
238                    .with(dispersed_state.state, almanac.clone())
239                    .quiet()
240                    .until_epoch_with_traj(end_epoch);
241
242                // Build a single run result
243                let run = Run {
244                    index: *index,
245                    dispersed_state: dispersed_state.clone(),
246                    result: result.map(|r| PropResult {
247                        state: r.0,
248                        traj: r.1,
249                    }),
250                };
251
252                tx.send(run).unwrap();
253            },
254        );
255
256        #[cfg(not(target_arch = "wasm32"))]
257        {
258            let clock_time = StdInstant::now() - start;
259            info!(
260                "Propagated {} states in {}",
261                num_runs,
262                clock_time.as_secs_f64() * Unit::Second
263            );
264        }
265
266        // Collect all of the results and sort them by run index
267        let mut runs = rx.iter().collect::<Vec<Run<S, PropResult<S>>>>();
268        runs.par_sort_by_key(|run| run.index);
269
270        Results {
271            runs,
272            scenario: self.scenario.clone(),
273        }
274    }
275
276    /// Set up the seed and generate the states. This is useful for checking the generated states before running a large scale Monte Carlo.
277    #[must_use = "Generated states for a Monte Carlo run must be used"]
278    pub fn generate_states(
279        &self,
280        skip: usize,
281        num_runs: usize,
282        seed: Option<u128>,
283    ) -> Vec<(usize, DispersedState<S>)> {
284        // Setup the RNG
285        let rng = match seed {
286            Some(seed) => Pcg64Mcg::new(seed),
287            None => Pcg64Mcg::from_os_rng(),
288        };
289
290        // Generate the states, forcing the borrow as specified in the `sample_iter` docs.
291        (&self.random_state)
292            .sample_iter(rng)
293            .skip(skip)
294            .take(num_runs)
295            .enumerate()
296            .collect::<Vec<(usize, DispersedState<S>)>>()
297    }
298}
299
300impl<S: Interpolatable, Distr: Distribution<DispersedState<S>>> fmt::Display
301    for MonteCarlo<S, Distr>
302where
303    DefaultAllocator: Allocator<S::Size> + Allocator<S::Size, S::Size> + Allocator<S::VecLength>,
304{
305    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306        write!(
307            f,
308            "{} - Nyx Monte Carlo - seed: {:?}",
309            self.scenario, self.seed
310        )
311    }
312}
313
314impl<S: Interpolatable, Distr: Distribution<DispersedState<S>>> fmt::LowerHex
315    for MonteCarlo<S, Distr>
316where
317    DefaultAllocator: Allocator<S::Size> + Allocator<S::Size, S::Size> + Allocator<S::VecLength>,
318{
319    /// Returns a filename friendly name
320    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321        write!(
322            f,
323            "mc-data-{}-seed-{:?}",
324            self.scenario.replace(' ', "-"),
325            self.seed
326        )
327    }
328}