# %%
import os, sys
import re
import pickle

import numpy as np
import pandas as pd
import torch
from lightning.pytorch.loggers import WandbLogger

from lib.dre.train.linear import *
from lib.concat_results import concat_and_out

assert torch.cuda.is_available()

python_file_name = os.path.basename(__file__)
out_dirname = re.search(r'.+_run_experiment_(.+?)(|-.*).py', python_file_name).group(1)

# Please specify an appropriate project name. 
# If not specified, the script name will be used as the default.
# ex. WANDB_PROJECT_NAME = 'project_X'
WANDB_PROJECT_NAME = None

# CUDA DEVICE to use
CUDA_DEVICE_ID = 0


### Experimental settings (Please do not change the following) ###
MODE = []
#MODE += ['TEST']

N_TEST=100
WORK_DIR = './'
PARENT_OUT_DIR = os.path.join(WORK_DIR, 'out', out_dirname)
PARENT_DATA_DIR = os.path.join(WORK_DIR, 'data', out_dirname)
if WANDB_PROJECT_NAME is None:
    WANDB_PROJECT_NAME = out_dirname
N_BLOCKS = 1
N_UNITS_HIDDEN_NN = 256
N_LAYERS_PER_BLOCKS = 2
N_EPOCHS = 5000
DROPOUT = 0.0
SAVE_RESULT = False
SEED = 0
EARLY_STOPPING_PATIENCE = 32
BATCHSIZE = 256
LEARNING_RATE = 0.00005
denominator_sigma_list = [1.0, 1.1, 1.2, 1.4, 1.6, 2.0, 2.5, 3.0]

method_list = [
        'alphaDiv',
        #'KLdivergence-energy'
        ]

train_params_dict = {
        'alphaDiv':{'alpha': 0.5,
                    'early_stoppping_partience': EARLY_STOPPING_PATIENCE},
        'KLdivergence-energy': {
            'early_stoppping_partience': EARLY_STOPPING_PATIENCE}}

if 'TEST' in MODE:
    N_TEST = 2
    denominator_sigma_list = [0.9]
    sub_outdir = f'NN{N_UNITS_HIDDEN_NN}x{N_LAYERS_PER_BLOCKS}x{N_BLOCKS}_bs{BATCHSIZE:03}_lr{LEARNING_RATE*10**9:09.0f}_dropout{DROPOUT*100:03.0f}_er{EARLY_STOPPING_PATIENCE:02}'
    PARENT_OUT_DIR = os.path.join(PARENT_OUT_DIR, 'TEST', sub_outdir)


##### Run Eexperinment #####
wandb_logdir = os.path.join(WORK_DIR, 'log', 'wandb', WANDB_PROJECT_NAME)
os.makedirs(wandb_logdir, exist_ok=True)
wandb_logger = WandbLogger(project=WANDB_PROJECT_NAME,
                           save_dir=wandb_logdir)
params_training_base = {
    'hidden_dim': N_UNITS_HIDDEN_NN,
    'learning_rate': LEARNING_RATE,
    'n_layers_per_block': N_LAYERS_PER_BLOCKS,
    'n_blocks': N_BLOCKS,
    'dropout': DROPOUT,
    'batch_size': BATCHSIZE,
    'max_epochs': N_EPOCHS,
    'DoBatchNormarize': N_BLOCKS > 1}
for method in method_list:
    params_training = params_training_base.copy()
    train_params  =  train_params_dict[method]
    params_training.update(train_params)
    for denominator_sigma in denominator_sigma_list:
        data_name = f'denominator_sigma_{denominator_sigma*1000:04.0f}'
        out_results_dir= os.path.join(
            PARENT_OUT_DIR,
            method,
            data_name)
        os.makedirs(out_results_dir, exist_ok=True)
        out_file_name = '_'.join(
            [method, data_name])
        for _test_id in range(N_TEST):
            DATA_ID = _test_id
            test_id_str = f'{DATA_ID:04d}'
            out_modeliing_log_dir= os.path.join(
                out_results_dir,
                test_id_str)
            os.makedirs(out_modeliing_log_dir, exist_ok=True)

            data_dir_path = os.path.join(
                PARENT_DATA_DIR,
                data_name,
                test_id_str)
            
            # --- read data ---
            train_data_file_path = os.path.join(
                data_dir_path, 'train.npz')
            d_train = np.load(train_data_file_path)
            eval_data_file_path = os.path.join(
                data_dir_path, 'eval.npz')
            d_eval = np.load(eval_data_file_path)
            test_data_file_path = os.path.join(
                data_dir_path, 'test.npz')
            d_test = np.load(test_data_file_path)

            train_denominator_data_np = d_train['de']
            train_numerator_data_np = d_train['nu']
            eval_denominator_data_np = d_eval['de']
            eval_numerator_data_np = d_eval['nu']
            test_denominator_data_np = d_test['de']
            test_numerator_data_np = d_test['nu']
            test_true_density_rate_np = d_test['true_density_rate']
            true_KL = np.float32(d_test['KL_div'])

            # --- run experiment ---
            target_estimated_div, res_dict = dre_train_for_all_epoch(
                            train_denominator_data_np,
                            train_numerator_data_np, 
                            eval_denominator_data_np,
                            eval_numerator_data_np,
                            test_denominator_data_np,
                            test_numerator_data_np,
                            params_training,
                            method=method,
                            logger=wandb_logger,
                            out_results_dir=out_modeliing_log_dir,
                            do_save_result=SAVE_RESULT,
                            seed=SEED,
                            device_id=CUDA_DEVICE_ID,
                            true_rate_for_test=test_true_density_rate_np)

            row_index = pd.MultiIndex.from_tuples(
                [(method, data_name, test_id_str)])
            resut_df = pd.DataFrame(
                index=row_index, 
                data={
                    'estimated_target_divergence': [res_dict['estimated_target_divergence']],
                    'RMSE':[res_dict['rmse']],
                    'L1': [res_dict['L1']],
                    'bias': [res_dict['bias']],
                    'estimated_KL_divergence': [res_dict['estimated_KL_divergence']],
                    'true_KL_divergence': [true_KL]})

            print(
                data_name, "/",
                test_id_str)
            print(resut_df)
            with open(os.path.join(out_results_dir, 
                f'{test_id_str}.pickle'), 'wb') as fp:
                pickle.dump(resut_df, fp)
        in_results_dir = out_results_dir
        
        concat_and_out( 
                    1, # max level to concat
                    in_results_dir,
                    PARENT_OUT_DIR,
                    out_file_name
                    )
 
