# %%
import os
import re

import numpy as np
import pandas as pd
import torch

from lib.dre.train.with_monitoring import *

assert torch.cuda.is_available()

MODE = []
#MODE += ['TEST']

python_file_name = os.path.basename(__file__)
out_dirname = re.findall(r'.+_run_experiment_(.+).py', python_file_name)[0]

WORK_DIR = './'
PARENT_OUT_DIR = os.path.join(WORK_DIR, 'out/', out_dirname)
PARENT_DATA_DIR = os.path.join(WORK_DIR, 'data/', out_dirname)

N_BLOCKS = 1
N_UNITS_HIDDEN_NN = 100
N_LAYERS_PER_BLOCKS = 2
BATCHSIZE = 2500
DROPOUT = 0.0
SAVE_RESULT = False

DATA_DIM = 5
dist_centers_nu_and_de = 5

SEED = 0
CUDA_DEVICE_ID = 0
N_TEST = 100

method_list = [
        'alphaDiv',
        'alphaDiv-biased',
        'nnBD-LSIF']
max_rate = 10
train_params_dict = {
        'alphaDiv':{'alpha': 0.5},
        'alphaDiv-biased':{'alpha': 0.5},
        'nnBD-LSIF':{'C': 1/max_rate}}
batchsize_to_train_def_dict ={
    500:{
       'learning_rate': 0.00001,
       'steps_per_monitor': 10,
       'all_epochs': 250
    },
    1000:{
       'learning_rate': 0.00005,
       'steps_per_monitor': 5,
       'all_epochs': 500
    },
    2500:{
       'learning_rate': 0.00005,
       'steps_per_monitor': 2,
       'all_epochs': 1000
    },
    5000:{
       'learning_rate': 0.0001,
       'steps_per_monitor': 2,
       'all_epochs': 1500
    },
    10000:{
       'learning_rate': 0.0001,
       'steps_per_monitor': 1,
       'all_epochs': 2000
    },
}

if 'TEST' in MODE:
    N_TEST = 1
    PARENT_OUT_DIR = os.path.join(PARENT_OUT_DIR, 'TEST')
    data_typ_list = ['GenDataTwoShiftedGausses']
    method_list = [
        'alphaDiv',
        'alphaDiv-biased',
        'nnBD-LSIF']

for method in method_list:
    LEARNING_RATE = batchsize_to_train_def_dict[BATCHSIZE]['learning_rate']
    N_EPOCHS = batchsize_to_train_def_dict[BATCHSIZE]['all_epochs']
    steps_per_monitor = batchsize_to_train_def_dict[BATCHSIZE]['steps_per_monitor']
    ############################################################
    params_training_base = {
        'hidden_dim': N_UNITS_HIDDEN_NN,
        'learning_rate': LEARNING_RATE,
        'n_layers_per_block': N_LAYERS_PER_BLOCKS,
        'n_blocks': N_BLOCKS,
        'dropout': DROPOUT,
        'batch_size': BATCHSIZE,
        'max_epochs': N_EPOCHS,
        'DoBatchNormarize': N_BLOCKS > 1}
    sub_outdir = f'NN{N_UNITS_HIDDEN_NN}x{N_LAYERS_PER_BLOCKS}x{N_BLOCKS}_bs{BATCHSIZE:03}_lr{LEARNING_RATE*10**9:09.0f}_dropout{DROPOUT*100:03.0f}'
    #############################################################
    params_training = params_training_base.copy()
    params_training['n_steps_per_monitoring'] = steps_per_monitor
    train_params  =  train_params_dict[method]
    params_training.update(train_params)      
    data_dirname = f'Dim_{DATA_DIM:03}_dist{dist_centers_nu_and_de}'
    if 'TEST' in MODE:
        out_method_dir = method + '_' + sub_outdir
    else:
        out_method_dir = method
    out_results_dir= os.path.join(
        PARENT_OUT_DIR,
        out_method_dir,
        data_dirname,
        sub_outdir
        )
    all_res_dfs_list = list()
    os.makedirs(out_results_dir, exist_ok=True)
    out_file_name = '_'.join(
        [out_method_dir, data_dirname, sub_outdir])
    for _test_id in range(N_TEST):
        DATA_ID = _test_id
        data_filename_base = f'{DATA_ID:04d}'
        out_modeliing_log_dir= os.path.join(
            out_results_dir,
            data_filename_base)
        os.makedirs(out_modeliing_log_dir, exist_ok=True)

        data_dir_path = os.path.join(
            PARENT_DATA_DIR,
            data_dirname,
            data_filename_base)
        # --- read data ---
        train_data_file_path = os.path.join(
            data_dir_path, 'train.npz')
        d_train = np.load(train_data_file_path)
        test_data_file_path = os.path.join(
            data_dir_path, 'test.npz')
        d_test = np.load(test_data_file_path)

        train_denominator_data_np = d_train['de']
        train_numerator_data_np = d_train['nu']
        test_denominator_data_np = d_test['de']
        test_numerator_data_np = d_test['nu']
        test_true_density_rate_np = d_test['true_density_rate']

        # --- run experiment ---
        res_dict = dre_train_for_all_epoch(
                        train_denominator_data_np,
                        train_numerator_data_np, 
                        test_denominator_data_np,
                        test_numerator_data_np,
                        params_training,
                        method=method,
                        out_results_dir=out_modeliing_log_dir,
                        do_save_result=SAVE_RESULT,
                        seed=SEED,
                        device_id=CUDA_DEVICE_ID,
                        true_rate_for_test=test_true_density_rate_np)
        res_df = pd.DataFrame.from_dict(res_dict, orient='index')
        colname_sim_id = pd.MultiIndex.from_product(
                [[data_filename_base], res_df.columns])
        res_df.columns = colname_sim_id
        all_res_dfs_list.append(res_df)
    all_result_df = pd.concat(all_res_dfs_list, axis=1)
    all_result_df.reset_index(drop=False, inplace=True)
    all_result_df['method'] = method
    all_result_df['batchsize'] = BATCHSIZE
    all_result_df['model_info'] = sub_outdir
    all_result_df['data_info'] = data_dirname
    csv_file_name = out_file_name + '.csv'
    all_result_df.to_csv(
        os.path.join(PARENT_OUT_DIR, csv_file_name), index=False)
