from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F

from .utils import create_relative_position
from .modules.normalizations import NORM2FN
from .modules.feedforward import FeedForward
from .modules.cross_attention import CrossAttention
from .modules.relative_self_attention import RelativeSelfAttention

# pylint:disable=no-member


class DecoderLayer(nn.Module):
    def __init__(
        self,
        num_heads: int,
        dim_model: int,
        dim_head: int,
        dim_ff_inner: int,
        cache_len: int,
        dropout: float = 0.0,
        dropattn: float = 0.0,
        layer_norm_type: str = "layer_norm",
        act_type: str = "swish",
    ):
        super().__init__()
        self.cache_len = cache_len

        self.self_attn = RelativeSelfAttention(
            num_heads=num_heads,
            dim_model=dim_model,
            dim_head=dim_head,
            dropout=dropout,
            dropattn=dropattn,
            layer_norm_type=layer_norm_type
        )

        self.cross_attn = CrossAttention(
            num_heads=num_heads,
            dim_model=dim_model,
            dim_head=dim_head,
            dropout=dropout,
            dropattn=dropattn,
            layer_norm_type=layer_norm_type
        )

        self.feedforward = FeedForward(
            hidden_size=dim_model, intermediate_size=dim_ff_inner, dropout=dropout, act_type=act_type
        )

        # FeedForward does not have its own layer norm
        self.ff_layer_norm = NORM2FN[layer_norm_type](dim_model)

    def forward(
        self,
        hidden_states,
        cross_hidden_states,
        rel_pos_embedding,
        decoder_cache,
        self_attn_mask,
        cross_attn_mask=None
    ):
        """
        Args:
            cross_hidden_states: for cross attention on the encoder outputs
            rel_pos_embedding: shape (query_len, key_len, dim_head)
            decoder_cache: xl-style memory hidden states
            self_attn_mask: extended
            cross_attn_mask: extended
        """

        # PreNorm + Self Attn
        past_hidden_states = hidden_states
        residual = hidden_states
        attn_outputs, self_attn_probs = self.self_attn(
            hidden_states=hidden_states,
            rel_pos_embedding=rel_pos_embedding,
            decoder_cache=decoder_cache["past_hidden_states"],
            extended_attn_mask=self_attn_mask
        )
        hidden_states = residual + attn_outputs

        # PreNorm + Cross Attn
        residual = hidden_states
        attn_outputs, cross_attn_probs = self.cross_attn(
            hidden_states=hidden_states,
            cross_hidden_states=cross_hidden_states,
            decoder_cache=decoder_cache,
            extended_attn_mask=cross_attn_mask
        )
        hidden_states = residual + attn_outputs

        # PreNorm + FeedForward
        residual = hidden_states
        ff_outputs = self.feedforward(self.ff_layer_norm(hidden_states))
        hidden_states = residual + ff_outputs

        outputs = {
            "hidden_states": hidden_states,
        }

        decoder_cache = {"past_hidden_states": past_hidden_states.detach()}

        return outputs, decoder_cache


class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.decoder_cache_len = config.decoder_cache_len
        self.num_attn_buckets = config.num_attn_buckets
        self.max_attn_distance = config.max_attn_distance
        self.num_decoder_layers = config.num_decoder_layers
        self.num_heads = config.num_heads

        self.relatve_position_embedding = nn.Embedding(config.num_attn_buckets, config.dim_head)

        self.layers = nn.ModuleList(
            [
                DecoderLayer(
                    num_heads=config.num_heads,
                    dim_model=config.dim_model,
                    dim_head=config.dim_head,
                    dim_ff_inner=config.dim_ff_inner,
                    cache_len=self.decoder_cache_len,
                    dropout=config.dropout,
                    dropattn=config.dropattn,
                    layer_norm_type=config.layer_norm_type,
                    act_type=config.act_type,
                ) for _ in range(self.num_decoder_layers)
            ]
        )

        self.layer_norm = NORM2FN[config.layer_norm_type](config.dim_model)

        self.reset_parameters()

    def _get_extended_self_attn_mask(self, attn_mask):
        # This is only used for self attention
        seq_len = attn_mask.shape[1]
        attn_mask = F.pad(attn_mask, (self.decoder_cache_len, 0), "constant", True)
        attn_mask = attn_mask.unsqueeze(1)
        extended_attn_mask = attn_mask & attn_mask.transpose(-1, -2)
        # shape (batch, query_len, key_len)
        extended_attn_mask = torch.tril(extended_attn_mask)[:, -seq_len:, :]
        extended_attn_mask = extended_attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
        return extended_attn_mask

    def _get_extended_cross_attn_mask(self, attn_mask, query_len):
        # This is only used for self attention
        extended_attn_mask = attn_mask[:, None, None, :]
        extended_attn_mask = extended_attn_mask.repeat(1, self.num_heads, query_len, 1)
        return extended_attn_mask

    def _get_rel_pos_embedding(self, hidden_states):
        # Relative Position
        query_len = hidden_states.shape[1]
        key_len = query_len + self.decoder_cache_len
        relative_position = create_relative_position(
            query_len=query_len,
            key_len=key_len,
            bidirectional=False,
            num_buckets=self.num_attn_buckets,
            max_distance=self.max_attn_distance
        ).to(hidden_states.device)
        rel_pos_embedding = self.relatve_position_embedding(relative_position)
        return rel_pos_embedding

    def forward(
        self, hidden_states: torch.FloatTensor, cross_hidden_states: torch.FloatTensor, decoder_caches: List[Dict],
        decoder_attn_mask: torch.BoolTensor, encoder_attn_mask: torch.BoolTensor
    ):
        query_len = hidden_states.shape[1]

        self_attn_mask = self._get_extended_self_attn_mask(decoder_attn_mask)
        cross_attn_mask = self._get_extended_cross_attn_mask(encoder_attn_mask, query_len)
        rel_pos_embedding = self._get_rel_pos_embedding(hidden_states)

        # pre-norm
        # cross_hidden_states = self.layer_norm(cross_hidden_states)
        new_decoder_caches = []

        for i in range(self.num_decoder_layers):
            outputs, new_decoder_cache = self.layers[i](
                hidden_states=hidden_states,
                cross_hidden_states=cross_hidden_states,
                rel_pos_embedding=rel_pos_embedding,
                decoder_cache=decoder_caches[i],
                self_attn_mask=self_attn_mask,
                cross_attn_mask=cross_attn_mask,
            )
            hidden_states = outputs["hidden_states"]

            new_decoder_caches.append(new_decoder_cache)

        outputs = {"last_hidden_states": hidden_states}
        new_decoder_caches = self._update_decoder_cahches(new_decoder_caches, decoder_attn_mask)

        return outputs, new_decoder_caches

    def _update_decoder_cahches(self, decoder_caches: List[Dict], decoder_attn_mask: torch.BoolTensor):
        if self.config.decoder_cache_len == 0:
            return None

        batch_size = decoder_attn_mask.shape[0]
        past_hidden_states = [cache["past_hidden_states"] for cache in decoder_caches]
        past_hidden_states = torch.stack(past_hidden_states, dim=0)
        new_past_hidden_states = []

        for i in range(batch_size):
            # shape (num_layers, batch_size, seq_len, )
            hidden = past_hidden_states[:, i][:, decoder_attn_mask[i]][:, -self.decoder_cache_len:]
            if hidden.shape[1] < self.decoder_cache_len:
                hidden = F.pad(hidden, (0, 0, self.decoder_cache_len - hidden.shape[1], 0), "constant", 0.0)

            new_past_hidden_states.append(hidden)

        new_past_hidden_states = torch.stack(new_past_hidden_states, dim=1)
        new_decoder_caches = [{"past_hidden_states": new_past_hidden_states[i]} for i in range(self.num_decoder_layers)]

        return new_decoder_caches

    def reset_parameters(self):
        nn.init.normal_(self.relatve_position_embedding.weight.data, std=0.02)