import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from omegaconf import OmegaConf
from typing import List, Dict

from .recurrent.recurrent_model import RecurrentCell
from .encoder_decoder import EncoderDecoder


class RNNCell(RecurrentCell):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.encoder_decoder = EncoderDecoder(config)

    def forward(self, step_inputs, memory):
        """ This function is only used for training, and must have the same arguments and outputs as RecurrentCell!
        """
        encoder_input_ids = step_inputs["source_ids"]
        decoder_input_ids = step_inputs["decoder_input_ids"]
        encoder_memory, decoder_memory = memory

        outputs, encoder_memory, decoder_cache = self.encoder_decoder(encoder_input_ids, decoder_input_ids, encoder_memory, decoder_memory)

        memory = encoder_memory, decoder_cache
        return outputs, memory

    def compute_outputs(self, recurrent_outputs, recurrent_inputs, training: bool = True) -> Dict:
        raise NotImplementedError

    def construct_memory(self, batch_size):
        raise NotImplementedError

    def reset_memory(self, memory_reset_signal: torch.Tensor):
        raise NotImplementedError
