import json
import os
from abc import ABC, abstractmethod
from typing import Tuple, Optional

import torch
from torch.nn import Module, Linear, ReLU, Sigmoid
from torch.utils import tensorboard

from Utils import config as cfg, logger
from Utils.Constants import FileNamesConstants


class BaseLSTMRegressor(Module, ABC):
    def __init__(self, embedding_size: int, hidden_size: int, lstm_layers: int, inner_dense_layer_sizes: Tuple[int],
                 tb_log_path: Optional[str], sequence_len: int, batch_first: bool, last_layer: Optional[str],
                 is_double: bool, bi_directional: bool, **kwargs):
        """
        Base class for any LSTM Regressor I want. This class does all the logistics (tensorboard and logging) of
        building a LSTM regressor. It skips the actual building of the model, but it does call the build function.
        :param embedding_size: size of input embeddings
        :param hidden_size: lstm hidden size
        :param lstm_layers:
        :param inner_dense_layer_sizes: tuple of final dense layers sizes
        :param tb_log_path: path for saving tensorboard
        :param sequence_len: expected size of sequences
        :param batch_first: True if the input first dimension is the batch size
        :param last_layer: name of last layer type. supported types: sigmoid / relu
        :param is_double: is the module of type double
        :param bi_directional: should use a bi_directional LSTM
        """
        super(BaseLSTMRegressor, self).__init__()
        self._is_double = is_double
        self._tb_writer = None

        self._build_model(embedding_size=embedding_size, hidden_size=hidden_size, lstm_layers=lstm_layers,
                          inner_dense_layer_sizes=inner_dense_layer_sizes, last_layer=last_layer,
                          batch_first=batch_first, bi_directional=bi_directional, kwargs=locals())
        self.to(cfg.device)
        if self._is_double:
            self.double()

        logger().log('BaseLSTMRegressor::__init__',
                     f'BaseLSTMRegressor is double: {self._is_double} and device in config is: ', cfg.device,
                     ' model device is cuda: ', {next(self.parameters()).is_cuda})

        self.__set_up_tb_log(tb_log_path, embedding_size=embedding_size, hidden_size=hidden_size, lstm_layers=lstm_layers,
                             last_layer=last_layer, inner_dense_layer_sizes=inner_dense_layer_sizes,
                             sequence_len=sequence_len, batch_first=batch_first, is_double=self._is_double,
                             bi_directional=bi_directional)

        logger().force_log_and_print('BaseLSTMRegressor',
                                     f'created BaseLSTMRegressor tensorboard path: {tb_log_path}.\t'
                                     f'Input size: {embedding_size} and dense layer sizes: {inner_dense_layer_sizes}')

    @staticmethod
    def _build_dense_layers(hidden_size, inner_dense_layer_sizes):
        dense_layers_lst = list()
        last_size = hidden_size
        for curr_size in inner_dense_layer_sizes:
            dense_layers_lst.append(Linear(in_features=last_size, out_features=curr_size))
            dense_layers_lst.append(ReLU())
            last_size = curr_size
        return dense_layers_lst, last_size

    @staticmethod
    def _build_last_layer(last_size, last_layer):
        layers_lst = list()
        layers_lst.append(Linear(in_features=last_size, out_features=1))
        if last_layer == 'sigmoid':
            layers_lst.append(Sigmoid())
        elif last_layer == 'relu':
            layers_lst.append(ReLU())
        elif last_layer != 'linear':
            raise ValueError('Bad last layer name')

        return layers_lst

    def __set_up_tb_log(self, tb_log_path: Optional[str], embedding_size: int, hidden_size: int, lstm_layers: int,
                        inner_dense_layer_sizes: Tuple[int], sequence_len: int, batch_first: bool, is_double: bool,
                        last_layer: Optional[str], bi_directional: bool):
        if tb_log_path is not None:
            self._tb_writer = tensorboard.SummaryWriter(os.path.join(tb_log_path, 'logs'))
            self._tb_writer.add_graph(model=self, input_to_model=torch.rand(2, sequence_len, embedding_size,
                                                                            device=cfg.device, dtype=torch.double))
            with open(os.path.join(tb_log_path, FileNamesConstants.MODEL_HYPER_PARAMS), 'w') as f:
                json.dump(dict(embedding_size=embedding_size, hidden_size=hidden_size, lstm_layers=lstm_layers,
                               last_layer=last_layer, inner_dense_layer_sizes=inner_dense_layer_sizes,
                               batch_first=batch_first, is_double=is_double, device=cfg.device,
                               bi_directional=bi_directional), f)

    def tb_writer(self):
        return self._tb_writer

    def is_double(self):
        return self._is_double

    @abstractmethod
    def _build_model(self, embedding_size: int, hidden_size: int, lstm_layers: int, inner_dense_layer_sizes: Tuple[int],
                     last_layer: str, batch_first: bool, bi_directional: bool, **kwargs):
        pass

    @abstractmethod
    def forward(self, inp):
        pass
