from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from torchfly.training.flymodel import FlyModel

from typing import Any, Dict, List, Tuple, Union

# pylint:disable=undefined-variable,unused-variable, assignment-from-no-return, not-callable, no-member


class RecurrentFlyModel(FlyModel, ABC):
    def __init__(self, config):
        super().__init__(config)
        self._working_memory = None
        self.is_memory_fresh = True

    @abstractmethod
    def detach_working_memory(self) -> Any:
        raise NotImplementedError

    @abstractmethod
    def reset(self, batch_size=None) -> None:
        raise NotImplementedError

    def set_current_working_memory(self, memory) -> None:
        self._working_memory = memory