# %%
from __future__ import division
import os
import datetime
import random
import re
import pickle
import argparse
import glob

from typing import List, Any, Tuple
from typing import Optional

import scipy
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import OneHotEncoder
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import GradientBoostingRegressor

import torch
from torch import nn, Tensor
from torch import optim
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

import optuna

from libs.NeuralBW import train_and_enhance_NBW, train_NBW
from libs.NeuralBW import estimate_balancing_weights
from libs.OptunaFuncitons import objective_GBT

# Script Global Parameters
IS_TEST = False
MAX_SIM_TEST = 5


INPUT_TRAIN_PARENTDIR = './data/train_data'
INPUT_TRAIN_DATA_DIR = os.path.join(INPUT_TRAIN_PARENTDIR, 'N100000')
INPUT_TEST_DATA_DIR = './data/test_data'
PARENT_OUT_DIR = './out/Res_Section_8'


ALPHA = 0.5
N_LAYER_NN = 10
N_UNITS_HIDDEN_NN = 100
N_EPOCHS_MODEL_EXP1 = 140
N_EPOCHS_MODEL_EXP2 = 140
N_EPOCHS_ESTIMATE_EXP1 = N_EPOCHS_MODEL_EXP1
N_EPOCHS_ESTIMATE_EXP2 = N_EPOCHS_MODEL_EXP2
BATCHSIZE = 5000
LEARNING_RATE = 0.0001

if IS_TEST:
    PARENT_OUT_DIR = os.path.join(PARENT_OUT_DIR, 'TEST')

PARENT_OUT_DIR = os.path.join(
    PARENT_OUT_DIR, os.path.basename(
        os.path.normpath(INPUT_TRAIN_DATA_DIR)))

if IS_TEST:
    PARENT_OUT_DIR = os.path.join(PARENT_OUT_DIR,
      f'Epochs{N_EPOCHS_MODEL_EXP2}')

# For hyperparameter searches for GDB
N_TRY_OPT = 25
RATE_SUBSAMPLE_TUNE_GDB = 1
OPTUNA_TIMEOUT_SECONDS = 60*(60*6)
optuna.logging.set_verbosity(optuna.logging.CRITICAL)


def read_train_data(
    input_dir_str
      ):
    """
    Read data
    """
    d_train_all_df = pd.read_csv(
        os.path.join(input_dir_str, 'train.csv'))
    eb_weights_df = pd.read_csv(
        os.path.join(
          input_dir_str,
          'train_EBweights.csv'))
    d_train, d_eval, \
        train_eb_weights_df, eval_eb_weights_df = train_test_split(
                  d_train_all_df,
                  eb_weights_df,
                  test_size=0.5,
                  train_size=0.5,
                  random_state=0)

    train_expls_df = d_train.drop('Y', axis=1)
    train_resp_srs = d_train['Y']
    train_expl_mat = train_expls_df.to_numpy()
    train_resp_arr = train_resp_srs.to_numpy()
    train_expl_mats_list = [
        train_expl_mat[:, [0]],
        train_expl_mat[:, 1:]]

    eval_expls_df = d_eval.drop('Y', axis=1)
    eval_resp_srs = d_eval['Y']
    eval_expl_mat = eval_expls_df.to_numpy()
    eval_resp_arr = eval_resp_srs.to_numpy()
    eval_expl_mats_list = [
        eval_expl_mat[:, [0]],
        eval_expl_mat[:, 1:]]

    refit_expls_df = d_train_all_df.drop('Y', axis=1)
    refit_resp_srs = d_train_all_df['Y']
    refit_expl_mat = refit_expls_df.to_numpy()
    refit_resp_arr = refit_resp_srs.to_numpy()
    refit_expl_mats_list = [
        refit_expl_mat[:, [0]],
        refit_expl_mat[:, 1:]]
    refit_eb_weights_df = eb_weights_df

    return [
            (train_expl_mats_list,
             eval_expl_mats_list,
             refit_expl_mats_list),
            (train_expl_mat,
             eval_expl_mat,
             refit_expl_mat),
            (train_resp_arr,
             eval_resp_arr,
             refit_resp_arr),
            (train_eb_weights_df,
             eval_eb_weights_df,
             refit_eb_weights_df)
           ]


def read_data_for_experiment1(
        input_dir_str
      ):
    """
    Read data
    """
    d_test_df = pd.read_csv(
        os.path.join(input_dir_str, 'test_Experiment1.csv'))

    test_expls_df = d_test_df.drop('Y', axis=1)
    test_resp_srs = d_test_df['Y']
    test_expl_mat = test_expls_df.to_numpy()
    test_resp_arr = test_resp_srs.to_numpy()

    return test_expl_mat, test_resp_arr


def read_data_for_experiment2(
        input_dir_str
      ):
    """
    Read data
    """
    d_test_df = pd.read_csv(
        os.path.join(input_dir_str, 'test_Experiment2.csv'))

    test_expls_df = d_test_df.drop('Y', axis=1)
    test_resp_srs = d_test_df['Y']
    test_expl_mat = test_expls_df.to_numpy()
    test_resp_arr = test_resp_srs.to_numpy()

    return test_expl_mat, test_resp_arr


def build_and_validate_for_GBT(
          train_expl_mat,
          train_resp_arr,
          eval_expl_mat,
          eval_resp_arr,
          refit_expl_mat,
          refit_resp_arr,
          test_expl_mat,
          test_resp_arr,
          train_balancing_weights=None,
          refit_balancing_weights=None
      ) -> float:
    """
    Search hyperparameters for Gradient Boosting Tree(GBT) without
    weights
    """
    study = optuna.create_study(direction='minimize')
    study.optimize(
        objective_GBT(
          train_expl_mat,
          train_resp_arr,
          eval_expl_mat,
          eval_resp_arr,
          RATE_SUBSAMPLE_TUNE_GDB,
          train_balancing_weights),
        n_trials=N_TRY_OPT,
        timeout=OPTUNA_TIMEOUT_SECONDS)
    best_trial = study.best_trial
    print(best_trial)

    # Built a model of GBT
    max_leaf_nodes = best_trial.params['max_leaf_nodes']
    learning_rate = best_trial.params['learning_rate']
    print(max_leaf_nodes, learning_rate)
    rgg = GradientBoostingRegressor(
            random_state=0,
            max_leaf_nodes=max_leaf_nodes,
            learning_rate=learning_rate)
    if train_balancing_weights is None or refit_balancing_weights is None:
        rgg.fit(
            refit_expl_mat, refit_resp_arr)
    else:
        rgg.fit(
            refit_expl_mat,
            refit_resp_arr,
            refit_balancing_weights)

    # Estimate the averate causal effect frum the model of GBT
    Y_estimated = rgg.predict(test_expl_mat)

    # Root Mean Square Errors
    rmse = mean_squared_error(
      test_resp_arr, Y_estimated, squared=False)

    return rmse


def build_and_validate_for_LR(
          refit_expl_mat,
          refit_resp_arr,
          test_expl_mat,
          test_resp_arr,
          refit_balancing_weights=None
      ) -> float:

    if refit_balancing_weights is None:
        rgg = LinearRegression().fit(
            refit_expl_mat,
            refit_resp_arr)
    else:
        rgg = LinearRegression().fit(
            refit_expl_mat,
            refit_resp_arr,
            refit_balancing_weights)

    # Estimate the averate causal effect frum the model of GBT
    Y_estimated = rgg.predict(test_expl_mat)

    # Root Mean Square Errors
    rmse = mean_squared_error(
      test_resp_arr, Y_estimated, squared=False)

    return rmse


def _run_one_simulation(
          simulation_id,
          experiment_type,
          output_top_dir
      ):
    """
    Read data
    """
    input_train_dir_str = os.path.join(
        INPUT_TRAIN_DATA_DIR, simulation_id)
    expl_mat_lists_tupple, \
        expl_mats_tupple, \
        resp_arrs_tupple, \
        eb_weight_tupple = read_train_data(
                input_train_dir_str)
    input_test_dir_str = os.path.join(
        INPUT_TEST_DATA_DIR, simulation_id)
    if experiment_type == 'Experiment1':
        test_expl_mat, test_resp_arr = \
            read_data_for_experiment1(input_test_dir_str)
        n_epochs_model = N_EPOCHS_MODEL_EXP1
        n_epochs_estimate = N_EPOCHS_ESTIMATE_EXP1
    elif experiment_type == 'Experiment2':
        test_expl_mat, test_resp_arr = \
            read_data_for_experiment2(input_test_dir_str)
        n_epochs_model = N_EPOCHS_MODEL_EXP2
        n_epochs_estimate = N_EPOCHS_ESTIMATE_EXP2

    train_expl_mats_list = expl_mat_lists_tupple[0]
    eval_expl_mats_list = expl_mat_lists_tupple[1]
    refit_expl_mats_list = expl_mat_lists_tupple[2]
    train_expl_mat = expl_mats_tupple[0]
    eval_expl_mat = expl_mats_tupple[1]
    refit_expl_mat = expl_mats_tupple[2]
    train_resp_arr = resp_arrs_tupple[0]
    eval_resp_arr = resp_arrs_tupple[1]
    refit_resp_arr = resp_arrs_tupple[2]
    train_eb_weights_df = eb_weight_tupple[0]
    refit_eb_weights_df = eb_weight_tupple[2]

    # set directory for outputing all results
    out_log_dir = os.path.join(output_top_dir, 'logs', simulation_id)
    os.makedirs(out_log_dir, exist_ok=True)

    # Structure of a NGB model
    params_nbw = {
        'n_layers': N_LAYER_NN,
        'hidden_dim': N_UNITS_HIDDEN_NN}

    # Build nbw models
    filePaths_of_nbw_models_to_use_list, \
        alpha_infomation_estimated, \
        alpha_infos_of_all_nbw_models = train_NBW(
            ALPHA,
            train_expl_mats_list,
            eval_expl_mats_list,
            False,      # do_estimate_alpha_inf
            params_nbw,
            LEARNING_RATE,
            BATCHSIZE,
            n_epochs_model,
            out_log_dir)

    print('--------------------------------------------------------------')
    print(f'Estimateid alpha-information for the balanced distribution = ',
          f'{alpha_infomation_estimated}')
    print('--------------------------------------------------------------')

    result_rmse_dict = dict()
    """
    Measure performances of models for "NBW"
    """
    # Estimate the balancig weights
    train_balancing_weights_estimated = estimate_balancing_weights(
        train_expl_mats_list,
        params_nbw,
        filePaths_of_nbw_models_to_use_list)
    refit_balancing_weights_estimated = estimate_balancing_weights(
        refit_expl_mats_list,
        params_nbw,
        filePaths_of_nbw_models_to_use_list)
    # build and caluculate RMSE for modeling GBT with nbw
    rmse = build_and_validate_for_GBT(
        train_expl_mat,
        train_resp_arr,
        eval_expl_mat,
        eval_resp_arr,
        refit_expl_mat,
        refit_resp_arr,
        test_expl_mat,
        test_resp_arr,
        train_balancing_weights_estimated,
        refit_balancing_weights_estimated)
    result_rmse_dict['GBT_NBW'] = rmse
    # build and caluculate RMSE for modeling LR with nbw
    rmse = build_and_validate_for_LR(
        refit_expl_mat,
        refit_resp_arr,
        test_expl_mat,
        test_resp_arr,
        refit_balancing_weights_estimated)
    result_rmse_dict['LR_NBW'] = rmse

    """
    Measure performances of models for "EB"
    """
    for _i_moment in range(1, 5):
        # build and caluculate RMSE for modeling GBT
        # with EB weights of Momoent=_i_moment
        colname_to_use = f'moment_{_i_moment}'
        rmse = build_and_validate_for_GBT(
            train_expl_mat,
            train_resp_arr,
            eval_expl_mat,
            eval_resp_arr,
            refit_expl_mat,
            refit_resp_arr,
            test_expl_mat,
            test_resp_arr,
            train_eb_weights_df[colname_to_use],
            refit_eb_weights_df[colname_to_use])
        result_rmse_dict[f'GBT_EB({_i_moment})'] = rmse
        # build and caluculate RMSE for modeling LR
        # with EB weights of Momoent=_i_moment
        rmse = build_and_validate_for_LR(
            refit_expl_mat,
            refit_resp_arr,
            test_expl_mat,
            test_resp_arr,
            refit_eb_weights_df[colname_to_use])
        result_rmse_dict[f'LR_EB({_i_moment})'] = rmse

    """
    Measure performances of models for "no weights"
    """
    # build and caluculate RMSE for modeling GBT with no weights
    rmse = build_and_validate_for_GBT(
        train_expl_mat,
        train_resp_arr,
        eval_expl_mat,
        eval_resp_arr,
        refit_expl_mat,
        refit_resp_arr,
        test_expl_mat,
        test_resp_arr)
    result_rmse_dict['GBT_NoWeights'] = rmse
    # build and caluculate RMSE for modeling LR with no weights
    rmse = build_and_validate_for_LR(
        refit_expl_mat,
        refit_resp_arr,
        test_expl_mat,
        test_resp_arr)
    result_rmse_dict['LR_NoWeights'] = rmse

    res_df = pd.DataFrame(
      data=result_rmse_dict,
      index=[simulation_id])

    print('--------------------------------------------------------------')
    print(f'All RMSE: simulation id = {simulation_id}')
    print(res_df)
    print('--------------------------------------------------------------')

    return res_df


def run_all_simulations(
        experiment_type,
        current_datetime):
    """
    Run an experment
    """
    print(
      f'Experiment type: {experiment_type}')
    print(
      f'Start Building a NGB model... --- current time: {current_datetime}')

    execution_id = f'{experiment_type}_{current_datetime}'
    output_top_dir = os.path.join(
        PARENT_OUT_DIR,
        execution_id)
    os.makedirs(output_top_dir, exist_ok=True)

    all_data_dirs = glob.glob(
        os.path.join(INPUT_TRAIN_DATA_DIR, '*/'))

    if IS_TEST:
        all_data_dirs = all_data_dirs[:MAX_SIM_TEST]

    all_res_dfs_list = list()
    for _i_sim in range(len(all_data_dirs)):
        data_dir_path = all_data_dirs[_i_sim]
        sim_id_str = os.path.basename(
            os.path.normpath(data_dir_path))
        res_df = _run_one_simulation(
          sim_id_str,
          experiment_type,
          output_top_dir)
        all_res_dfs_list.append(res_df)
    all_result_df = pd.concat(all_res_dfs_list, axis=0)
    all_result_df.reset_index(drop=False, inplace=True)

    csv_file_name = execution_id + '.csv'
    all_result_df.to_csv(
        os.path.join(output_top_dir, csv_file_name), index=False)


# %%
if __name__ == '__main__':
    now = datetime.datetime.now()
    current_datetime = now.strftime('%Y%m%d_%H%M_%S')
    np.random.seed(1)
    random.seed(1)
    run_all_simulations(
          'Experiment2',
          current_datetime)

    current_datetime = now.strftime('%Y%m%d_%H%M_%S')
    np.random.seed(1)
    random.seed(1)
    run_all_simulations(
          'Experiment1',
          current_datetime)

