from abc import ABC, abstractmethod
import torch
import torch.nn as nn

from typing import Dict, List, Tuple, Union


class RecurrentCell(nn.Module, ABC):
    @abstractmethod
    def forward(self, inputs: torch.Tensor, memory: torch.Tensor) -> [torch.Tensor, torch.Tensor]:
        raise NotImplementedError


class RecurrentModel(nn.Module, ABC):
    """ RecurrentModel should waro all the NN weights
    """
    def set_recurrent_cell(self, recurrent_cell):
        self.recurrent_cell = recurrent_cell

    # @abstractmethod
    # def forward(self, rollout):
    #     raise NotImplementedError

    @abstractmethod
    def compute_outputs(self, recurrent_outputs, recurrent_inputs, training: bool = True) -> Dict:
        raise NotImplementedError

    @abstractmethod
    def construct_memory(self, batch_size):
        raise NotImplementedError

    @abstractmethod
    def reset_memory(self, memory_reset_signal: torch.Tensor):
        raise NotImplementedError