import os
import random
import datetime

import math
from typing import List, Any, Tuple, Dict
from abc import ABCMeta, abstractmethod

import numpy as np

from sklearn.metrics import mean_squared_error

import torch
from torch import nn, Tensor
from torch.distributions import MultivariateNormal, Uniform
from torch.utils.data import DataLoader


import lightning as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping


# def set_seed_everything(seed):
#     np.random.seed(seed)
#     random.seed(seed)
#     os.environ['PYTHONHASHSEED'] = str(seed) 
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed(seed)




class ABCPrimitiveProbRateDense(metaclass=ABCMeta):
    def __init__(self,
                 params_for_training: Dict[str, Any],
                 ):
        super().__init__()
        #self.alpha_ = params_for_training['alpha']
        self.input_dim_ = params_for_training['input_dim']
        self.hidden_dim_ = params_for_training['hidden_dim']

    @abstractmethod
    def _forward_train(self,
                       denominator_data: torch.Tensor,
                       numerator_data: torch.Tensor):
        ...

    @abstractmethod
    def _forward_estimate(self,
                          denominator_data: torch.Tensor):
        ...

    @abstractmethod
    def forward(self,
                pred_mode: str,
                denominator_data: torch.Tensor,
                numerator_data: torch.Tensor = None,
                ) -> List[torch.Tensor]:
        ...



class ABCSimpleProbRateDense(ABCPrimitiveProbRateDense):
    def __init__(self,
                 params_for_training: Dict[str, Any],
                 ):
        super().__init__(params_for_training)
        self.n_layers_ = params_for_training['n_layers']

    @abstractmethod
    def _forward_train(self,
                       denominator_data: torch.Tensor,
                       numerator_data: torch.Tensor):
        ...

    @abstractmethod
    def _forward_estimate(self,
                          denominator_data: torch.Tensor):
        ...

    @abstractmethod
    def forward(self,
                pred_mode: str,
                denominator_data: torch.Tensor,
                numerator_data: torch.Tensor = None,
                ) -> List[torch.Tensor]:
        ...



class ABCProbRateDense(ABCPrimitiveProbRateDense):
    def __init__(self,
                 params_for_training: Dict[str, Any],
                 ):
        super().__init__(params_for_training)
        self.n_layers_per_block_ = params_for_training['n_layers_per_block']
        self.n_blocks_ = params_for_training['n_blocks']
        self.dropout_ = params_for_training['dropout']
        self.DoBatchNormarize_ = params_for_training['DoBatchNormarize']
        self.energy_model_list_ = None
        self.n_energy_model_ = None

    @abstractmethod
    def _forward_train(self,
                       denominator_data: torch.Tensor,
                       numerator_data: torch.Tensor):
        ...

    @abstractmethod
    def _forward_estimate(self,
                          denominator_data: torch.Tensor):
        ...

    @abstractmethod
    def forward(self,
                pred_mode: str,
                denominator_data: torch.Tensor,
                numerator_data: torch.Tensor = None,
                ) -> List[torch.Tensor]:
        ...

# class ABCProbRateDense(metaclass=ABCMeta):
#     def __init__(self,
#                  params_for_training: Dict[str, Any],
#                  ):
#         super().__init__()
#         #self.alpha_ = params_for_training['alpha']
#         self.input_dim_ = params_for_training['input_dim']
#         self.hidden_dim_ = params_for_training['hidden_dim']
#         self.n_layers_per_block_ = params_for_training['n_layers_per_block']
#         self.n_blocks_ = params_for_training['n_blocks']
#         self.dropout_ = params_for_training['dropout']
#         self.DoBatchNormarize_ = params_for_training['DoBatchNormarize']
#         self.energy_model_list_ = None
#         self.n_energy_model_ = None

#     @abstractmethod
#     def _forward_train(self,
#                        denominator_data: torch.Tensor,
#                        numerator_data: torch.Tensor):
#         ...

#     @abstractmethod
#     def _forward_estimate(self,
#                           denominator_data: torch.Tensor):
#         ...

#     @abstractmethod
#     def forward(self,
#                 pred_mode: str,
#                 denominator_data: torch.Tensor,
#                 numerator_data: torch.Tensor = None,
#                 ) -> List[torch.Tensor]:
#         ...

class Truncated(nn.Module):
    def __init__(self, min: float, max: float, inplace: bool = False) -> None:
        super().__init__()
        self.min_ = min
        self.max_ = max
        self.inplace_ = inplace
        self.min_threshold_ = torch.nn.Threshold(self.min_, self.min_)
        self.max_threshold_ = torch.nn.Threshold(-self.max_, -self.max_)

    def forward(self, input: Tensor) -> Tensor:
        res = -self.max_threshold_(-self.min_threshold_(input))
        return res

    def extra_repr(self):
        inplace_str = ', inplace=True' if self.inplace else ''
        return f'min={self.min_}, max={self.max_}{inplace_str}'