# %%
import os, sys
import re
import itertools

import pickle

import numpy as np
import pandas as pd
import torch
from lightning.pytorch.loggers import WandbLogger

from lib.dre.train.linear import *
from lib.concat_results import concat_and_out

assert torch.cuda.is_available()

python_file_name = os.path.basename(__file__)
out_dirname = re.search(r'.+_run_experiment_(.+?)(|-.*).py', python_file_name).group(1)

# Please specify an appropriate project name. 
# If not specified, the script name will be used as the default.
# ex. WANDB_PROJECT_NAME = 'project_X'
WANDB_PROJECT_NAME = None

# CUDA DEVICE to use
CUDA_DEVICE_ID = 0

MODE = []
#MODE += ['TEST']

N_TEST=100
WORK_DIR = './'
PARENT_OUT_DIR = os.path.join(WORK_DIR, 'out', out_dirname)
PARENT_DATA_DIR = os.path.join(WORK_DIR, 'data', out_dirname)
if WANDB_PROJECT_NAME is None:
    WANDB_PROJECT_NAME = out_dirname


N_BLOCKS = 1
EARLY_STOPPING_PATIENCE = 1
N_EPOCHS = 5000
DROPOUT = 0.0
SAVE_RESULT=False
SEED = 0

N_UNITS_HIDDEN_NN = 1024
N_LAYERS_PER_BLOCKS = 4
LEARNING_RATE = 0.0001
BATCHSIZE = 128

N_modal_list = [1, 2, 3, 4]
DIM_data_list = [50, 100, 200]
KL_div_list = [3]


# n_data
n_data_list = [1000, 2000, 4000, 8000, 16000]

method_list = [
        'alphaDiv']
        #'KLdivergence-energy']

train_params_dict = {
        'alphaDiv':{'alpha': 0.5,
                    'early_stoppping_partience': EARLY_STOPPING_PATIENCE},
        'KLdivergence-energy': {
            'early_stoppping_partience': EARLY_STOPPING_PATIENCE}}   

if 'TEST' in MODE:
    N_TEST = 1
    PARENT_OUT_DIR = os.path.join(PARENT_OUT_DIR, 'TEST')
    N_modal_list = N_modal_list[:2]
    DIM_data_list = DIM_data_list[:2]
    KL_div_list = KL_div_list[:2]
    n_data_list = n_data_list[:2]


def create_param_string(para_dict):
    first = True
    for key, val in para_dict.items():
        if first:
            res = f'{key}-{val:.6f}'
            first = False
        else:
            res += f'_{key}-{val:.6f}'
    return res 

params_training_base = {
    'hidden_dim': N_UNITS_HIDDEN_NN,
    'learning_rate': LEARNING_RATE,
    'n_layers_per_block': N_LAYERS_PER_BLOCKS,
    'n_blocks': N_BLOCKS,
    'dropout': DROPOUT,
    'batch_size': BATCHSIZE,
    'max_epochs': N_EPOCHS,
    'DoBatchNormarize': N_BLOCKS > 1}

# %%
for method in method_list:
    ##### Run Eexperinment #####
    wandb_logdir = os.path.join(WORK_DIR, 'log', 'wandb', WANDB_PROJECT_NAME)
    os.makedirs(wandb_logdir, exist_ok=True)
    wandb_logger = WandbLogger(project=WANDB_PROJECT_NAME,
                                save_dir=wandb_logdir)
    for n_mdl, dim, kl  in itertools.product(N_modal_list, DIM_data_list, KL_div_list):  
        if (dim == 100 and n_mdl == 1):
            continue
        data_type = f'dim-{dim}_kl-{kl}_nmdl-{n_mdl}'
        out_file_name_base = '_'.join(
                [method, data_type])
        out_file_path = os.path.join(PARENT_OUT_DIR, out_file_name_base + '.csv')
        result_df_list = []
        params_training = params_training_base.copy()
        train_params  =  train_params_dict[method]
        params_training.update(train_params)
        params_str = create_param_string(params_training)
        out_results_dir= os.path.join(
            PARENT_OUT_DIR,
            method,
            data_type)
        os.makedirs(out_results_dir, exist_ok=True)

        for _test_id in range(N_TEST):
            DATA_ID = _test_id
            data_id_str = f'{DATA_ID:04d}'
            out_modeliing_log_dir= os.path.join(
                out_results_dir,
                data_id_str)
            os.makedirs(out_modeliing_log_dir, exist_ok=True)

            data_dir_path = os.path.join(
                PARENT_DATA_DIR,
                data_type,
                data_id_str)
            # --- read data ---
            train_data_file_path = os.path.join(
                data_dir_path, 'train.npz')
            d_train = np.load(train_data_file_path)
            eval_data_file_path = os.path.join(
                data_dir_path, 'eval.npz')
            d_eval = np.load(eval_data_file_path)
            test_data_file_path = os.path.join(
                data_dir_path, 'test.npz')
            d_test = np.load(test_data_file_path)

            train_denominator_data_np = d_train['de']
            train_numerator_data_np = d_train['nu']
            eval_denominator_data_np = d_eval['de']
            eval_numerator_data_np = d_eval['nu']
            test_denominator_data_np = d_test['de']
            test_numerator_data_np = d_test['nu']
            test_true_density_rate_np = d_test['true_density_rate']
            true_KL = np.float32(d_test['KL_div'])

            # --- run experiment ---
            for n_train_data in n_data_list:
                # --- run experiment ---
                target_estimated_div, res_dict = dre_train_for_all_epoch(
                                train_denominator_data_np[0:n_train_data, :],
                                train_numerator_data_np[0:n_train_data, :], 
                                eval_denominator_data_np,
                                eval_numerator_data_np,
                                test_denominator_data_np,
                                test_numerator_data_np,
                                params_training,
                                method=method,
                                logger=wandb_logger,
                                out_results_dir=out_modeliing_log_dir,
                                do_save_result=SAVE_RESULT,
                                seed=SEED,
                                device_id=CUDA_DEVICE_ID,
                                true_rate_for_test=test_true_density_rate_np)

                row_index = pd.MultiIndex.from_tuples(
                    [(method, data_type, data_id_str, 
                        n_train_data, params_str)])

                resut_df = pd.DataFrame(
                    index=row_index, 
                    data={
                        'estimated_target_divergence': [res_dict['estimated_target_divergence']],
                        'RMSE':[res_dict['rmse']],
                        'L1': [res_dict['L1']],
                        'L3': [res_dict['L3']],
                        'bias': [res_dict['bias']],
                        'estimated_KL_divergence': [res_dict['estimated_KL_divergence']],
                        'true_KL_divergence': [true_KL],
                        'n_dimension': [dim],
                        'n_mode': [n_mdl],
                        })
                print(
                    data_type, "/",
                    data_id_str,
                    '/ n train data = ', n_train_data)
                print(resut_df)
                result_df_list.append(resut_df)

        result_df = pd.concat(result_df_list)
        result_df.to_csv(out_file_path)

# %%
