import os
import datetime

from typing import List, Any, Tuple, Dict
import numpy as np

import torch
from torch import nn, Tensor
from torch import optim

import lightning as pl


class EnergyDense(nn.Module):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 n_layers_per_block: int,
                 n_blocks: int,
                 dropout: float,
                 DoBatchNormarize: bool
                 ):
        super().__init__()
        layers: List[nn.Module] = []
        _i_input_dim: int = input_dim
        for _i_block in range(n_blocks):
            layers.append(
                EnergyDenseBlock(
                  _i_input_dim,
                  hidden_dim,
                  n_layers_per_block,
                  dropout,
                  DoBatchNormarize))
            _i_input_dim = hidden_dim
        output_dim = 1
        layers.append(nn.Linear(hidden_dim, output_dim))
        self.layers_: nn.Module = nn.Sequential(*layers)

    def forward(self, data: torch.Tensor) -> torch.Tensor:
        return self.layers_(data)


class EnergyDenseBlock(nn.Module):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 n_layers: int,
                 dropout: float,
                 DoBatchNormarize: bool
                 ):
        super().__init__()
        self.input_dim_ = input_dim
        self.hidden_dim_ = hidden_dim
        layers: List[nn.Module] = []
        if DoBatchNormarize:
            layers.append(
                nn.BatchNorm1d(input_dim,
                affine=DoBatchNormarize))
        _h_input_dim: int = input_dim
        for _i_layer in range(n_layers-1):
            layers.append(
                nn.Linear(_h_input_dim, hidden_dim))
            layers.append(nn.ReLU())
            _h_input_dim = hidden_dim
        if dropout > 0:
            layers.append(nn.Dropout(dropout))    
        layers.append(
            nn.Linear(_h_input_dim, hidden_dim,
            bias=not DoBatchNormarize))
        layers.append(nn.ReLU())
  
        self.layers_: nn.Module = nn.Sequential(*layers)

    def forward(self, data: torch.Tensor) -> torch.Tensor:
        if self.input_dim_ == self.hidden_dim_:
            # skip layer when block size >= 2.
            return self.layers_(data) + data
        else:
            return self.layers_(data)


class ProbRateDense(nn.Module):
    def __init__(self,
                 params_for_training: Dict[str, Any],
                 ):
        super().__init__()
        self.alpha_ = params_for_training['alpha']
        self.input_dim_ = params_for_training['input_dim']
        self.hidden_dim_ = params_for_training['hidden_dim']
        self.n_layers_per_block_ = params_for_training['n_layers_per_block']
        self.n_blocks_ = params_for_training['n_blocks']
        self.dropout_ = params_for_training['dropout']
        self.DoBatchNormarize_ = params_for_training['DoBatchNormarize']      
        self.energy_model_ = EnergyDense(
                                self.input_dim_,
                                self.hidden_dim_,
                                self.n_layers_per_block_,
                                self.n_blocks_,
                                self.dropout_,
                                self.DoBatchNormarize_)

    def load_energy_model(self, energy_model_parameters: dict):
        mdl = EnergyDense(self.input_dim_,
                          self.hidden_dims_)
        mdl.load_state_dict(energy_model_parameters)
        self.energy_model_ = mdl

    def _forward_train(self,
                       observed_data: torch.Tensor,
                       generated_data: torch.Tensor,
                       ) -> List[torch.Tensor]:
        preds_P = self.energy_model_(observed_data)
        preds_Q = self.energy_model_(generated_data)
        return preds_P, preds_Q

    def _forward_estimate(self,
                          observed_data: torch.Tensor) -> List[torch.Tensor]:
        T_P = self.energy_model_(observed_data).flatten()
        mean_exp_minus_T_P = torch.mean(torch.exp(-T_P))

        return T_P, mean_exp_minus_T_P

    def forward(self,
                pred_mode: str,
                observed_data: torch.Tensor,
                generated_data: torch.Tensor = None,
                ) -> List[torch.Tensor]:
        if pred_mode == 'density':
            return self._forward_estimate(observed_data)
        elif pred_mode == 'divergence':
            return self._forward_train(observed_data, generated_data)

class LightningDREdenseNN(pl.LightningModule):
    def __init__(self,
                 params_training: Dict[str, Any],
                 learning_rate: float
                 ):
        super().__init__()
        self.save_hyperparameters()
        self.alpha_ = params_training['alpha']
        prob_rate_model = ProbRateDense(params_training)
        self.prob_rate_model_ = prob_rate_model
        self.learning_rate_ = learning_rate

    def forward(self,
                denominator_data: torch.Tensor,
                numerator_data: torch.Tensor):
        return self.prob_rate_model_(
            'divergence', denominator_data, numerator_data)

    def training_step(self, batch, batch_idx: int) -> torch.Tensor:
        denominator_data, numerator_data = batch
        T_P, T_Q = self(denominator_data, numerator_data)
        exp_alphaT_Q = torch.exp(self.alpha_*T_Q)
        exp_alphaT_P = torch.exp((self.alpha_ - 1)*T_P)
        term_Q = torch.mean(exp_alphaT_Q)
        term_P = torch.mean(exp_alphaT_P)
        loss = term_Q/self.alpha_ + term_P/(1 - self.alpha_)
        return loss

    def validation_step(self, batch, batch_idx: int) -> None:
        with torch.inference_mode():
            denominator_data, numerator_data = batch
            T_P, T_Q = self(denominator_data, numerator_data)
            exp_alphaT_Q = torch.exp(self.alpha_*T_Q)
            exp_alphaT_P = torch.exp((self.alpha_ - 1)*T_P)
            term_Q = torch.mean(exp_alphaT_Q)
            term_P = torch.mean(exp_alphaT_P)
            loss = term_Q/self.alpha_ + term_P/(1 - self.alpha_)
            alpha_div = 1/(self.alpha_*(1 - self.alpha_)) - loss
            self.log(f'validate_alpha_divergence', alpha_div,
                     on_step=True)

            T_P_vec = T_P.flatten()
            exp_minus_T_P = torch.exp(-T_P_vec)
            mean_exp_minus_T_P = torch.mean(exp_minus_T_P)
            T_Q_vec = T_Q.flatten()
            exp_T_Q = torch.exp(T_Q_vec)
            mean_exp_T_Q = torch.mean(exp_T_Q)

            under_P_mean_pred_log_dPdQ = torch.mean(
                  T_P_vec + mean_exp_minus_T_P)

            self.log(f'validate_KL_divergence', 
                     under_P_mean_pred_log_dPdQ.item(),
                     on_step=True)

            log2_T_P = torch.log(Tensor([2]).type_as(T_P))
            under_P_mean_logdPdM_numerator = torch.log(mean_exp_minus_T_P)
            under_P_mean_logdPdM_denominator = torch.mean(
                  torch.log(exp_minus_T_P + mean_exp_minus_T_P))
            mean_under_P_pred_logdPdM = (
                under_P_mean_logdPdM_numerator
                 - under_P_mean_logdPdM_denominator)/log2_T_P + 1

            log2_T_Q = torch.log(Tensor([2]).type_as(T_Q))
            under_Q_mean_logdQdM_numerator = torch.log(mean_exp_T_Q)
            under_Q_mean_logdQdM_denominator = torch.mean(
                  torch.log(exp_T_Q + mean_exp_T_Q))
            mean_under_Q_pred_logdQdM = (
                under_Q_mean_logdQdM_numerator
                 - under_Q_mean_logdQdM_denominator)/log2_T_Q + 1

            JS_divergence = (mean_under_P_pred_logdPdM/2
                             + mean_under_Q_pred_logdQdM/2)

            self.log(f'validate_Jensen–Shannon_divergence', 
                     JS_divergence.item(),
                     on_step=True)

            return alpha_div

    def test_step(self, batch, batch_idx: int) -> None:
        with torch.inference_mode():
            denominator_data, numerator_data = batch
            T_P, T_Q = self(denominator_data, numerator_data)
            exp_alphaT_Q = torch.exp(self.alpha_*T_Q)
            exp_alphaT_P = torch.exp((self.alpha_ - 1)*T_P)
            term_Q = torch.mean(exp_alphaT_Q)
            term_P = torch.mean(exp_alphaT_P)
            loss = term_Q/self.alpha_ + term_P/(1 - self.alpha_)
            alpha_div = 1/(self.alpha_*(1 - self.alpha_)) - loss
            self.log(f'test_alpha_divergence', alpha_div,
                     on_step=True)    

            T_P_vec = T_P.flatten()
            exp_minus_T_P = torch.exp(-T_P_vec)
            mean_exp_minus_T_P = torch.mean(exp_minus_T_P)
            T_Q_vec = T_Q.flatten()
            exp_T_Q = torch.exp(T_Q_vec)
            mean_exp_T_Q = torch.mean(exp_T_Q)

            under_P_mean_pred_log_dPdQ = torch.mean(
                  T_P_vec + mean_exp_minus_T_P)

            self.log(f'test_KL_divergence', 
                     under_P_mean_pred_log_dPdQ.item(),
                     on_step=True)

            log2_T_P = torch.log(Tensor([2]).type_as(T_P))
            under_P_mean_logdPdM_numerator = torch.log(mean_exp_minus_T_P)
            under_P_mean_logdPdM_denominator = torch.mean(
                  torch.log(exp_minus_T_P + mean_exp_minus_T_P))
            mean_under_P_pred_logdPdM = (
                under_P_mean_logdPdM_numerator
                 - under_P_mean_logdPdM_denominator)/log2_T_P + 1

            log2_T_Q = torch.log(Tensor([2]).type_as(T_Q))
            under_Q_mean_logdQdM_numerator = torch.log(mean_exp_T_Q)
            under_Q_mean_logdQdM_denominator = torch.mean(
                  torch.log(exp_T_Q + mean_exp_T_Q))
            mean_under_Q_pred_logdQdM = (
                under_Q_mean_logdQdM_numerator
                 - under_Q_mean_logdQdM_denominator)/log2_T_Q + 1

            JS_divergence = (mean_under_P_pred_logdPdM/2
                             + mean_under_Q_pred_logdQdM/2)

            self.log(f'test_Jensen–Shannon_divergence', 
                     JS_divergence.item(),
                     on_step=True)

    def configure_optimizers(self) -> optim.Optimizer:
        return optim.Adam(self.prob_rate_model_.energy_model_.parameters(),
                          lr=self.learning_rate_)



