import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from .utils import create_relative_position
from .modules.normalizations import NORM2FN
from .modules.feedforward import FeedForward
from .modules.cross_attention import CrossAttention
from .modules.relative_self_attention import RelativeSelfAttention

# pylint:disable=no-member


class MemoryAttentionWriter(nn.Module):
    def __init__(
        self,
        num_heads: int,
        dim_model: int,
        dim_head: int,
        dropout: float = 0.0,
        dropattn: float = 0.0,
        layer_norm_type: str = "layer_norm"
    ):
        super().__init__()

        self.num_heads = num_heads
        self.dim_model = dim_model
        self.dim_head = dim_head
        self.dim_inner = num_heads * dim_head
        self.dropout = dropout
        self.dropattn = dropattn

        self.q_proj = nn.Linear(dim_model, self.dim_inner, bias=True)
        self.k_proj = nn.Linear(dim_model, self.dim_inner, bias=True)
        self.v_proj = nn.Linear(dim_model, self.dim_inner, bias=True)
        # keep the variance around 1
        self.scale = 1.0 / (dim_head**0.5)

        self.reset_parameters()
    
    def forward(self, memory_hidden_states, memory_token_hidden):
        """
        """        
        # TODO: Implement fast decoding
        batch_size = memory_hidden_states.size(0)
        query_len = memory_hidden_states.size(1)
        memory_token_len = memory_token_hidden.size(1)
        

        # query shape: (batch, head, seq_length, head_features)
        query = self.q_proj(memory_hidden_states).view(batch_size, query_len, self.num_heads,
                                                      self.dim_head).transpose(1, 2)
        
        # key shape: (batch_size, num_heads, head_size, tgt_len)
        # value shape: (batch_size, num_heads, tgt_len, head_size)
        cat_hidden = torch.cat([memory_hidden_states, memory_token_hidden], dim=1)
        key = self.k_proj(cat_hidden).view(batch_size, cat_hidden.shape[1], self.num_heads, self.dim_head).permute(0, 2, 3, 1)
        
        memory_token_key = key[:, :, :, -memory_token_len:]
        memory_token_attn_logits = torch.matmul(query, memory_token_key)
        
        memory_key = key[:, :, :, -query_len:].transpose(-1, -2)[:, :, :, :, None]
        self_attn_logits = torch.matmul(query[:, :, :, None, :], memory_key).squeeze(-1)
        
        attn_logits = torch.cat([self_attn_logits, memory_token_attn_logits], dim=-1)
        attn_probs = torch.softmax(attn_logits, dim=-1)
        
        memory_token_value = self.v_proj(memory_token_hidden).view(batch_size, memory_token_hidden.shape[1], self.num_heads,
                                                      self.dim_head).transpose(1, 2)
        self_value = memory_hidden_states.view(batch_size, query_len, self.num_heads,
                                                      self.dim_head).transpose(1, 2)
        
        memory_token_attn_probs = attn_probs[:, :, :, -memory_token_len:]
        self_attn_probs = attn_probs[:, :, :, :1]
        memory_token_out = torch.matmul(memory_token_attn_probs, memory_token_value)
        self_out = self_attn_probs * self_value
        new_memory_hidden_states = self_out + memory_token_out
        new_memory_hidden_states = new_memory_hidden_states.transpose(1, 2).reshape(batch_size, query_len, self.dim_model)
        return new_memory_hidden_states
        
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.q_proj.weight.data, 1 / math.sqrt(2))
        nn.init.xavier_uniform_(self.k_proj.weight.data, 1 / math.sqrt(2))
        nn.init.xavier_uniform_(self.v_proj.weight.data, 1 / math.sqrt(2))
        nn.init.constant_(self.q_proj.bias.data, 0.)
        nn.init.constant_(self.k_proj.bias.data, 0.)
        nn.init.constant_(self.v_proj.bias.data, 0.)

class EncoderLayer(nn.Module):
    def __init__(
        self,
        num_heads: int,
        dim_model: int,
        dim_head: int,
        dim_ff_inner: int,
        dropout: float = 0.0,
        dropattn: float = 0.0,
        layer_norm_type: str = "layer_norm",
        act_type: str = "swish",
    ):
        super().__init__()

        self.self_attn = RelativeSelfAttention(
            num_heads=num_heads,
            dim_model=dim_model,
            dim_head=dim_head,
            dropout=dropout,
            dropattn=dropattn,
            layer_norm_type=layer_norm_type
        )

        self.cross_attn = CrossAttention(
            num_heads=num_heads,
            dim_model=dim_model,
            dim_head=dim_head,
            dropout=dropout,
            dropattn=dropattn,
            layer_norm_type=layer_norm_type
        )

        self.feedforward = FeedForward(
            hidden_size=dim_model, intermediate_size=dim_ff_inner, dropout=dropout, act_type=act_type
        )

        # FeedForward does not have its own layer norm
        self.ff_layer_norm = NORM2FN[layer_norm_type](dim_model)

    def forward(
        self,
        hidden_states,
        rel_pos_embedding,
        cross_hidden_states,
        self_attn_mask,
        cross_attn_mask,
    ):
        """
        Args:
            cross_hidden_states: for cross attention on the encoder outputs
            rel_pos_embedding: shape (query_len, key_len, dim_head)
            decoder_cache: xl-style memory hidden states
            self_attn_mask: extended
            cross_attn_mask: extended
        """

        # PreNorm + Self Attn
        residual = hidden_states
        attn_outputs, self_attn_probs = self.self_attn(
            hidden_states=hidden_states,
            rel_pos_embedding=rel_pos_embedding,
            decoder_cache=None,
            extended_attn_mask=self_attn_mask
        )
        hidden_states = residual + attn_outputs

        # PreNorm + Cross Attn
        residual = hidden_states
        attn_outputs, cross_attn_probs = self.cross_attn(
            hidden_states=hidden_states,
            cross_hidden_states=cross_hidden_states,
            decoder_cache=None,
            extended_attn_mask=cross_attn_mask
        )
        hidden_states = residual + attn_outputs

        # PreNorm + FeedForward
        residual = hidden_states
        ff_outputs = self.feedforward(self.ff_layer_norm(hidden_states))
        hidden_states = residual + ff_outputs

        outputs = {
            "hidden_states": hidden_states,
        }

        return outputs


class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_attn_buckets = config.num_attn_buckets
        self.max_attn_distance = config.max_attn_distance
        self.num_encoder_layers = config.num_encoder_layers
        self.num_heads = config.num_heads
        self.encoder_memory_token_len = config.encoder_memory_token_len

        self.relatve_position_embedding = nn.Embedding(config.num_attn_buckets, config.dim_head)
        self.memory_embedding = nn.Parameter(torch.empty(self.encoder_memory_token_len, config.dim_model))
        
        self.memory_writer = MemoryAttentionWriter(
            num_heads=config.num_heads,
            dim_model=config.dim_model,
            dim_head=config.dim_head,
            dropout=config.dropout,
            dropattn=config.dropattn,
            layer_norm_type=config.layer_norm_type
        )

        self.layers = nn.ModuleList(
            [
                EncoderLayer(
                    num_heads=config.num_heads,
                    dim_model=config.dim_model,
                    dim_head=config.dim_head,
                    dim_ff_inner=config.dim_ff_inner,
                    dropout=config.dropout,
                    dropattn=config.dropattn,
                    layer_norm_type=config.layer_norm_type,
                    act_type=config.act_type,
                ) for _ in range(self.num_encoder_layers)
            ]
        )

        self.layer_norm = NORM2FN[config.layer_norm_type](config.dim_model)

        self.reset_parameters()

    def _get_extended_self_attn_mask(self, attn_mask):
        # This is only used for self attention
        seq_len = attn_mask.shape[1]
        attn_mask = attn_mask.unsqueeze(1)
        extended_attn_mask = attn_mask & attn_mask.transpose(-1, -2)
        # shape (batch, query_len, key_len)
        extended_attn_mask = extended_attn_mask[:, -seq_len:, :]
        extended_attn_mask = extended_attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
        return extended_attn_mask

    def _get_rel_pos_embedding(self, hidden_states):
        # Relative Position
        query_len = hidden_states.shape[1]
        key_len = query_len
        relative_position = create_relative_position(
            query_len=query_len,
            key_len=key_len,
            bidirectional=True,
            num_buckets=self.num_attn_buckets,
            max_distance=self.max_attn_distance
        ).to(hidden_states.device)
        # untie cls token
        relative_position[:self.encoder_memory_token_len, :] = 0
        relative_position[:, :self.encoder_memory_token_len] = 0
        rel_pos_embedding = self.relatve_position_embedding(relative_position)
        return rel_pos_embedding

    def forward(
        self, hidden_states: torch.FloatTensor, cross_hidden_states: torch.FloatTensor,
        encoder_attn_mask: torch.BoolTensor, cross_attn_mask: torch.BoolTensor
    ):
        """
        Returns:
            last_hidden_states: layer normalized
        """
        batch_size = hidden_states.shape[0]

        memory_hidden_states = self.memory_embedding.unsqueeze(0).expand(batch_size, -1, -1)
        # shape: (batch_size, memory_len + query_len, dim_model)
        hidden_states = torch.cat([memory_hidden_states, hidden_states], dim=1)
        rel_pos_embedding = self._get_rel_pos_embedding(hidden_states)

        # Masks
        self_attn_mask = self._get_extended_self_attn_mask(encoder_attn_mask)
        cross_attn_mask = cross_attn_mask[:, None, None, :cross_hidden_states.shape[1]]
        cross_attn_mask = cross_attn_mask.repeat(1, self.num_heads, hidden_states.shape[1], 1)

        # pre-norm
        # ln_cross_hidden_states = self.layer_norm(cross_hidden_states)

        for i in range(self.num_encoder_layers):
            outputs = self.layers[i](
                hidden_states=hidden_states,
                rel_pos_embedding=rel_pos_embedding,
                cross_hidden_states=cross_hidden_states,
                self_attn_mask=self_attn_mask,
                cross_attn_mask=None,
            )
            hidden_states = outputs["hidden_states"]

        hidden_states = self.layer_norm(hidden_states)

        new_cross_hidden_states = self.memory_writer(
            cross_hidden_states, memory_token_hidden=hidden_states[:, :self.encoder_memory_token_len]
        )

        outputs = {"last_hidden_states": hidden_states, "memory_hidden": new_cross_hidden_states}

        return outputs

    def reset_parameters(self):
        nn.init.normal_(self.relatve_position_embedding.weight.data, std=0.02)
        nn.init.normal_(self.memory_embedding.data, std=0.02)