# %%
import os
import re

from typing import List, Any, Tuple, Dict

import numpy as np
import pandas as pd

from lib.generate_datasets import GenDataCorrelatedGaussVsNonCorrelatedGauss
from lib.utils import set_seed_everything
from lib_exp.exp_stability_optimization import dre_train_for_all_epoch

MODE = []
#MODE += ['TEST']

python_file_name = os.path.basename(__file__)
out_dirname = re.findall(r'.+_run_experiment_(.+).py', python_file_name)[0]

WORK_DIR = './'
PARENT_OUT_DIR = os.path.join(WORK_DIR, 'out', out_dirname)

N_DATA_TRAIN = 5000

N_LAYERS_PER_BLOCKS = 3
N_UNITS_HIDDEN_NN = 100
N_BLOCKS = 1
DROPOUT = 0.0
N_ALL_EPOCHS = 500
BATCHSIZE = 2500
LEARNING_RATE = 0.001
SAVE_RESULT = False
EARLY_STOPPING_PATIENCE = 0
CUDA_DEVICE_ID=0


def _one_simulation(
          n_dimensions_data,
          params_training,
          batchsize: int,
          n_all_epochs: int,
          out_results_dir: str,
          do_save_result: bool,
          seed: int
      ) -> Dict[int, float]:

    gen_data = GenDataCorrelatedGaussVsNonCorrelatedGauss(0.8, n_dimensions_data)
    train_datasets_de, train_datasets_nu, _  = gen_data.sample(
            N_DATA_TRAIN)
    res_dict = dre_train_for_all_epoch(
        train_datasets_de,
        train_datasets_nu,
        params_training,
        batchsize,
        n_all_epochs,
        out_results_dir,
        do_save_result=do_save_result,
        seed=seed,
        device_id=CUDA_DEVICE_ID)
    res_df = pd.DataFrame.from_dict(res_dict, orient='index')

    return res_df


def run_all_simulations(
      n_simulations,
      alpha,
      n_dimensions_data,
      output_top_dir) -> None:
    out_log_dir_name = f'alpha-{alpha:.3f}_ndim-{n_dimensions_data}'
    out_all_log_dir = os.path.join(
        output_top_dir, 'log', out_log_dir_name)
    os.makedirs(out_all_log_dir, exist_ok=True)

    # --- run ---
    params_training = {
        'alpha': alpha,
        'earlystopping_patience': EARLY_STOPPING_PATIENCE,
        'hidden_dim': N_UNITS_HIDDEN_NN,
        'learning_rate': LEARNING_RATE,
        'n_layers_per_block': N_LAYERS_PER_BLOCKS,
        'n_blocks': N_BLOCKS,
        'dropout': DROPOUT,
        'DoBatchNormarize': N_BLOCKS > 1}

    all_res_dfs_list = list()
    for _i_sim in range(n_simulations):
        sim_id_str = f'SIM{_i_sim:04}'
        out_this_sim_log_dir = os.path.join(
            out_all_log_dir, sim_id_str)
        os.makedirs(out_this_sim_log_dir, exist_ok=True)
        res_df = _one_simulation(
            n_dimensions_data,
            params_training,
            BATCHSIZE,
            N_ALL_EPOCHS,
            out_this_sim_log_dir,
            do_save_result=SAVE_RESULT,
            seed=_i_sim)
        colname_sim_id = pd.MultiIndex.from_product(
            [[sim_id_str], res_df.columns])
        res_df.columns = colname_sim_id
        all_res_dfs_list.append(res_df)
    all_result_df = pd.concat(all_res_dfs_list, axis=1)
    all_result_df.reset_index(drop=False, inplace=True)

    n_data = N_DATA_TRAIN
    csv_file_name = f'Ndata{n_data:05}_dimensions{n_dimensions_data:02}_alpha{alpha:.3f}.csv'
    all_result_df.to_csv(
        os.path.join(output_top_dir, csv_file_name), index=False)

if 'TEST' in MODE:
    parent_outdir = os.path.join(PARENT_OUT_DIR,'TEST')
    n_simulations= 3
else:
    parent_outdir = PARENT_OUT_DIR
    n_simulations= 100

# --- run experiment ---
if __name__ == '__main__':
    os.makedirs(parent_outdir, exist_ok=True)
    

    set_seed_everything(1)
    run_all_simulations(
        n_simulations,  # n_simulations,
        -2.0,    # alpha,
        5,    # n_dimensions_data,
        PARENT_OUT_DIR
    )

    set_seed_everything(1)
    run_all_simulations(
        n_simulations,  # n_simulations,
        3.0,    # alpha,
        5,    # n_dimensions_data,
        PARENT_OUT_DIR
    )

    set_seed_everything(1)
    run_all_simulations(
        n_simulations,  # n_simulations,
        0.5,    # alpha,
        5,    # n_dimensions_data,
        PARENT_OUT_DIR
    )

    set_seed_everything(1)
    run_all_simulations(
        n_simulations,  # n_simulations,
        2,    # alpha,
        5,    # n_dimensions_data,
        PARENT_OUT_DIR
    )

    set_seed_everything(1)
    run_all_simulations(
        n_simulations,  # n_simulations,
        3,    # alpha,
        5,    # n_dimensions_data,
        PARENT_OUT_DIR
    )

    set_seed_everything(1)
    run_all_simulations(
        n_simulations,  # n_simulations,
        4,    # alpha,
        5,    # n_dimensions_data,
        PARENT_OUT_DIR
    )

    set_seed_everything(1)
    run_all_simulations(
        n_simulations,  # n_simulations,
        0.2,  # alpha,
        5,    # n_dimensions_data,
        PARENT_OUT_DIR
    )

    set_seed_everything(1)
    run_all_simulations(
        n_simulations,  # n_simulations,
        0.5,  # alpha,
        5,    # n_dimensions_data,
        PARENT_OUT_DIR
    )
    
    set_seed_everything(1)
    run_all_simulations(
        n_simulations,  # n_simulations,
        0.8,  # alpha,
        5,    # n_dimensions_data,
        PARENT_OUT_DIR
    )

    set_seed_everything(1)
    run_all_simulations(
        n_simulations,  # n_simulations,
        -1,   # alpha,
        5,    # n_dimensions_data,
        PARENT_OUT_DIR
    )

    set_seed_everything(1)
    run_all_simulations(
        n_simulations,  # n_simulations,
        -2,   # alpha,
        5,    # n_dimensions_data,
        PARENT_OUT_DIR
    )

    set_seed_everything(1)
    run_all_simulations(
        n_simulations,  # n_simulations,
        -3,   # alpha,
        5,    # n_dimensions_data,
        PARENT_OUT_DIR
    )
