#![allow(non_upper_case_globals)]
#![allow(non_snake_case)]

use average::MeanWithError;
use itertools::Itertools;
use nalgebra::*;
use rand::distributions::Standard;
use rand::prelude::*;
use rayon::prelude::*;
use std::fs::File;
use std::io::{prelude::*, BufReader, BufWriter};

use crate::for_type;
use crate::gradient::{gd::GD, saga::SAGA, sgd::SGD, svrg2::SVRG2};
use crate::gradient::{FixedBatchGradientOracle, GradientEstimator};
use crate::model::{logistic::Logistic, SumDecomposableModel};
use crate::tool::simple_type_name;
use crate::uld::info::BasicInfoBuilder;
use crate::uld::multiple::{LPMLikeRMMMS, ULDProcessNewMS, OMMS, RMMMS};
use crate::uld::{optimal::OM, rmm::RMM};
use crate::uld::{FixedBatchSeq, ULDParam, ULDProcess, ULDProcessNew};

pub type DType = f64;
pub const SEED: u64 = 123;
pub const d: usize = 24;
pub type D = Const<d>;
pub type State = OVector<DType, D>;
pub type Model = Logistic<DType, D>;

pub fn read_german(m: DType) -> anyhow::Result<(Logistic<DType, D>, Logistic<DType, D>)> {
    let file = File::open("dataset/german.numer_scale")?;
    let reader = BufReader::new(file);

    fn read_line(s: &str) -> anyhow::Result<(OVector<DType, D>, DType)> {
        use scan_fmt::scan_fmt;
        let mut a = OVector::<DType, D>::zeros_generic(D::from_usize(d), Const::<1>);
        let mut y = 0.0;

        for t in s.split(' ') {
            if t.contains(':') {
                let (i, v) = scan_fmt!(t, "{}:{}", usize, DType)?;
                a[i - 1] = v;
            } else {
                y = scan_fmt!(t, "{}", DType)?;
            }
        }
        Ok((a, y))
    }

    let yas = reader
        .lines()
        .map(|line| read_line(line?.trim()))
        .collect::<anyhow::Result<Vec<_>>>()?;
    let total_data_size = yas.len();
    let mut is_train = vec![false; total_data_size];
    is_train.iter_mut().take(total_data_size / 2).for_each(|is| *is = true);
    is_train.shuffle(&mut StdRng::seed_from_u64(SEED));

    let mut train_ys = DVector::<DType>::zeros(0);
    train_ys.extend(
        yas.iter()
            .zip(is_train.iter())
            .filter_map(|((_a, y), is)| if *is { Some(*y) } else { None }),
    );
    let mut train_a_s = OMatrix::<DType, D, Dynamic>::zeros_generic(D::from_usize(d), Dynamic::from_usize(0));
    train_a_s.extend(
        yas.iter()
            .zip(is_train.iter())
            .filter_map(|((a, _y), is)| if *is { Some(a.clone()) } else { None }),
    );
    let norm = train_a_s.norm();
    train_a_s.scale_mut(1.0 / norm);
    let mut test_ys = DVector::<DType>::zeros(0);
    test_ys.extend(
        yas.iter()
            .zip(is_train.iter())
            .filter_map(|((_a, y), is)| if !*is { Some(*y) } else { None }),
    );
    let mut test_a_s = OMatrix::<DType, D, Dynamic>::zeros_generic(D::from_usize(d), Dynamic::from_usize(0));
    test_a_s.extend(
        yas.iter()
            .zip(is_train.iter())
            .filter_map(|((a, _y), is)| if !*is { Some(a.clone()) } else { None }),
    );
    test_a_s.scale_mut(1.0 / norm);

    let scale = (d as f64).sqrt();
    train_a_s.scale_mut(scale);
    test_a_s.scale_mut(scale);

    Ok((
        Logistic::init(train_a_s, train_ys, m),
        Logistic::init(test_a_s, test_ys, 0.0),
    ))
}

struct GermanModelParameters {
    L: DType,
    m: DType,
    kappa: DType,
    gamma: DType,
    temperature: DType,
    batch_size: usize,
    step: DType,
    burnin: usize,
    total_gradient_query: usize,
    model: Model,
    test_model: Model,
    ensemble_size: usize,
}

impl GermanModelParameters {
    fn default() -> anyhow::Result<Self> {
        let (mut model, test_model) = read_german(0.0)?;
        let tL = model.calculate_L();
        let kappa = 1e3;
        let m = tL / (kappa - 1.0);
        model.m = m;
        let L = tL + m;
        let step = 1.0 / 4.0 / 10.0 / L;

        Ok(GermanModelParameters {
            L,
            m,
            kappa,
            gamma: 2.0,
            temperature: 1.0,
            batch_size: 1,
            step,
            burnin: (50.0 / (L * step)) as usize,
            total_gradient_query: (100.0 / (L * step)) as usize,
            model,
            test_model,
            ensemble_size: 100,
        })
    }
}

pub fn german_scale_data() -> anyhow::Result<()> {
    rayon::scope(|s| {
        s.spawn(|_| {
            println!("RMM");
            _german_scale_data::<RMMMS<_>>().unwrap();
        });
        s.spawn(|_| {
            println!("LPMLikeRMM");
            _german_scale_data::<LPMLikeRMMMS<_>>().unwrap();
        });
        s.spawn(|_| {
            println!("OM");
            _german_scale_data::<OMMS<_>>().unwrap();
        });
    });
    Ok(())
}
pub fn _german_scale_data<U: ULDProcessNewMS<DType> + ULDProcess<State> + Send>() -> anyhow::Result<()> {
    let amp = GermanModelParameters::default()?;
    #[allow(unused_variables)]
    let GermanModelParameters {
        L,
        m,
        kappa,
        gamma,
        temperature,
        batch_size,
        model,
        test_model,
        ..
    } = amp;
    let uldp: ULDParam<DType> = ULDParam::<DType> {
        xi: L,
        temperature,
        gamma,
    };
    let batch_size = 40;

    let segments = 10;
    let ges_num = 30;
    let max_ges = 1e7;
    let min_ges = 1e3;

    use std::collections::HashMap;
    use std::sync::Mutex;
    let result = Mutex::new(HashMap::new());

    let mut rng = StdRng::seed_from_u64(SEED);
    let initial_x: OVector<DType, D> = model.init_state(&mut rng);
    let initial_v = State::zeros_generic(D::from_usize(d), Const::<1>);

    rayon::scope(|s| {
        for i in 0..ges_num {
            for_type!(Alg::<GD, SGD, SVRG2<State>, SAGA<State, Model>>, {
                let rng = &rng;
                let model = &model;
                let uldp = &uldp;
                let initial_x = &initial_x;
                let initial_v = &initial_v;
                let result = &result;
                s.spawn(move |_| {
                    let mut ge = Alg::new(initial_x, model, rng.clone());
                    ge.set_batch_size(batch_size, model);
                    let overhead = ge.gradient_overhead(model);
                    let name = GradientEstimator::<State, Model>::name(&ge);
                    let go = FixedBatchGradientOracle::new(model, rng.clone(), batch_size, ge);

                    let burnin_var = 0;
                    let total_iterations_var =
                        (min_ges * (max_ges / min_ges).powf((ges_num - i - 1) as f64 / (ges_num - 1) as f64)
                            / (batch_size as f64 * overhead as f64)) as usize;
                    let step_var = 1e2 / L / (total_iterations_var as f64);

                    let uldprocess = U::new(segments, step_var / (segments as f64), uldp.clone(), rng.clone());

                    use std::marker::PhantomData;
                    let it = FixedBatchSeq {
                        x: initial_x.clone(),
                        v: initial_v.clone(),
                        uldprocess,
                        phantom: PhantomData::<BasicInfoBuilder<_>>,
                        go,
                    };
                    let ref_it = {
                        let ge = GD::new(initial_x, model, rng.clone());
                        let go = FixedBatchGradientOracle::new(model, rng.clone(), batch_size, ge);

                        let uldprocess = RMM::new(step_var / (segments as f64), uldp.clone(), rng.clone());

                        use std::marker::PhantomData;
                        let iter = FixedBatchSeq {
                            x: initial_x.clone(),
                            v: initial_v.clone(),
                            uldprocess,
                            phantom: PhantomData::<BasicInfoBuilder<_>>,
                            go,
                        };
                        iter
                    };

                    let mut gradient_num = 0;
                    let values: MeanWithError = it
                        .zip(ref_it.skip(segments - 1).step_by(segments))
                        .skip(burnin_var)
                        .take(total_iterations_var)
                        .map(|((x, v, _, num), (ref_x, ref_v, _, _))| {
                            let x_diff = x - ref_x;
                            let v_diff = v - ref_v;
                            gradient_num = num;
                            (x_diff.norm_squared() + v_diff.norm_squared()).sqrt()
                        })
                        .collect();

                    result.lock().unwrap().entry(name).or_insert(Vec::new()).push((
                        step_var,
                        gradient_num,
                        values.mean(),
                        values.error(),
                    ));
                });
            });
        }
    });

    let result = result.into_inner()?;
    let uname = simple_type_name::<U>();
    for (name, v) in result {
        let mut file = BufWriter::new(File::create(format!("out/german_{}_{}_scale", uname, name)).unwrap());

        for x in v {
            write!(file, "{},{},{},{}\n", x.0, x.1, x.2, x.3).unwrap();
        }
    }

    Ok(())
}
pub fn german_plot_data() -> anyhow::Result<()> {
    rayon::scope(|s| {
        s.spawn(|_| {
            println!("OM");
            _german_plot_data::<OM<_>>().unwrap();
        });
    });
    Ok(())
}
pub fn _german_plot_data<U: ULDProcessNew<DType> + ULDProcess<State> + Send>() -> anyhow::Result<()> {
    let amp = GermanModelParameters::default()?;
    #[allow(unused_variables)]
    let GermanModelParameters {
        L,
        m,
        kappa,
        gamma,
        temperature,
        batch_size,
        step,
        burnin,
        total_gradient_query,
        model,
        test_model,
        ensemble_size,
        ..
    } = amp;
    let uldp: ULDParam<DType> = ULDParam::<DType> {
        xi: L,
        temperature,
        gamma,
    };

    let mut rng = StdRng::seed_from_u64(SEED);
    let initial_x: OVector<DType, D> = model.init_state(&mut rng);
    let initial_v = State::zeros_generic(D::from_usize(d), Const::<1>);
    let seed_for_ensemble: Vec<u64> = rng.sample_iter(Standard).take(ensemble_size).collect();

    rayon::scope(|s| {
        for_type!(Alg::<GD, SGD, SVRG2<State>, SAGA<State, Model>>, {
            s.spawn(|_| {
                let mut temp_ge = Alg::new(&initial_x, &model, StdRng::seed_from_u64(111));
                temp_ge.set_batch_size(batch_size, &model);
                let total_iterations =
                    ((total_gradient_query / batch_size) as f64 / (temp_ge.gradient_overhead(&model)) as f64) as usize;

                let gap = max(total_iterations / 100, 1);
                let history = (0..ensemble_size)
                    .into_par_iter()
                    .map(|i| {
                        let local_seed = seed_for_ensemble[i];
                        let local_rng = StdRng::seed_from_u64(local_seed);

                        let mut ge = Alg::new(&initial_x, &model, local_rng.clone());
                        ge.set_batch_size(batch_size, &model);
                        let go = FixedBatchGradientOracle::new(&model, local_rng.clone(), batch_size, ge);

                        let uldprocess = U::new(step, uldp.clone(), local_rng.clone());

                        use std::marker::PhantomData;
                        let iter = FixedBatchSeq {
                            x: initial_x.clone(),
                            v: initial_v.clone(),
                            uldprocess,
                            phantom: PhantomData::<BasicInfoBuilder<_>>,
                            go,
                        };
                        let iter: Box<dyn Iterator<Item = _> + Send> = Box::new(iter);

                        let history = iter
                            .take(total_iterations)
                            .chunks(gap)
                            .into_iter()
                            .map(|chunk| {
                                let mut len = 0;
                                chunk
                                    .into_iter()
                                    .step_by(100)
                                    .map(|(x, _, _, gradient_query_num)| {
                                        len += 1;
                                        let potential = model.all_loss(&x);
                                        let test_neg_log_likelihood = test_model.all_loss(&x);
                                        let prob = test_model.probs(&x);
                                        let test_acc = prob.iter().filter(|a| **a > 0.5).count();
                                        let mut v = vec![
                                            potential,
                                            test_neg_log_likelihood,
                                            (test_acc as f64) / (model.N() as f64),
                                            gradient_query_num as f64,
                                        ];
                                        v.extend(prob.iter());
                                        let v: DVector<DType> = v.into();
                                        v
                                    })
                                    .sum::<OMatrix<_, _, _>>()
                                    / len as f64
                            })
                            .collect::<Vec<_>>();
                        Some(OMatrix::from_columns(&history))
                    })
                    .reduce(
                        || None,
                        |a: Option<OMatrix<DType, Dynamic, Dynamic>>, b: Option<OMatrix<DType, _, _>>| {
                            if let Some(aa) = a {
                                if let Some(bb) = b {
                                    Some(aa + bb)
                                } else {
                                    Some(aa)
                                }
                            } else {
                                b
                            }
                        },
                    )
                    .unwrap()
                    / ensemble_size as f64;
                let mut output = File::create(format!(
                    "out/german_{}_{}_plot",
                    simple_type_name::<U>(),
                    simple_type_name::<Alg>()
                ))
                .unwrap();
                for row in history.row_iter() {
                    let s = row.iter().map(|i| i.to_string()).join(",");
                    write!(output, "{}\n", s).unwrap();
                }
            });
        });
    });

    Ok(())
}
