#![allow(non_upper_case_globals)]
#![allow(non_snake_case)]

use average::MeanWithError;
use nalgebra::*;
use rand::prelude::*;
use rayon::prelude::*;

use crate::gradient::{gd::GD, saga::SAGA, sgd::SGD, svrg2::SVRG2};
use crate::gradient::{FixedBatchGradientOracle, GradientEstimator};
use crate::model::gaussian_model::GaussianModel;
use crate::tool::simple_type_name;
use crate::uld::info::BasicInfoBuilder;
use crate::uld::multiple::{LPMLikeRMMMS, ULDProcessNewMS, OMMS, RMMMS};
use crate::uld::{optimal::OM, rmm::RMM};
use crate::uld::{FixedBatchSeq, ULDProcess, ULDProcessNew};

mod gaussian_model_parameters {
    #![allow(non_upper_case_globals)]
    use crate::model::gaussian_model::GaussianModel;
    use crate::uld::ULDParam;
    use nalgebra::*;

    pub type DType = f64;

    pub const SEED: u64 = 123;

    pub type D = Const<5>;
    pub const L: DType = 10.0;
    pub const gamma: DType = 2.0;
    pub const temperature: DType = 1.0;
    pub const uldp: ULDParam<DType> = ULDParam::<DType> {
        xi: L,
        temperature,
        gamma,
    };

    pub const data_size: usize = 100;
    pub const batch_size: usize = 20;

    pub type State = OVector<DType, D>;
    pub type Model = GaussianModel<DType, D>;
}
use gaussian_model_parameters::{DType, State};

macro_rules! make_iter {
  ($ALG:tt, single, ($ULDProcess:tt, $INFO:ty, $initial_x:tt, $initial_v:tt, $uldp:tt, $step:tt, $model:tt, $batch_size:tt, $rng:tt)) => {
    {
      let mut ge = $ALG::new(&$initial_x, &$model, $rng.clone());
      ge.set_batch_size($batch_size, &$model);
      let name = GradientEstimator::<State, Model>::name(&ge);
      let go = FixedBatchGradientOracle::new(&$model, $rng.clone(), $batch_size, ge);

      let uldprocess = $ULDProcess::new(($step), $uldp, $rng.clone());

      use std::marker::PhantomData;
      let iter = FixedBatchSeq {
          x: $initial_x,
          v: $initial_v,
          uldprocess,
          phantom:PhantomData::<$INFO>,
          go,
      };
      let iter2: Box<dyn Iterator<Item = _>+Send> = Box::new(iter);
      (name,iter2)
    }
  };
  ($ALG:tt, single, ($ULDProcess:tt, $INFO:ty, $initial_x:tt, $initial_v:tt, $uldp:tt, $segments:tt, $step:tt, $model:tt, $batch_size:tt, $rng:tt)) => {
    {
      let mut ge = $ALG::new(&$initial_x, &$model, $rng.clone());
      ge.set_batch_size($batch_size, &$model);
      let name = GradientEstimator::<State, Model>::name(&ge);
      let go = FixedBatchGradientOracle::new(&$model, $rng.clone(), $batch_size, ge);

      let uldprocess = $ULDProcess::new($segments, ($step), $uldp, $rng.clone());

      use std::marker::PhantomData;
      let iter = FixedBatchSeq {
          x: $initial_x,
          v: $initial_v,
          uldprocess,
          phantom:PhantomData::<$INFO>,
          go,
      };
      let iter2: Box<dyn Iterator<Item = _>+Send> = Box::new(iter);
      (name,iter2)
    }
  };
  [[$($ALGS:tt),*], $params:tt] => {
    vec![$(make_iter!($ALGS, single, $params)),*]
  };
}

/// We explorer how many segments are enough.
#[allow(non_snake_case)]
pub fn gaussian_model_plot_scale3_data() -> anyhow::Result<()> {
    rayon::scope(|s| {
        s.spawn(|_| {
            println!("RMM");
            _gaussian_model_plot_scale3_data::<RMMMS<_>>().unwrap();
        });
    });
    Ok(())
}
pub fn _gaussian_model_plot_scale3_data<U: ULDProcessNewMS<DType> + ULDProcess<State> + Send>() -> anyhow::Result<()> {
    use gaussian_model_parameters::*;
    let step_var = 0.1;
    let burnin_var = 0;
    let total_iterations_var = (10_000_000.0 / (L * step_var)) as usize;
    // let total_iterations_var = (1_0_000.0 / (L * step_var)) as usize;

    use std::collections::HashMap;
    use std::sync::Mutex;
    let result = Mutex::new(HashMap::new());

    let mut rng = StdRng::seed_from_u64(SEED);
    let model = GaussianModel::init(&mut rng, data_size, L);
    let initial_x: State = model.init_state(&mut rng);
    let initial_v = State::zeros();

    let max_segments = 100;
    (2..=max_segments).into_par_iter().for_each(|segments| {
        let segments = segments as usize;
        let (_name, it) = make_iter![
            GD,
            single,
            (
                U,
                BasicInfoBuilder::<_>,
                initial_x,
                initial_v,
                uldp,
                segments,
                (step_var / segments as f64),
                model,
                batch_size,
                rng
            )
        ];
        let (_, ref_it) = make_iter![
            GD,
            single,
            (
                // RMM,
                // LPMLikeRMM,
                OM,
                BasicInfoBuilder::<_>,
                initial_x,
                initial_v,
                uldp,
                (step_var / (segments as f64)),
                model,
                batch_size,
                rng
            )
        ];
        let values: MeanWithError = it
            .zip(ref_it.skip(segments - 1).step_by(segments))
            .skip(burnin_var)
            .take(total_iterations_var)
            .map(|((x, v, _, _), (gd_x, gd_v, _, _))| {
                let x_diff = x - gd_x;
                let v_diff = v - gd_v;
                (x_diff.norm_squared() + v_diff.norm_squared()).sqrt()
            })
            .collect();
        result
            .lock()
            .unwrap()
            .entry("GD")
            .or_insert(Vec::new())
            .push([segments as f64, values.mean(), values.error()]);
    });
    let result = result.into_inner()?;
    let Uname = std::any::type_name::<U>()
        .split('<')
        .next()
        .unwrap()
        .rsplit(':')
        .next()
        .unwrap();
    for (name, v) in result {
        use std::fs::File;
        use std::io::{BufWriter, Write};
        let mut file = BufWriter::new(File::create(format!("out/{}_{}_scale3", Uname, name)).unwrap());

        for x in v {
            write!(file, "{},{},{}\n", x[0], x[1], x[2]).unwrap();
        }
    }

    Ok(())
}

/// We show the scaling of trajectory mean error to the step size
#[allow(non_snake_case)]
pub fn gaussian_model_plot_scale4_data() -> anyhow::Result<()> {
    rayon::scope(|s| {
        s.spawn(|_| {
            println!("RMM");
            _gaussian_model_plot_scale4_data::<RMMMS<_>>().unwrap();
        });
        s.spawn(|_| {
            println!("OM");
            _gaussian_model_plot_scale4_data::<OMMS<_>>().unwrap();
        });
        s.spawn(|_| {
            println!("LPMLikeRMM");
            _gaussian_model_plot_scale4_data::<LPMLikeRMMMS<_>>().unwrap();
        });
    });
    Ok(())
}
pub fn _gaussian_model_plot_scale4_data<U: ULDProcessNewMS<DType> + ULDProcess<State> + Send>() -> anyhow::Result<()> {
    use gaussian_model_parameters::*;
    let segments = 10;
    let step_num = 30;
    let max_step = 1e-1;
    // let min_step = 1e-4;
    let min_step = 1e-5;
    let step_range = (0..step_num)
        .map(|p| (p as f64) / ((step_num - 1) as f64))
        .map(|p| min_step * (max_step / min_step).powf(p))
        .collect::<Vec<_>>();
    let burnin_fn = |step_var| (0.0 / (L * step_var)) as usize;
    let total_iterations_fn = |step_var| (10.0 / (L * step_var)) as usize;

    use std::collections::HashMap;
    use std::sync::Mutex;
    let result = Mutex::new(HashMap::new());

    let mut rng = StdRng::seed_from_u64(SEED);
    let model = GaussianModel::init(&mut rng, data_size, L);
    let initial_x: State = model.init_state(&mut rng);
    let initial_v = State::zeros();

    step_range.into_par_iter().for_each(|step_var| {
        let burnin_var = burnin_fn(step_var);
        let total_iterations_var = total_iterations_fn(step_var);
        let iters = make_iter![
            [GD, SGD, SAGA, SVRG2],
            // [SVRG2],
            (
                U,
                BasicInfoBuilder::<_>,
                initial_x,
                initial_v,
                uldp,
                segments,
                (step_var / segments as f64),
                model,
                batch_size,
                rng
            )
        ];
        iters.into_par_iter().for_each(|(name, it)| {
            let (_, ref_it) = make_iter![
                GD,
                single,
                (
                    RMM,
                    BasicInfoBuilder::<_>,
                    initial_x,
                    initial_v,
                    uldp,
                    (step_var / segments as f64),
                    model,
                    batch_size,
                    rng
                )
            ];
            let mut gradient_num = 0;
            let values: MeanWithError = it
                .zip(ref_it.skip(segments - 1).step_by(segments))
                .skip(burnin_var)
                .take(total_iterations_var)
                .map(|((x, v, _, num), (ref_x, ref_v, _, _))| {
                    let x_diff = x - ref_x;
                    let v_diff = v - ref_v;
                    gradient_num = num;
                    (x_diff.norm_squared() + v_diff.norm_squared()).sqrt()
                })
                .collect();

            result.lock().unwrap().entry(name).or_insert(Vec::new()).push((
                step_var,
                gradient_num,
                values.mean(),
                values.error(),
            ));
        });
    });

    let result = result.into_inner()?;
    let uname = simple_type_name::<U>();
    for (name, v) in result {
        use std::fs::File;
        use std::io::{BufWriter, Write};
        let mut file = BufWriter::new(File::create(format!("out/{}_{}_scale4", uname, name)).unwrap());

        for x in v {
            write!(file, "{},{},{},{}\n", x.0, x.1, x.2, x.3).unwrap();
        }
    }

    Ok(())
}
