import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from omegaconf import OmegaConf
from typing import List, Dict

from .encoder import Encoder
from .decoder import Decoder


class EncoderDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.encoder_memory_token_len = config.encoder_memory_token_len

        self.word_embedding = nn.Embedding(config.vocab_size, config.dim_embed)
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)

        self.reset_parameters()

    def forward(self, encoder_input_ids, decoder_input_ids, memory, decoder_caches):
        """
        Returns:
            outputs: a dict
            decoder_cache: storing the intermediate results for decoding
        """
        encoder_attn_mask = encoder_input_ids != self.config.pad_token_id
        decoder_attn_mask = decoder_input_ids != self.config.pad_token_id

        # prepend memory
        encoder_attn_mask = F.pad(encoder_attn_mask, (self.encoder_memory_token_len, 0), "constant", True)

        memory_cross_hidden = memory["encoder_cross_hidden"]
        memory_cross_mask = memory["encoder_cross_mask"]

        if decoder_caches is None or self.config.decoder_cache_len == 0:
            decoder_caches = [{"past_hidden_states": None} for _ in range(self.config.num_decoder_layers)]

        encoder_hidden = self.word_embedding(encoder_input_ids)

        encoder_outputs = self.encoder(
            hidden_states=encoder_hidden,
            cross_hidden_states=memory_cross_hidden,
            encoder_attn_mask=encoder_attn_mask,
            cross_attn_mask=memory_cross_mask
        )
        encoder_hidden = encoder_outputs["last_hidden_states"]
        encoder_memory_hidden = encoder_outputs["memory_hidden"]

        decoder_hidden = self.word_embedding(decoder_input_ids)
        decoder_outputs, new_decoder_cache = self.decoder(
            hidden_states=decoder_hidden,
            cross_hidden_states=encoder_hidden,
            decoder_caches=decoder_caches,
            decoder_attn_mask=decoder_attn_mask,
            encoder_attn_mask=encoder_attn_mask
        )

        outputs = {
            "encoder_hidden_states": encoder_hidden,
            "decoder_hidden_states": decoder_outputs["last_hidden_states"]
        }

        new_memory = {"encoder_cross_hidden": encoder_memory_hidden, "encoder_cross_mask": encoder_attn_mask}

        return outputs, new_memory, new_decoder_cache

    def reset_parameters(self):
        nn.init.normal_(self.word_embedding.weight.data, std=0.02)