import os
import datetime

from typing import List, Any, Tuple, Dict
import numpy as np

import torch
from torch.nn.parameter import Parameter
from torch import nn, Tensor
from torch import optim

import lightning as pl

from lib.dre.common.util import ABCProbRateDense

DRE_UPPER_BOUND=10000

class DenseLinear(nn.Module):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 n_layers_per_block: int,
                 n_blocks: int,
                 dropout: float,
                 DoBatchNormarize: bool
                 ):
        super().__init__()
        layers: List[nn.Module] = []
        _i_input_dim: int = input_dim
        for _i_block in range(n_blocks):
            layers.append(
                SingleBlockDenseLinear(
                  _i_input_dim,
                  hidden_dim,
                  n_layers_per_block,
                  dropout,
                  DoBatchNormarize))
            _i_input_dim = hidden_dim
        output_dim = 1
        layers.append(nn.Linear(hidden_dim, output_dim))
        self.layers_: nn.Module = nn.Sequential(*layers)

    def forward(self, data: torch.Tensor) -> torch.Tensor:
        return self.layers_(data)


class SingleBlockDenseLinear(nn.Module):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 n_layers: int,
                 dropout: float,
                 DoBatchNormarize: bool
                 ):
        super().__init__()
        self.input_dim_ = input_dim
        self.hidden_dim_ = hidden_dim
        layers: List[nn.Module] = []
        if DoBatchNormarize:
            layers.append(
                nn.BatchNorm1d(input_dim,
                affine=DoBatchNormarize))
        _h_input_dim: int = input_dim
        for _i_layer in range(n_layers-1):
            # if dropout > 0:
            #   layers.append(nn.Dropout(dropout))
            layers.append(
                nn.Linear(_h_input_dim, hidden_dim))
            layers.append(nn.ReLU())
            _h_input_dim = hidden_dim
        if dropout > 0:
            layers.append(nn.Dropout(dropout))    
        layers.append(
            nn.Linear(_h_input_dim, hidden_dim,
            bias=not DoBatchNormarize))
        layers.append(nn.ReLU())
  
        self.layers_: nn.Module = nn.Sequential(*layers)

    def forward(self, data: torch.Tensor) -> torch.Tensor:
        if self.input_dim_ == self.hidden_dim_:
            # skip layer when block size >= 2.
            return self.layers_(data) + data
        else:
            return self.layers_(data)


class EnergyBasedModelDenseLinear(
        ABCProbRateDense, nn.Module):
    def __init__(self,
                 params_for_training: Dict[str, Any],
                 ):
        super().__init__(params_for_training)
        self.n_model_ = 1
        self.model_list_ = nn.ModuleList()
        mdl = DenseLinear(
                self.input_dim_,
                self.hidden_dim_,
                self.n_layers_per_block_,
                self.n_blocks_,
                self.dropout_,
                self.DoBatchNormarize_) 
        self.model_list_.append(mdl)


    def _clam_dre(input: torch.Tensor):
        log_dre_upper_bound = np.log(DRE_UPPER_BOUND)
        out =  nn.clamp(input,
                        - log_dre_upper_bound,
                        log_dre_upper_bound)
        return out

    def _forward_nn(self,
                    input: torch.Tensor,
                    ) -> List[torch.Tensor]:
        tmp_out = self.model_list_[0](input)   
        # To preventing values ​​from diverging when inputting an 
        # exponential function, outputs of energy fuctions are
        # limited  under DRE_UPPER_BOUND.
        log_dre_upper_bound = np.log(DRE_UPPER_BOUND)
        out =  torch.clamp(tmp_out, 
                - log_dre_upper_bound, log_dre_upper_bound)
        return out

    def _forward_train(self,
                       denominator_data: torch.Tensor,
                       numerator_data: torch.Tensor,
                       ) -> List[torch.Tensor]:
        preds_energy_P = self._forward_nn(denominator_data)
        preds_energy_Q = self._forward_nn(numerator_data)

        return [preds_energy_P], [preds_energy_Q]

    def _forward_estimate(self,
                          denominator_data: torch.Tensor) -> List[torch.Tensor]:
        preds_energy_P = (self._forward_nn(denominator_data)).flatten()
        return preds_energy_P

    def forward(self,
                pred_mode: str,
                denominator_data: torch.Tensor,
                numerator_data: torch.Tensor = None,
                ) -> List[torch.Tensor]:
        if pred_mode == 'estimation':
            return self._forward_estimate(denominator_data)
        elif pred_mode == 'optimization':
            return self._forward_train(denominator_data, numerator_data)


class DensityRateModelDenseLinear(
        ABCProbRateDense, nn.Module):
    def __init__(self,
                 params_for_training: Dict[str, Any],
                 ):
        super().__init__(params_for_training)
        self.n_model_ = 1
        self.model_list_ = nn.ModuleList()
        mdl = DenseLinear(self.input_dim_,
                    self.hidden_dim_,
                    self.n_layers_per_block_,
                    self.n_blocks_,
                    self.dropout_,
                    self.DoBatchNormarize_)
        self.model_list_.append(mdl)


    def _forward_train(self,
                       denominator_data: torch.Tensor,
                       numerator_data: torch.Tensor,
                       ) -> List[torch.Tensor]:
        #ReLU = torch.nn.ReLU()
        machine_eps = 10**-8
        m = nn.Threshold(machine_eps, machine_eps)
        preds_rate_P = m(self.model_list_[0](denominator_data))
        preds_rate_Q = m(self.model_list_[0](numerator_data))
        return [preds_rate_P], [preds_rate_Q]

    def _forward_estimate(self,
                          denominator_data: torch.Tensor) -> List[torch.Tensor]:
        ReLU = torch.nn.ReLU()
        estimated_rate_P = ReLU(self.model_list_[0](denominator_data)).flatten()
        return estimated_rate_P

    def forward(self,
                pred_mode: str,
                denominator_data: torch.Tensor,
                numerator_data: torch.Tensor = None,
                ) -> List[torch.Tensor]:
        if pred_mode == 'estimation':
            return self._forward_estimate(denominator_data)
        elif pred_mode == 'optimization':
            return self._forward_train(denominator_data, numerator_data)
