import os, sys
import datetime

import math
from typing import List, Any, Tuple, Dict
import numpy as np

import torch
from torch import nn, Tensor
from torch import optim

import lightning as pl

from lib.dre.nn.linear import EnergyBasedModelDenseLinear, DensityRateModelDenseLinear
from lib.dre.common.estimate import gen_estimate_density_rate_func
from lib.dre.common.estimate import gen_estimate_target_divergence_func, gen_estimate_KL_divergence_func
from lib.dre.common.loss import gen_loss_func


class LightningNaiveDREdenseNN(pl.LightningModule):
    def __init__(self,
                method: str,
                params_training: Dict[str, Any],
                 ):
        super().__init__()
        self.save_hyperparameters()

        self.method_ = method
        if self.method_ == 'alphaDiv':
            self.alpha_ = params_training['alpha']
            model_to_train = EnergyBasedModelDenseLinear(params_training)
            self.model_ = model_to_train
        elif self.method_ in [
             'LSIF-energy', 
             'KLdivergence-energy',
             'nnBD-KLdivergence-energy',
             'penalty-KLdivergence-energy']:
            model_to_train = EnergyBasedModelDenseLinear(params_training)
            self.model_ = model_to_train
        elif self.method_ in [
                'LSIF', 
                'nnBD-LSIF',
                'penalty-LSIF',
                'KLdivergence',
                'nnBD-KLdivergence',
                'penalty-KLdivergence',
                'GAN',
                'alphaDiv-biased',
                'alphaDiv-biased-truncated']:  
            model_to_train = DensityRateModelDenseLinear(params_training)
            self.model_ = model_to_train
        else:
            sys.exit('LightningNaiveDREdenseNN error: no method specified.')
        if 'log_on_step' in params_training.keys():
            self.log_on_step_ = params_training['log_on_step']
        else: 
            self.log_on_step_ = True
        self.do_eval_mse_ = params_training['eval_mse']
        self.do_test_mse_ = params_training['test_mse']
        if self.do_eval_mse_ or self.do_test_mse_:
            self.esti_density_rate_func_ = gen_estimate_density_rate_func(method)
        self.learning_rate_ = params_training['learning_rate']
        self.loss_func_ = gen_loss_func(
                method=method, params=params_training)
        self.esti_target_div_func_ = gen_estimate_target_divergence_func(
                                            method, params_training)
        self.esti_KL_div_func_ = gen_estimate_KL_divergence_func(
                                            method, params_training)
    def forward(self,
                denominator_data: torch.Tensor,
                numerator_data: torch.Tensor):
        return self.prob_rate_model_(
            'optimization', denominator_data, numerator_data)
        
    def training_step(self, batch, batch_idx: int) -> torch.Tensor:
        denominator_data, numerator_data = batch
        loss = self.loss_func_(
            denominator_data, numerator_data,
            self.model_)
        return loss

    def validation_step(self, batch, batch_idx: int) -> None:
        with torch.inference_mode():
            if self.do_eval_mse_ :
                denominator_data, numerator_data, true_dre = batch
                estimated_dre = self.esti_density_rate_func_(
                    denominator_data,
                    self.model_)
                #loss = torch.nn.MSELoss()
                #mse = loss(estimated_dre, true_dre)
                mse = torch.mean(
                    torch.pow((estimated_dre - true_dre), 2))
                self.log(f'validate_L2', 
                        mse.item(),
                        on_step=self.log_on_step_)
                L1 = torch.mean(
                    torch.absolute(estimated_dre - true_dre))
                self.log(f'validate_L1', 
                        L1.item(),
                        on_step=self.log_on_step_) 
                bias = torch.mean(estimated_dre - true_dre)
                self.log(f'validate_bias', 
                        bias.item(),
                        on_step=self.log_on_step_) 
            else:
                denominator_data, numerator_data = batch

            div = self.esti_target_div_func_(
                denominator_data, numerator_data, self.model_)      
            self.log(f'validate_target_divergence', 
                     div, on_step=self.log_on_step_)
            # KL_1, _ = self.esti_KL_div_func_(
            #     denominator_data, numerator_data, self.model_) 
            # self.log(f'validate_KL_divergence', 
            #          KL_1, on_step=self.log_on_step_)
            KL_1, KL_2 = self.esti_KL_div_func_(
                denominator_data, numerator_data, self.model_) 
            self.log(f'validate_KL_divergence_1', 
                     KL_1,
                     on_step=self.log_on_step_)
            self.log(f'validate_KL_divergence_2', 
                     KL_2,
                     on_step=self.log_on_step_)

    def test_step(self, batch, batch_idx: int) -> None:
        with torch.inference_mode():
            if self.do_test_mse_ :
                denominator_data, numerator_data, true_dre = batch
                estimated_dre = self.esti_density_rate_func_(
                    denominator_data,
                    self.model_)
                #loss = torch.nn.MSELoss()
                #mse = loss(estimated_dre, true_dre)
                mse = torch.mean(
                    torch.pow((estimated_dre - true_dre), 2))
                self.log(f'test_L2', 
                        mse.item(),
                        on_step=self.log_on_step_)
                L1 = torch.mean(
                    torch.absolute(estimated_dre - true_dre))
                self.log(f'test_L1', 
                        L1.item(),
                        on_step=self.log_on_step_) 
                bias = torch.mean(estimated_dre - true_dre)
                self.log(f'test_bias', 
                        bias.item(),
                        on_step=self.log_on_step_) 
            else:
                denominator_data, numerator_data = batch

            div = self.esti_target_div_func_(
                denominator_data, numerator_data, self.model_)      
            self.log(f'test_target_divergence', 
                     div,
                     on_step=self.log_on_step_)

            # KL_1, _ = self.esti_KL_div_func_(
            #     denominator_data, numerator_data, self.model_)
            # self.log(f'test_KL_divergence', 
            #          KL_1,
            #          on_step=self.log_on_step_)
            KL_1, KL_2 = self.esti_KL_div_func_(
                denominator_data, numerator_data, self.model_)
            self.log(f'test_KL_divergence_1', 
                     KL_1,
                     on_step=self.log_on_step_)
            self.log(f'test_KL_divergence_2', 
                     KL_2,
                     on_step=self.log_on_step_)

    def configure_optimizers(self) -> optim.Optimizer:
        return optim.Adam(self.model_.parameters(),
                          lr=self.learning_rate_)



