# %%
from __future__ import division
import os
import datetime
import random
import re
import pickle
import argparse
import glob

from typing import List, Any, Tuple, Dict
from typing import Optional

import scipy
import numpy as np
import pandas as pd


import torch
from torch import nn, Tensor
from torch import optim
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

import optuna

from libs.CheckConvergence import generate_data, calc_true_alpha_info
from libs.CheckConvergence import train_for_all_epochs


# Script Global Parameters
RHO = 0.8
IS_TEST = False
PARENT_OUT_DIR = './out/Res_Section_7'
N_DATA_TRAIN = 5000
N_DATA_TEST = 5000
N_LAYER_NN = 3
N_UNITS_HIDDEN = 100
N_ALL_EPOCHS = 500
BATCHSIZE = 2500
LEARNING_RATE = 0.002


def _one_simulation(
          alpha,
          n_dimensions_data,
          out_result_dir: str
      ) -> Dict[int, float]:

    # Structure of a NGB model
    params_nbw = {
        'n_layers': N_LAYER_NN,
        'hidden_dim': N_UNITS_HIDDEN}
    params_nbw['input_dim'] = n_dimensions_data

    train_explanatories_to_be_balanced, \
        test_explanatories_to_be_balanced = generate_data(
            N_DATA_TRAIN,
            N_DATA_TEST,
            RHO, n_dimensions_data)

    true_alpha_info = calc_true_alpha_info(alpha, RHO, n_dimensions_data)

    res_dict = train_for_all_epochs(
        alpha,
        train_explanatories_to_be_balanced,
        test_explanatories_to_be_balanced,
        params_nbw,
        LEARNING_RATE,
        BATCHSIZE,
        N_ALL_EPOCHS,
        out_result_dir)

    vals_arr = np.fromiter(res_dict.values(), dtype=float)
    val_rates_arr = vals_arr / true_alpha_info
    keys_arr = np.fromiter(res_dict.keys(), dtype=int)
    res_df = pd.DataFrame(data={'val_rates': val_rates_arr}, index=keys_arr)

    return res_df


def run_all_simulations(
      n_simulations,
      alpha,
      n_dimensions_data,
      output_top_dir) -> None:

    out_log_dir_name = f'alpha-{alpha:.3f}_ndim-{n_dimensions_data}'
    out_all_log_dir = os.path.join(
        output_top_dir, 'logs', out_log_dir_name)
    os.makedirs(out_all_log_dir, exist_ok=True)

    all_res_dfs_list = list()
    for _i_sim in range(n_simulations):
        sim_id_str = f'SIM{_i_sim:04}'
        out_this_sim_log_dir = os.path.join(
            out_all_log_dir, sim_id_str)
        os.makedirs(out_this_sim_log_dir, exist_ok=True)
        res_df = _one_simulation(
            alpha,
            n_dimensions_data,
            out_this_sim_log_dir)
        res_df.rename(
            columns={'val_rates': sim_id_str},
            inplace=True)
        all_res_dfs_list.append(res_df)
    all_result_df = pd.concat(all_res_dfs_list, axis=1)
    all_result_df.reset_index(drop=False, inplace=True)

    n_data = N_DATA_TRAIN
    csv_file_name = f'Ndata{n_data:05}_dimensions{n_dimensions_data:02}_alpha{alpha:.3f}.csv'
    all_result_df.to_csv(
        os.path.join(output_top_dir, csv_file_name), index=False)


# %%
if __name__ == '__main__':
    now = datetime.datetime.now()
    current_datetime = now.strftime('%Y%m%d_%H%M_%S')
    output_top_dir = os.path.join(
        PARENT_OUT_DIR, current_datetime)
    os.makedirs(output_top_dir, exist_ok=True)

    np.random.seed(1)
    random.seed(1)

    if IS_TEST:
        run_all_simulations(
            2,    # n_simulations,
            0.5,  # alpha,
            4,    # n_dimensions_data,
            output_top_dir
        )
        run_all_simulations(
            2,    # n_simulations,
            0.7,  # alpha,
            4,    # n_dimensions_data,
            output_top_dir
        )
    else:
        np.random.seed(1)
        random.seed(1)
        run_all_simulations(
            100,  # n_simulations,
            0.5,  # alpha,
            2,    # n_dimensions_data,
            output_top_dir
        )
        np.random.seed(1)
        random.seed(1)
        run_all_simulations(
            100,  # n_simulations,
            0.5,  # alpha,
            3,    # n_dimensions_data,
            output_top_dir
        )
        np.random.seed(1)
        random.seed(1)
        run_all_simulations(
            100,  # n_simulations,
            0.5,  # alpha,
            4,    # n_dimensions_data,
            output_top_dir
        )
        np.random.seed(1)
        random.seed(1)
        run_all_simulations(
            100,  # n_simulations,
            0.5,  # alpha,
            5,    # n_dimensions_data,
            output_top_dir
        )
        np.random.seed(1)
        random.seed(1)
        run_all_simulations(
            100,  # n_simulations,
            0.5,  # alpha,
            6,    # n_dimensions_data,
            output_top_dir
        )
        np.random.seed(1)
        random.seed(1)
        run_all_simulations(
            100,  # n_simulations,
            0.5,  # alpha,
            7,    # n_dimensions_data,
            output_top_dir
        )