from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from torchfly.training.flymodel import FlyModel

from typing import Any, Dict, List, Tuple, Union
from overrides import overrides

# pylint:disable=undefined-variable,unused-variable, assignment-from-no-return, not-callable, no-member


class RecurrentFlyModel(FlyModel, ABC):
    def __init__(self, config):
        super().__init__(config)
        self._working_memory = None
        self.is_memory_fresh = True

    @abstractmethod
    def detach_working_memory(self) -> Any:
        raise NotImplementedError

    @abstractmethod
    def reset(self, batch_size=None) -> None:
        raise NotImplementedError

    def get_memory_grad(self, memory):
        return memory.grad

    def get_zero_memory_grad(self, memory):
        return torch.zeros_like(memory)

    def backward_memory_grad(self, memory, memory_grad):
        memory.backward(memory_grad, retain_graph=True)

    def set_requires_grad(self, memory):
        if memory.requires_grad == False:
            memory.requires_grad = True

    def set_current_working_memory(self, memory) -> None:
        self._working_memory = memory