import os, sys
import argparse
import pickle

import numpy as np
import pandas as pd
import torch

from lib.DensityRateEstimation import *
from lib.ConcatResults import concat_and_out

assert torch.cuda.is_available()
device = torch.device('cuda')
torch.set_default_tensor_type('torch.cuda.FloatTensor')

WORK_DIR = './'
PARENT_OUT_DIR = os.path.join(WORK_DIR, 'out/Expr_Section_D_1')
PARENT_DATA_DIR = os.path.join(WORK_DIR, 'data/DRE')

ALPHA=0.5
N_BLOCKS = 1
N_UNITS_HIDDEN_NN = 100
N_LAYERS_PER_BLOCKS = 1
EARLY_STOPPING_PATIENCE = 0
BATCHSIZE = 128
DROPOUT = 0.0
LEARNING_RATE = 0.0001
SAVE_RESULT=False
SEED = 0

N_TEST = 100
DATA_DIM_LIST = [10, 20, 30, 50, 100]
DIM_to_N_EPOCHS_DICT = {
              10: 40,
              20: 50,
              30: 50,
              50: 50,
              100: 60}

for _data_dim in DATA_DIM_LIST:
    DATA_DIM = _data_dim
    N_EPOCHS = DIM_to_N_EPOCHS_DICT[_data_dim]
    data_dirname = f'Dim_{DATA_DIM:03d}'
    out_dirname = f'AlphaDiv_Dim_{DATA_DIM:03d}'
    out_results_dir= os.path.join(
          PARENT_OUT_DIR,
          out_dirname)
    os.makedirs(out_results_dir, exist_ok=True)

    for _test_id in range(N_TEST):
        DATA_ID = _test_id
        DATA_DIM = _data_dim
        data_dirname = f'Dim_{DATA_DIM:03d}'
        data_filename_base = f'{DATA_ID:04d}'
        out_modeliing_log_dir= os.path.join(
            out_results_dir,
            data_filename_base)
        os.makedirs(out_modeliing_log_dir, exist_ok=True)

        # --- read data ---
        train_data_dir_path = os.path.join(
            PARENT_DATA_DIR, data_dirname, 
            data_filename_base, 'train')
        train_denominator_data_np = np.loadtxt(
            os.path.join(train_data_dir_path, 'de.csv'),
            delimiter=',')
        train_numerator_data_np = np.loadtxt(
            os.path.join(train_data_dir_path, 'nu.csv'),
            delimiter=',')
        eval_data_dir_path = os.path.join(
            PARENT_DATA_DIR, data_dirname,
            data_filename_base, 'eval')
        eval_denominator_data_np = np.loadtxt(
            os.path.join(eval_data_dir_path, 'de.csv'),
            delimiter=',')
        eval_numerator_data_np = np.loadtxt(
            os.path.join(eval_data_dir_path, 'nu.csv'),
            delimiter=',')
        test_data_dir_path = os.path.join(
            PARENT_DATA_DIR, data_dirname, 
            data_filename_base, 'test')
        test_denominator_data_np = np.loadtxt(
            os.path.join(test_data_dir_path, 'de.csv'),
            delimiter=',')
        test_numerator_data_np = np.loadtxt(
            os.path.join(test_data_dir_path, 'nu.csv'),
            delimiter=',')
        test_true_rate_np = np.loadtxt(
            os.path.join(test_data_dir_path,
            'true_rate.csv'),
            delimiter=',')

        # --- run experiment ---
        params_training = {
            'alpha': ALPHA,
            'hidden_dim': N_UNITS_HIDDEN_NN,
            'n_layers_per_block': N_LAYERS_PER_BLOCKS,
            'n_blocks': N_BLOCKS,
            'dropout': DROPOUT,
            'DoBatchNormarize': N_BLOCKS > 1}
        kl_div, mse = dre_train_for_all_epoch(
                    train_denominator_data_np,
                    train_numerator_data_np, 
                    eval_denominator_data_np,
                    eval_numerator_data_np,
                    test_denominator_data_np,
                    test_numerator_data_np,
                    params_training,
                    LEARNING_RATE,
                    BATCHSIZE,
                    N_EPOCHS,
                    EARLY_STOPPING_PATIENCE,
                    out_modeliing_log_dir,
                    do_save_result=SAVE_RESULT,
                    seed=SEED,
                    true_rate_for_test=test_true_rate_np)

        row_index = pd.MultiIndex.from_tuples(
            [(data_dirname, data_filename_base)])
        resut_df = pd.DataFrame(
              index=row_index, 
              data={
                'KL_divergence': [kl_div],
                'MSE':[mse]})
        with open(os.path.join(out_results_dir, 
            f'{data_filename_base}.pickle'), 'wb') as fp:
            pickle.dump(resut_df, fp)
    in_results_dir = out_results_dir
    concat_and_out(in_results_dir, 
                   PARENT_OUT_DIR)