import os
import random
import datetime

from typing import List, Any, Tuple, Dict

import numpy as np

from sklearn.metrics import mean_squared_error, mean_absolute_error

import torch
from torch import nn, Tensor
from torch.distributions import MultivariateNormal, Uniform
from torch.utils.data import DataLoader

import lightning as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

from lib.utils import set_seed_everything
from lib.dre.common.estimate import calc_Lp_error
from lib.dre.common.estimate import gen_estimate_density_rate_func
from lib.dre.common.estimate import gen_estimate_target_divergence_func, gen_estimate_KL_divergence_func
from lib.dre.lightning.linear import LightningNaiveDREdenseNN




#torch.set_default_tensor_type('torch.cuda.FloatTensor')


def dre_train_for_all_epoch(
            train_denominator_data_np: Any,
            train_numerator_data_np: Any, 
            eval_denominator_data_np: Any,
            eval_numerator_data_np: Any,
            test_denominator_data_np: Any,
            test_numerator_data_np: Any,
            params_training: Dict[str, Any],
            method: str,
            logger,
            out_results_dir: str,
            do_save_result: bool,
            seed: int,
            device_id: int = None,
            true_rate_for_test: Any  = None,
            true_rate_for_eval: Any  = None
            ) ->  List[float]:
    set_seed_everything(seed)
    out_model_dir = os.path.join(out_results_dir, 'energy_models')
    os.makedirs(out_model_dir, exist_ok=True)
    batchsize = params_training['batch_size']
    max_epochs =  params_training['max_epochs']

    if device_id is None:
        device_to_use = torch.device('cuda')
    else:
        device_to_use = torch.device(f'cuda:{device_id}')

    train_denominator_tsr = torch.from_numpy(
        train_denominator_data_np.astype(np.float32))
    train_numerator_tsr = torch.from_numpy(
        train_numerator_data_np.astype(np.float32))
    eval_denominator_tsr = torch.from_numpy(
        eval_denominator_data_np.astype(np.float32))
    eval_numerator_tsr = torch.from_numpy(
        eval_numerator_data_np.astype(np.float32))
    test_denominator_tsr = torch.from_numpy(
        test_denominator_data_np.astype(np.float32))
    test_numerator_tsr = torch.from_numpy(
        test_numerator_data_np.astype(np.float32))
    train_dataset = torch.utils.data.TensorDataset(
        train_denominator_tsr, train_numerator_tsr)
    train_dataloaer = DataLoader(
        train_dataset,
        shuffle=True,
        generator=torch.Generator(device='cpu'),
        batch_size=batchsize)

    if true_rate_for_eval is not None:
        params_training['eval_mse'] = True
        true_rate_for_eval_tsr = torch.from_numpy(
            true_rate_for_eval.astype(np.float32))
        eval_dataset = torch.utils.data.TensorDataset(
            eval_denominator_tsr,
            eval_numerator_tsr,
            true_rate_for_eval_tsr)
    else:
        params_training['eval_mse'] = False
        eval_dataset = torch.utils.data.TensorDataset(
            eval_denominator_tsr, eval_numerator_tsr)

    if true_rate_for_test is not None:
        params_training['test_mse'] = True
        true_rate_for_test_tsr = torch.from_numpy(
            true_rate_for_test.astype(np.float32)).type_as(
                            train_denominator_tsr)
        test_dataset = torch.utils.data.TensorDataset(
            test_denominator_tsr,
            test_numerator_tsr,
            true_rate_for_test_tsr)
    else:
        params_training['test_mse'] = False
        test_dataset = torch.utils.data.TensorDataset(
            test_denominator_tsr, test_numerator_tsr)   

    eval_dataloaer = DataLoader(
        eval_dataset,
        generator=torch.Generator(device='cpu'),
        batch_size=batchsize)

    test_dataloaer = DataLoader(
        test_dataset,
        generator=torch.Generator(device='cpu'),
        batch_size=batchsize) 

    n_dim_input = train_denominator_tsr.shape[1]
    params_training['input_dim'] = n_dim_input
    model_to_train = LightningNaiveDREdenseNN(
        method,
        params_training)
    now = datetime.datetime.now()
    logname_suffix = now.strftime('%Y-%m-%d_%H_%M_%S')
    logname = f'Run_{logname_suffix}'
    # logger = TensorBoardLogger(
    #     out_log_dir,
    #     logname,
    #     default_hp_metric=False)
    early_stoppping_partience = params_training['early_stoppping_partience']
    if early_stoppping_partience > 0:
        early_stop_callback = EarlyStopping(
            monitor='validate_target_divergence',
            min_delta=0.00,
            patience=early_stoppping_partience,
            verbose=False,
            mode='max'
        )
        pytorch_trainer = pl.Trainer(
            accelerator='gpu',
            devices=[device_id],
            strategy='auto',
            logger=logger,
            max_epochs=max_epochs,
            callbacks=[early_stop_callback],
            enable_progress_bar=False)
    else:
        pytorch_trainer = pl.Trainer(
          accelerator='gpu',
          devices=[device_id],
          strategy='auto',
          logger=logger,
          max_epochs=max_epochs,
          enable_progress_bar=False)  

    pytorch_trainer.fit(
        model_to_train,
        train_dataloaer,
        eval_dataloaer)
    pytorch_trainer.test(model_to_train, test_dataloaer)

    model_trained = model_to_train.model_.to(device_to_use)
    test_denominator_tsr = test_denominator_tsr.to(device_to_use)
    test_numerator_tsr = test_numerator_tsr.to(device_to_use)

    esti_divergence_func = gen_estimate_target_divergence_func(
        method, params_training)
    estimated_div = esti_divergence_func(
                test_denominator_tsr, test_numerator_tsr,
                model_trained)   
    esti_KLdivergence_func = gen_estimate_KL_divergence_func(
        method, params_training)
    estimated_KL, _ = esti_KLdivergence_func(
                test_denominator_tsr, test_numerator_tsr,
                model_trained)   

    for logger in pytorch_trainer.loggers:
        logger.log_metrics({
            'estimated_target_divergence': estimated_div})
    
    esti_density_rate_func = gen_estimate_density_rate_func(method)
    test_estimated_rate_tsr = esti_density_rate_func(
                                    test_denominator_tsr, model_trained)
    test_estimated_rate_np = test_estimated_rate_tsr.cpu().detach().numpy()
    if true_rate_for_test is not None:
        rmse = mean_squared_error(
            test_estimated_rate_np, true_rate_for_test, squared=False)
        L1 = mean_absolute_error(
            test_estimated_rate_np, true_rate_for_test)
        bias = np.mean(test_estimated_rate_np - true_rate_for_test)
        L3 = calc_Lp_error(test_estimated_rate_np,
                           true_rate_for_test, p=3.0)
        result_accuracy = {
            'L1': L1,
            'rmse': rmse,
            'L3': L3,
            'bias': bias}
    else:
        result_accuracy = {}
    
    result_optimization = {
        'estimated_target_divergence': estimated_div,
        'estimated_KL_divergence': estimated_KL}
    result_optimization.update(result_accuracy)

    if do_save_result:
        # Save the above results
        out_pred_dir = os.path.join(out_results_dir, 'results')
        os.makedirs(out_pred_dir, exist_ok=True)
        np.save(
              os.path.join(out_pred_dir, 'test_estimated'),
              test_estimated_rate_np)
        out_energy_mdl_filename = f'energy_model.mdl'
        out_energy_mdl_filepath = os.path.join(
            out_model_dir, out_energy_mdl_filename)
        torch.save(
            model_to_train.model_.model_list_.state_dict(),
            out_energy_mdl_filepath)
        if true_rate_for_test is not None:
            error = true_rate_for_test - test_estimated_rate_np
            np.save(
              os.path.join(out_pred_dir, 'test_errors'),
              test_estimated_rate_np)

    print('-'*100)
    print(f'Directory to output results:  {out_results_dir}')
    print(f'Estimated target divergence = {estimated_div}')
    print('-'*100)

    return model_trained, result_optimization


