# %%
import os, sys
import argparse
import pickle
import random

import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error
import torch
import torch.optim as optim 

from D3RE.train import train, loss_func, test 
from D3RE.model import NN, CNN

from lib.DensityRateEstimation import seed_everything
from lib.ConcatResults import concat_and_out

assert torch.cuda.is_available()
device = torch.device('cuda')


WORK_DIR = './'
PARENT_OUT_DIR = os.path.join(WORK_DIR, 'out/Expr_Section_D_1')
PARENT_DATA_DIR = os.path.join(WORK_DIR, 'data/DRE')

BATCHSIZE=128
LEARNING_RATE=0.00005
N_EPOCHS=250
UPPER_BOUND_C = 2.0
N_TEST = 100
DATA_DIM_LIST = [10, 20, 30, 50, 100]

for _data_dim in DATA_DIM_LIST:
    DATA_DIM = _data_dim
    data_dirname = f'Dim_{DATA_DIM:03d}'
    out_dirname = f'Kato_Dim_{DATA_DIM:03d}'
    out_results_dir= os.path.join(
          PARENT_OUT_DIR,
          out_dirname)
    os.makedirs(out_results_dir, exist_ok=True)
    for _test_id in range(N_TEST):
        DATA_ID = _test_id
        data_filename_base = f'{DATA_ID:04d}' 

        # --- read data ---
        train_data_dir_path = os.path.join(
            PARENT_DATA_DIR, data_dirname, 
            data_filename_base, 'train')
        train_denominator_data_np = np.loadtxt(
            os.path.join(train_data_dir_path, 'de.csv'),
            delimiter=',')
        train_numerator_data_np = np.loadtxt(
            os.path.join(train_data_dir_path, 'nu.csv'),
            delimiter=',')
        test_data_dir_path = os.path.join(
            PARENT_DATA_DIR, data_dirname, 
            data_filename_base, 'test')
        test_denominator_data_np = np.loadtxt(
            os.path.join(test_data_dir_path, 'de.csv'),
            delimiter=',')
        test_numerator_data_np = np.loadtxt(
            os.path.join(test_data_dir_path, 'nu.csv'),
            delimiter=',')
        test_true_rate_np = np.loadtxt(
            os.path.join(test_data_dir_path,
            'true_rate.csv'),
            delimiter=',')

        x_train_de = train_denominator_data_np
        x_train_nu = train_numerator_data_np
        x_train = np.concatenate([x_train_nu, x_train_de], axis=0)

        t_train_de = np.zeros(len(x_train_de))
        t_train_nu = np.ones(len(x_train_nu))
        t_train = np.concatenate([t_train_nu, t_train_de], axis=0)

        x_test_de = test_denominator_data_np
        x_test_nu = test_numerator_data_np
        x_test = np.concatenate([x_test_nu, x_test_de], axis=0)

        t_test_de = np.zeros(len(x_test_de))
        t_test_nu = np.ones(len(x_test_nu))
        t_test = np.concatenate([t_test_nu, t_test_de], axis=0)

        # --- run experiment ---
        seed_everything(DATA_ID)
        Net = NN
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model_boundedulsif = Net(DATA_DIM).to(device)
        optimizer_boundedulsif = optim.Adam(
            params=model_boundedulsif.parameters(), lr=LEARNING_RATE)
        train_boundedulisf, test_boundedulsif, auc_boundedulsif, mean_boundedulsif = train(x_train, t_train, x_test, t_test,
                N_EPOCHS,
                model_boundedulsif,
                optimizer_boundedulsif,
                device,
                batchsize=BATCHSIZE,
                method='boundeduLSIF',
                upper_bound=UPPER_BOUND_C)

        x_test_de_tsr=torch.from_numpy(
            x_test_de.astype(np.float32)).to(device)

        pred_rate = model_boundedulsif(x_test_de_tsr)
        estimated_rate_np = pred_rate.cpu().detach().numpy()
        mse = mean_squared_error(
            estimated_rate_np, test_true_rate_np)

        row_index = pd.MultiIndex.from_tuples(
            [(data_dirname, data_filename_base)])
        resut_df = pd.DataFrame(
              index=row_index, 
              data={
                'KL_divergence': [np.nan],
                'MSE':[mse]})
        with open(os.path.join(out_results_dir, 
            f'{data_filename_base}.pickle'), 'wb') as fp:
            pickle.dump(resut_df, fp)
    in_results_dir = out_results_dir
    concat_and_out(in_results_dir, 
                   PARENT_OUT_DIR)
