#include <iostream>
#include <fstream>
#include <vector>
#include <map>
#include <algorithm>
#include <cstdlib>
#include <chrono>
#include <random>
#include <functional>
#include <string>
#include "def_the_type.h"
#include "data.h"
#include "NEW_algorithm.h"
#include "IBOSS_algorithm.h"
#include "UNIFORM_algorithm.h"
#include "SRHT_algorithm.h"
#include "SLEV_algorithm.h"
#include "VDA_algorithm.h"

#include "experiments.h"
using namespace std;

const int N = _N_CONFIG;
const int p = _p_CONFIG;
const int n = N / p;

class one_experiment{
    private:
        string method_name;
        void run_new(the_Data*, const vector<T>&);
        void run_VDA(the_Data*, const vector<T>&);
        void run_IBOSS(the_Data*, const vector<T>&);
        void run_IBOSS_ADJUST(the_Data*, const vector<T>&);
        void run_uniform(the_Data*, const vector<T>&);
        void run_full(the_Data*, const vector<T>&);
        void run_SRHT(the_Data*, const vector<T>&);
        void run_SLEV(the_Data*, const vector<T>&);

    public:
        one_experiment(const string);
        vector<T> MSE_vector;
        vector<T> time_vector;
        void run(the_Data*, const vector<T>&);
        void average();
        void print();
        T MSE_mean;
        T time_mean;
};
one_experiment::one_experiment(const string method_name){
    this -> method_name = method_name;
}
void one_experiment::average(){
    MSE_mean = 0.;
    for(int i = 0; i < MSE_vector.size();++i){
        MSE_mean += MSE_vector[i];
    }
    MSE_mean /= (1. * MSE_vector.size());

    time_mean = 0;
    for(int i = 0; i < time_vector.size();++i){
        time_mean += time_vector[i];
    }
    time_mean /= (1. * time_vector.size());

    return;
}
void one_experiment::print(){
    cout << "%############################## " << method_name << " ###########################" << endl;
    cout << "&%" << method_name << " MSE: \n";
    cout << MSE_mean * 1000.  << endl;
    cout << "&%" << method_name << " time: \n";
    cout << time_mean << endl;
    return;
}
void one_experiment::run(the_Data *data, const vector<T> &true_beta){
    if(method_name == "new"){
        run_new(data, true_beta);
    }
    if(method_name == "VDA"){
        run_VDA(data, true_beta);
    }
    if(method_name == "IBOSS"){
        run_IBOSS(data, true_beta);
    }
    if(method_name == "IBOSS_ADJUST"){
        run_IBOSS_ADJUST(data, true_beta);
    }
    if(method_name == "uniform"){
        run_uniform(data, true_beta);
    }
    if(method_name == "SRHT"){
        run_SRHT(data, true_beta);
    }
    if(method_name == "SLEV"){
        run_SLEV(data, true_beta);
    }
    if(method_name == "full"){
        run_full(data, true_beta);
    }
    return;
}
void one_experiment::run_new(the_Data *data, const vector<T> &true_beta){
    auto time0 = chrono::high_resolution_clock::now();
    NEW_algorithm NEW_alg(data);
    vector<T> NEW_estimator = NEW_alg.Execute();
    auto time1 = chrono::high_resolution_clock::now();
    auto dd_time = chrono::duration_cast<chrono::microseconds>(time1 - time0).count();
    time_vector.push_back(dd_time / 1e6);
    T the_MSE = 0.;
    for(int j = 0; j < true_beta.size(); ++j){
        the_MSE += ((NEW_estimator[j] - true_beta[j]) * (NEW_estimator[j] - true_beta[j]));
    }
    MSE_vector.push_back(the_MSE);
    return;
}
void one_experiment::run_VDA(the_Data *data, const vector<T> &true_beta){
    auto time0 = chrono::high_resolution_clock::now();
    VDA_algorithm VDA_alg(data);
    vector<T> VDA_estimator = VDA_alg.Execute();
    auto time1 = chrono::high_resolution_clock::now();
    auto dd_time = chrono::duration_cast<chrono::microseconds>(time1 - time0).count();
    time_vector.push_back(dd_time / 1e6);
    T the_MSE = 0.;
    for(int j = 0; j < true_beta.size(); ++j){
        the_MSE += ((VDA_estimator[j] - true_beta[j]) * (VDA_estimator[j] - true_beta[j]));
    }
    MSE_vector.push_back(the_MSE);
    return;
}
void one_experiment::run_IBOSS(the_Data *data, const vector<T> &true_beta){
    auto time0 = chrono::high_resolution_clock::now();
    IBOSS_algorithm IBOSS_alg(data, n);
    vector<T> IBOSS_estimator = IBOSS_alg.Execute();
    auto time1 = chrono::high_resolution_clock::now();
    auto dd_time = chrono::duration_cast<chrono::microseconds>(time1 - time0).count();
    time_vector.push_back(dd_time / 1e6);
    T IBOSS_MSE = 0.;
    for(int j = 0; j < p + 1; ++j){
        IBOSS_MSE += (IBOSS_estimator[j] - true_beta[j]) * (IBOSS_estimator[j] - true_beta[j]);
    }
    MSE_vector.push_back(IBOSS_MSE);
    return;
}
void one_experiment::run_IBOSS_ADJUST(the_Data *data, const vector<T> &true_beta){
    auto time0 = chrono::high_resolution_clock::now();
    IBOSS_algorithm IBOSS_ADJUST_alg(data, n);
    vector<T> IBOSS_ADJUST_estimator = IBOSS_ADJUST_alg.Execute_adjust();
    auto time1 = chrono::high_resolution_clock::now();
    auto dd_time = chrono::duration_cast<chrono::microseconds>(time1 - time0).count();
    time_vector.push_back(dd_time/1e6);
    T IBOSS_ADJUST_MSE = 0.;
    for(int j = 0; j < p + 1; ++j){
        IBOSS_ADJUST_MSE += (IBOSS_ADJUST_estimator[j] - true_beta[j]) * (IBOSS_ADJUST_estimator[j] - true_beta[j]);
    }
    MSE_vector.push_back(IBOSS_ADJUST_MSE);
    return;
}
void one_experiment::run_uniform(the_Data *data, const vector<T> &true_beta){
    auto time0 = chrono::high_resolution_clock::now();
    UNIFORM_algorithm UNIFORM_alg(data, n);
    vector<T> UNIFORM_estimator = UNIFORM_alg.Execute();
    auto time1 = chrono::high_resolution_clock::now();
    auto dd_time = chrono::duration_cast<chrono::microseconds>(time1 - time0).count();
    time_vector.push_back(dd_time/1e6);
    T UNIFORM_MSE = 0.;
    for(int j = 0; j < p + 1; ++j){
        UNIFORM_MSE += (UNIFORM_estimator[j] - true_beta[j]) * (UNIFORM_estimator[j] - true_beta[j]);
    }
    MSE_vector.push_back(UNIFORM_MSE);
    return;
}
void one_experiment::run_SRHT(the_Data *data, const vector<T> &true_beta){
    auto time0 = chrono::high_resolution_clock::now();
    SRHT_algorithm SRHT_alg(data, n);
    vector<T> SRHT_estimator = SRHT_alg.Execute();
    auto time1 = chrono::high_resolution_clock::now();
    auto dd_time = chrono::duration_cast<chrono::microseconds>(time1 - time0).count();
    time_vector.push_back(dd_time/1e6);
    T SRHT_MSE = 0.;
    for(int j = 0; j < p + 1; ++j){
        SRHT_MSE += (SRHT_estimator[j] - true_beta[j]) * (SRHT_estimator[j] - true_beta[j]);
    }
    MSE_vector.push_back(SRHT_MSE);
    return;
}
void one_experiment::run_SLEV(the_Data *data, const vector<T> &true_beta){
    auto time0 = chrono::high_resolution_clock::now();
    SLEV_algorithm SLEV_alg(data, n);
    vector<T> SLEV_estimator = SLEV_alg.Execute();
    auto time1 = chrono::high_resolution_clock::now();
    auto dd_time = chrono::duration_cast<chrono::microseconds>(time1 - time0).count();
    time_vector.push_back(dd_time/1e6);
    T SLEV_MSE = 0.;
    for(int j = 0; j < p + 1; ++j){
        SLEV_MSE += (SLEV_estimator[j] - true_beta[j]) * (SLEV_estimator[j] - true_beta[j]);
    }
    MSE_vector.push_back(SLEV_MSE);
    return;
}
void one_experiment::run_full(the_Data *data, const vector<T> &true_beta){
    auto time0 = chrono::high_resolution_clock::now();
    UNIFORM_algorithm FULL_alg(data, N);
    vector<T> FULL_estimator = FULL_alg.Execute_full();
    auto time1 = chrono::high_resolution_clock::now();
    auto dd_time = chrono::duration_cast<chrono::microseconds>(time1 - time0).count();
    time_vector.push_back(dd_time/1e6);
    T FULL_MSE = 0.;
    for(int j = 0; j < p + 1; ++j){
        FULL_MSE += (FULL_estimator[j] - true_beta[j]) * (FULL_estimator[j] - true_beta[j]);
    }
    MSE_vector.push_back(FULL_MSE);
    return;
}

int main(){
    unsigned seed = chrono::system_clock::now().time_since_epoch().count();
    default_random_engine generator(seed);
    normal_distribution<T> the_normal(0., 1.);
    auto the_generator = bind(the_normal, generator);

    vector<T> true_beta(p+1);
    for(int j = 0; j < p + 1; ++j){
        true_beta[j] = 1.;
    }
    auto time0 = chrono::high_resolution_clock::now();
    auto time1 = chrono::high_resolution_clock::now();
    T dd_time = chrono::duration_cast<chrono::microseconds>(time1 - time0).count();

    vector<string> total_methods {string("new"), string("VDA"), string("uniform"), string("SRHT"), string("SLEV"), string("IBOSS"), string("full")};
    vector<string> all_methods;
    if(_RUN_FULL_CONFIG)
        all_methods.push_back(string("full"));
    if(_RUN_NEW_CONFIG)
        all_methods.push_back(string("new"));
    if(_RUN_VDA_CONFIG)
        all_methods.push_back(string("VDA"));
    if(_RUN_IBOSS_CONFIG)
        all_methods.push_back(string("IBOSS"));
    //if(_RUN_IBOSS_ADJUST_CONFIG)
    //    all_methods.push_back("IBOSS_ADJUST");
    if(_RUN_UNIFORM_CONFIG)
        all_methods.push_back(string("uniform"));
    if(_RUN_SLEV_CONFIG)
        all_methods.push_back(string("SLEV"));
    if(_RUN_SRHT_CONFIG)
        all_methods.push_back(string("SRHT"));

    map<string, one_experiment*> all_experiments;
    for(auto name : all_methods)
        all_experiments[name] = new one_experiment(name);

    for(int repeat_time = 0; repeat_time < _REPEAT_TIME_CONFIG; ++repeat_time){
        // generate data
        the_Data myD(N, p);

        switch(_Z_DIS_CONFIG){
            case _Z_UNIFORM_DISTRIBUTION_CODE:
                myD.uniform_generate();
                break;
            case _Z_NORMAL_DISTRIBUTION_CODE:
                myD.normal_generate();
                break;
            case _Z_LOGNORMAL_DISTRIBUTION_CODE:
                myD.lognormal_generate();
                break;
            case _Z_T_DISTRIBUTION_CODE:
                myD.t_generate();
                break;
            case _Z_NORMAL_CORRELATION_DISTRIBUTION_CODE:
                myD.normal_equal_correlation_generate();
            case _Z_NORMAL_MIXTURE_DISTRIBUTION_CODE:
                myD.normal_mixture_generate();
        }
        switch(_y_DIS_CONFIG){
            case _y_NORMAL_DISTRIBUTION_CODE:
                myD.normal_generate_y(true_beta);
                break;
            case _y_CHI_SQUARED_DISTRIBUTION_CODE:
                myD.chi_squared_generate_y(true_beta);
                break;
        }
        for(auto method_name : all_methods)
            all_experiments[method_name] -> run(&myD, true_beta);

        cout << "complete "<< repeat_time + 1 << " \n";
    }

    for(auto method_name : all_methods){
        all_experiments[method_name] -> average();
        all_experiments[method_name] -> print();     
    }

    map<int, string> Z_dis_map;
    Z_dis_map[_Z_UNIFORM_DISTRIBUTION_CODE] = "uniform";
    Z_dis_map[_Z_NORMAL_DISTRIBUTION_CODE] = "normal";
    Z_dis_map[_Z_LOGNORMAL_DISTRIBUTION_CODE] = "lognormal";
    Z_dis_map[_Z_T_DISTRIBUTION_CODE] = "t";
    Z_dis_map[_Z_NORMAL_CORRELATION_DISTRIBUTION_CODE] = "normalCorrelation";
    Z_dis_map[_Z_NORMAL_MIXTURE_DISTRIBUTION_CODE] = "normalMixture";

    map<int, string> y_dis_map;
    y_dis_map[_y_NORMAL_DISTRIBUTION_CODE] = "normal";
    y_dis_map[_y_CHI_SQUARED_DISTRIBUTION_CODE] = "chiSquared";

    ofstream the_file;
    string file_name;
    file_name = string("results/") + "Z_" + Z_dis_map[_Z_DIS_CONFIG] + "_y_" + y_dis_map[_y_DIS_CONFIG] + "_N_" + to_string(N) + "_p_" + to_string(p) + ".txt";
    cout << file_name << endl;
    the_file.open(file_name);
    the_file << "p";
    for(string method : total_methods){
        if(find(all_methods.begin(), all_methods.end(), method) != all_methods.end()){
            the_file << " & " << method;
        }
    }
    the_file << endl << "MSE:" << endl;
    the_file <<  to_string(p);
    for(string method : total_methods){
        if(find(all_methods.begin(), all_methods.end(), method) != all_methods.end()){
            the_file << " & " << ((all_experiments[method] -> MSE_mean) * 1000.);
        }
    }
    the_file << endl << "time:" << endl;
    the_file <<  to_string(p);
    for(auto method : total_methods){
        if(find(all_methods.begin(), all_methods.end(), method) != all_methods.end()){
            the_file << " & " << all_experiments[method] -> time_mean;
        }
    }
    the_file.close();
    return 0;
}

