import os
import random
import datetime

from typing import List, Any, Tuple, Dict

import numpy as np

from sklearn.metrics import mean_squared_error

import torch
from torch import nn, Tensor
from torch.distributions import MultivariateNormal, Uniform
from torch.utils.data import DataLoader


import lightning as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

from lib.ProbabilityRateModeling import *

def seed_everything(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed) 
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

def estimate_density_rate(
      data_denominator: Tensor,
      prob_rate_model: ProbRateDense
    ) -> Tensor:    
    T_P, mean_exp_minus_T_P = prob_rate_model(
        'density', data_denominator)
    estimated_log_dQdP = - T_P - mean_exp_minus_T_P
    estimated_dQdP = torch.exp(estimated_log_dQdP)
    return estimated_dQdP

def estimate_alpha_div(
      alpha: float,
      data_for_P_distribution: Tensor,
      data_for_Q_distribution: Tensor,
      prob_rate_model: ProbRateDense,
  ) -> float:
    T_P, T_Q = prob_rate_model(
        'divergence',
        data_for_P_distribution, data_for_Q_distribution)
    T_Q_vec = T_Q.flatten()
    exp_alpha_T_Q = torch.exp(alpha*T_Q_vec)
    mean_exp_alpha_T_Q = torch.mean(exp_alpha_T_Q).item()
    T_P_vec = T_P.flatten()
    exp_one_m_alpha_T_P = torch.exp((alpha-1)*T_P_vec)
    mean_one_m_alpha_T_P = torch.mean(exp_one_m_alpha_T_P).item()
    loss_alpha_div = (
        mean_exp_alpha_T_Q/alpha
        + mean_one_m_alpha_T_P/(1 - alpha))
    alpha_infomation = 1/(alpha*(1 - alpha)) - loss_alpha_div
    return alpha_infomation

def estimate_KL_div(
      data_for_base_distribution: Tensor,
      prob_rate_model: ProbRateDense,
  ) -> float:
    T_P, mean_exp_minus_T_P = prob_rate_model(
        'density', data_for_base_distribution)
    log_dPdQ = T_P + mean_exp_minus_T_P
    estimated_kl_div = torch.mean(log_dPdQ).item()

    return estimated_kl_div

def estimate_JS_div(
      data_for_P_distribution: Tensor,
      data_for_Q_distribution: Tensor,
      prob_rate_model: ProbRateDense,
  ) -> Tensor:
    T_P, T_Q = prob_rate_model(
        'divergence',
        data_for_P_distribution, data_for_Q_distribution)
    T_P_vec = T_P.flatten()
    exp_minus_T_P = torch.exp(-T_P_vec)
    mean_exp_minus_T_P = torch.mean(exp_minus_T_P)
    T_Q_vec = T_Q.flatten()
    exp_T_Q = torch.exp(T_Q_vec)
    mean_exp_T_Q = torch.mean(exp_T_Q)

    log2_T_P = torch.log(Tensor([2]).type_as(T_P))
    under_P_mean_logdPdM_numerator = torch.log(mean_exp_minus_T_P)
    under_P_mean_logdPdM_denominator = torch.mean(
          torch.log(exp_minus_T_P + mean_exp_minus_T_P))
    mean_under_P_pred_logdPdM = (
        under_P_mean_logdPdM_numerator
          - under_P_mean_logdPdM_denominator)/log2_T_P + 1

    log2_T_Q = torch.log(Tensor([2]).type_as(T_Q))
    under_Q_mean_logdQdM_numerator = torch.log(mean_exp_T_Q)
    under_Q_mean_logdQdM_denominator = torch.mean(
          torch.log(exp_T_Q + mean_exp_T_Q))
    mean_under_Q_pred_logdQdM = (
        under_Q_mean_logdQdM_numerator
          - under_Q_mean_logdQdM_denominator)/log2_T_Q + 1

    JS_divergence = (mean_under_P_pred_logdPdM/2
                      + mean_under_Q_pred_logdQdM/2).item()
    return JS_divergence

def dre_train_for_all_epoch(
            train_denominator_data_np: Any,
            train_numerator_data_np: Any, 
            eval_denominator_data_np: Any,
            eval_numerator_data_np: Any,
            test_denominator_data_np: Any,
            test_numerator_data_np: Any,
            params_training: Dict[str, Any],
            learning_rate: float,
            batchsize: int,
            n_all_epochs: int,
            earlystopping_patience: int,
            out_results_dir: str,
            do_save_result: bool,
            seed: int,
            true_rate_for_test: Any  = None
            ) ->  List[float]:
    seed_everything(seed)
    out_log_dir = os.path.join(out_results_dir, 'all_logs')
    os.makedirs(out_log_dir, exist_ok=True)
    out_model_dir = os.path.join(out_results_dir, 'energy_models')
    os.makedirs(out_model_dir, exist_ok=True)

    train_denominator_tsr = torch.from_numpy(
        train_denominator_data_np.astype(np.float32))
    eval_denominator_tsr = torch.from_numpy(
        eval_denominator_data_np.astype(np.float32))
    train_numerator_tsr = torch.from_numpy(
        train_numerator_data_np.astype(np.float32))
    eval_numerator_tsr = torch.from_numpy(
        eval_numerator_data_np.astype(np.float32))
    test_denominator_tsr = torch.from_numpy(
        test_denominator_data_np.astype(np.float32)).type_as(
                          train_numerator_tsr)
    test_numerator_tsr = torch.from_numpy(
        test_numerator_data_np.astype(np.float32)).type_as(
                          train_numerator_tsr)

    train_dataset = torch.utils.data.TensorDataset(
        train_denominator_tsr, train_numerator_tsr)
    train_dataloaer = DataLoader(
        train_dataset,
        shuffle=True,
        generator=torch.Generator(device='cuda'),
        batch_size=batchsize)

    eval_dataset = torch.utils.data.TensorDataset(
        eval_denominator_tsr, eval_numerator_tsr)
    eval_dataloaer = DataLoader(
        eval_dataset,
        shuffle=True,
        generator=torch.Generator(device='cuda'),
        batch_size=batchsize)

    test_dataset = torch.utils.data.TensorDataset(
        test_denominator_tsr, test_numerator_tsr)
    test_dataloaer = DataLoader(
        test_dataset,
        shuffle=True,
        generator=torch.Generator(device='cuda'),
        batch_size=batchsize) 

    n_dim_input = train_denominator_tsr.shape[1]
    params_training['input_dim'] = n_dim_input
    model_to_train = LightningDREdenseNN(
        params_training,
        learning_rate)

    now = datetime.datetime.now()
    logname_suffix = now.strftime('%Y-%m-%d_%H_%M_%S')
    logname = f'Run_{logname_suffix}'
    logger = TensorBoardLogger(
        out_log_dir,
        logname,
        default_hp_metric=False)

    if earlystopping_patience > 0:
        early_stop_callback = EarlyStopping(
            monitor='validate_alpha_divergence',
            min_delta=0.00,
            patience=earlystopping_patience,
            verbose=False,
            mode='max'
        )
        pytorch_trainer = pl.Trainer(
            accelerator='gpu',
            devices=1,
            strategy='auto',
            logger=logger,
            max_epochs=n_all_epochs,
            callbacks=[early_stop_callback],
            enable_progress_bar=False)
    else:
        pytorch_trainer = pl.Trainer(
          accelerator='gpu',
          devices=1,
          strategy='auto',
          logger=logger,
          max_epochs=n_all_epochs,
          enable_progress_bar=False)  

    pytorch_trainer.fit(
        model_to_train,
        train_dataloaer,
        eval_dataloaer)
    pytorch_trainer.test(model_to_train, test_dataloaer)

    test_kl_div_estmated = estimate_KL_div(
        test_denominator_tsr, 
        model_to_train.prob_rate_model_)

    test_JS_div_estmated = estimate_JS_div(
        test_denominator_tsr, 
        test_numerator_tsr,
        model_to_train.prob_rate_model_)

    test_estimated_rate_tsr = estimate_density_rate(
        test_denominator_tsr, 
        model_to_train.prob_rate_model_)
    test_estimated_rate_np = test_estimated_rate_tsr.cpu().detach().numpy()
    if true_rate_for_test is not None:
        mse = mean_squared_error(
            test_estimated_rate_np, true_rate_for_test)
    else:
        mse = None
    if do_save_result:
        # Save the above results
        out_pred_dir = os.path.join(out_results_dir, 'estimated')
        os.makedirs(out_pred_dir, exist_ok=True)
        np.save(
              os.path.join(out_pred_dir, 'vals_estimated'),
              test_estimated_rate_np)
        out_energy_mdl_filename = f'energy_model.mdl'
        out_energy_mdl_filepath = os.path.join(
            out_model_dir, out_energy_mdl_filename)
        torch.save(
            model_to_train.prob_rate_model_.energy_model_.state_dict(),
            out_energy_mdl_filepath)

    print('-'*100)
    print(f'Directory to output results:  {out_results_dir}')
    print(f'Estimated KL divergence = {test_kl_div_estmated}')
    print(f'Estimated JS divergence = {test_JS_div_estmated}')
    print('-'*100)

    return test_kl_div_estmated, mse


