# %%
import os, sys
import re
import itertools

import pickle

import numpy as np
import pandas as pd
import torch
from lightning.pytorch.loggers import WandbLogger

from lib.dre.train.linear import *
from lib.concat_results import concat_and_out

assert torch.cuda.is_available()

python_file_name = os.path.basename(__file__)
out_dirname = re.search(r'.+_run_experiment_(.+?)(|-.*).py', python_file_name).group(1)

# Please specify an appropriate project name. 
# If not specified, the script name will be used as the default.
# ex. WANDB_PROJECT_NAME = 'project_X'
WANDB_PROJECT_NAME = None

# CUDA DEVICE to use
CUDA_DEVICE_ID = 0


MODE = []
#MODE += ['TEST']

N_TEST=100
WORK_DIR = './'
PARENT_OUT_DIR = os.path.join(WORK_DIR, 'out', out_dirname)
PARENT_DATA_DIR = os.path.join(WORK_DIR, 'data', out_dirname)
if WANDB_PROJECT_NAME is None:
    WANDB_PROJECT_NAME = out_dirname


N_BLOCKS = 1
EARLY_STOPPING_PATIENCE = 3
N_EPOCHS = 5000
DROPOUT = 0.0
SAVE_RESULT=False
SEED = 0

N_UNITS_HIDDEN_NN = 1024
N_LAYERS_PER_BLOCKS = 4
LEARNING_RATE = 0.0001
BATCHSIZE = 128


N_modal_list = [1, 2, 3, 4]
DIM_data_list = [5]
KL_div_list = [1, 2, 4, 6, 8, 10, 12, 14]

method_list = [
        'alphaDiv']
        #'KLdivergence-energy']

train_params_dict = {
        'alphaDiv':{'alpha': 0.5,
                    'early_stoppping_partience': EARLY_STOPPING_PATIENCE},
        'KLdivergence-energy': {
            'early_stoppping_partience': EARLY_STOPPING_PATIENCE}}   

if 'TEST' in MODE:
    N_TEST = 1
    PARENT_OUT_DIR = os.path.join(PARENT_OUT_DIR, 'TEST')
    N_modal_list = N_modal_list[:2]
    DIM_data_list = DIM_data_list[:2]
    KL_div_list = KL_div_list[:2]


def create_param_string(para_dict):
    first = True
    for key, val in para_dict.items():
        if first:
            res = f'{key}-{val:.6f}'
            first = False
        else:
            res += f'_{key}-{val:.6f}'
    return res 

params_training_base = {
    'hidden_dim': N_UNITS_HIDDEN_NN,
    'learning_rate': LEARNING_RATE,
    'n_layers_per_block': N_LAYERS_PER_BLOCKS,
    'n_blocks': N_BLOCKS,
    'dropout': DROPOUT,
    'batch_size': BATCHSIZE,
    'max_epochs': N_EPOCHS,
    'DoBatchNormarize': N_BLOCKS > 1}

# %%
for method in method_list:
    ##### Run Eexperinment #####
    wandb_logdir = os.path.join(WORK_DIR, 'log', 'wandb', WANDB_PROJECT_NAME)
    os.makedirs(wandb_logdir, exist_ok=True)
    wandb_logger = WandbLogger(project=WANDB_PROJECT_NAME,
                                save_dir=wandb_logdir)
    for n_mdl, dim, kl  in itertools.product(N_modal_list, DIM_data_list, KL_div_list):
        if n_mdl == 1:
            continue
        data_type = f'dim-{dim}_kl-{kl}_nmdl-{n_mdl}'
        out_file_name_base = '_'.join(
                [method, data_type])
        out_file_path = os.path.join(PARENT_OUT_DIR, out_file_name_base + '.csv')
        result_df_list = []
        params_training = params_training_base.copy()
        train_params  =  train_params_dict[method]
        params_training.update(train_params)
        params_str = create_param_string(params_training)
        out_results_dir= os.path.join(
            PARENT_OUT_DIR,
            method,
            data_type)
        os.makedirs(out_results_dir, exist_ok=True)

        for _test_id in range(N_TEST):
            DATA_ID = _test_id
            data_id_str = f'{DATA_ID:04d}'
            out_modeliing_log_dir= os.path.join(
                out_results_dir,
                data_id_str)
            os.makedirs(out_modeliing_log_dir, exist_ok=True)

            data_dir_path = os.path.join(
                PARENT_DATA_DIR,
                data_type,
                data_id_str)
            # --- read data ---
            train_data_file_path = os.path.join(
                data_dir_path, 'train.npz')
            d_train = np.load(train_data_file_path)
            eval_data_file_path = os.path.join(
                data_dir_path, 'eval.npz')
            d_eval = np.load(eval_data_file_path)
            test_data_file_path = os.path.join(
                data_dir_path, 'test.npz')
            d_test = np.load(test_data_file_path)

            train_denominator_data_np = d_train['de']
            train_numerator_data_np = d_train['nu']
            eval_denominator_data_np = d_eval['de']
            eval_numerator_data_np = d_eval['nu']
            test_denominator_data_np = d_test['de']
            test_numerator_data_np = d_test['nu']
            test_true_density_rate_np = d_test['true_density_rate']
            true_KL = np.float32(d_test['KL_div'])


            # --- run experiment ---
            target_estimated_div, res_dict = dre_train_for_all_epoch(
                            train_denominator_data_np,
                            train_numerator_data_np, 
                            eval_denominator_data_np,
                            eval_numerator_data_np,
                            test_denominator_data_np,
                            test_numerator_data_np,
                            params_training,
                            method=method,
                            logger=wandb_logger,
                            out_results_dir=out_modeliing_log_dir,
                            do_save_result=SAVE_RESULT,
                            seed=SEED,
                            device_id=CUDA_DEVICE_ID,
                            true_rate_for_test=test_true_density_rate_np)

            row_index = pd.MultiIndex.from_tuples(
                [(method, data_type, data_id_str, params_str)])

            resut_df = pd.DataFrame(
                index=row_index, 
                data={
                    'estimated_target_divergence': [res_dict['estimated_target_divergence']],
                    'RMSE':[res_dict['rmse']],
                    'L1': [res_dict['L1']],
                    'L3': [res_dict['L3']],
                    'bias': [res_dict['bias']],
                    'estimated_KL_divergence': [res_dict['estimated_KL_divergence']],
                    'true_KL_divergence': [true_KL],
                    'n_dimension': [dim],
                    'n_mode': [n_mdl],
                    })
            print(
                data_type, "/",
                data_id_str)
            print(resut_df)
            result_df_list.append(resut_df)

        result_df = pd.concat(result_df_list)
        result_df.to_csv(out_file_path)


