from typing import Tuple, Optional

import os

import json
import torch
from torch.nn import LSTM, ModuleList

from BaselineModel.BaselineBaseLSTMRegressor import BaseLSTMRegressor
from Utils import config as cfg
from Utils.Constants import Diff, FileNamesConstants


class CombinedLSTMRegressor(BaseLSTMRegressor):
    def __init__(self, weights_embedding_size: int, gradients_embedding_size: int, hidden_size: int, lstm_layers: int,
                 inner_dense_layer_sizes: Tuple[int], tb_log_path: Optional[str], bi_directional: bool,
                 sequence_len: int = int(Diff.NUMBER_STEPS_SAVED*2), batch_first: bool = True,
                 last_layer: Optional[str] = 'sigmoid', **kwargs):
        """
        LSTM regression model using both gradients and weights data.
        :param weights_embedding_size:
        :param gradients_embedding_size:
        :param hidden_size:
        :param lstm_layers:
        :param inner_dense_layer_sizes:
        :param tb_log_path:
        :param sequence_len:
        :param batch_first:
        :param last_layer:
        :param kwargs:
        """
        is_double = True
        self.__weights_embedding_size = weights_embedding_size
        self.__gradients_embedding_size = gradients_embedding_size
        super(CombinedLSTMRegressor, self).__init__(embedding_size=weights_embedding_size,
                                                    hidden_size=hidden_size, is_double=is_double, lstm_layers=lstm_layers,
                                                    last_layer=last_layer, tb_log_path=tb_log_path,
                                                    inner_dense_layer_sizes=inner_dense_layer_sizes,
                                                    sequence_len=sequence_len, batch_first=batch_first,
                                                    bi_directional=bi_directional, kwargs=kwargs)

        if tb_log_path is not None:
            os.remove(os.path.join(tb_log_path, FileNamesConstants.MODEL_HYPER_PARAMS))
            with open(os.path.join(tb_log_path, FileNamesConstants.MODEL_HYPER_PARAMS), 'w') as f:
                json.dump(dict(weights_embedding_size=weights_embedding_size, gradients_embedding_size=gradients_embedding_size,
                               hidden_size=hidden_size, lstm_layers=lstm_layers,
                               last_layer=last_layer, inner_dense_layer_sizes=inner_dense_layer_sizes,
                               batch_first=batch_first, is_double=is_double, device=cfg.device,
                               bi_directional=bi_directional), f)

    def _build_model(self, embedding_size: int, hidden_size: int, lstm_layers: int, inner_dense_layer_sizes: Tuple[int],
                     last_layer: str, batch_first: bool, bi_directional: bool,  **kwargs):
        self._lstm_weights = LSTM(input_size=self.__weights_embedding_size, hidden_size=hidden_size, num_layers=lstm_layers,
                                  batch_first=batch_first, bidirectional=bi_directional)
        self._lstm_gradients = LSTM(input_size=self.__gradients_embedding_size, hidden_size=hidden_size, num_layers=lstm_layers,
                                    batch_first=batch_first, bidirectional=bi_directional)
        dense_inp_size = hidden_size * 2 * 2 if bi_directional else hidden_size * 2
        dense_layers, last_size = self._build_dense_layers(dense_inp_size, inner_dense_layer_sizes)
        self._dense_layers = ModuleList(dense_layers)
        self._dense_layers.extend(self._build_last_layer(last_size, last_layer))

    def forward(self, inp):
        weights_inp = inp[:, :inp.shape[1]//2, :]
        gradients_inp = inp[:, inp.shape[1]//2:, :]
        weights_out, _ = self._lstm_weights(weights_inp)
        gradients_out, _ = self._lstm_gradients(gradients_inp)
        x = torch.cat((weights_out[:, -1, :], gradients_out[:, -1, :]), 1)  # get the final hidden layer value
        for layer in self._dense_layers:
            x = layer(x)
        return x
