from abc import ABC, abstractmethod
import torch
import torch.nn as nn

from typing import Dict, List, Tuple, Union


class RecurrentCell(nn.Module, ABC):
    @abstractmethod
    def forward(self, step_inputs: torch.Tensor, memory: torch.Tensor) -> [torch.Tensor, torch.Tensor]:
        raise NotImplementedError


class RecurrentModel(nn.Module, ABC):
    """ RecurrentModel should waro all the NN weights
    """
    def set_recurrent_cell(self, recurrent_cell):
        self.recurrent_cell = recurrent_cell

    # @abstractmethod
    # def forward(self, rollout):
    #     raise NotImplementedError