from typing import Optional, Dict, Tuple

import torch
from pado.core import PadoModule
from pado.nn.modules import Linear

__all__ = ["ReuseMultiheadSelfAttentionWithV"]


class _ReuseMultiheadAttentionBaseWithV(PadoModule):
    """MHA module with external attention map"""

    def __init__(self,
                 hidden_dim: int,
                 num_heads: int,
                 bias: bool = True, *,  # bias is just dummy argument
                 v_dim: Optional[int] = None,
                 v_proj_dim: Optional[int] = None) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.use_bias = bias
        if num_heads <= 0:
            raise ValueError(f"Number of heads should be positive, but got {num_heads}.")

        if v_dim is None:
            v_dim = hidden_dim
        self.v_dim = v_dim

        if v_proj_dim is None:
            v_proj_dim = self.v_dim * 2  # to compensate decreased #params

        if v_proj_dim % num_heads != 0:
            raise ValueError(f"V dim {v_proj_dim} should be multiple of num_heads {num_heads}.")
        self.v_proj_dim = v_proj_dim

        self.v_proj = Linear(v_dim, v_proj_dim, bias=True, init_type="attn")

    def _transpose_for_attn(self, x: torch.Tensor):
        # (batch_size, length, hidden_dim) -> (batch_size, num_heads, length, head_dim)
        batch_size, length, hidden_dim = x.shape
        head_dim = hidden_dim // self.num_heads
        x = x.view(batch_size, length, self.num_heads, head_dim).transpose(1, 2).contiguous()
        return x

    def _transpose_for_output(self, x: torch.Tensor):
        # (batch_size, num_heads, length, head_dim) -> (batch_size, length, hidden_dim)
        batch_size, num_heads, length, head_dim = x.shape
        assert num_heads == self.num_heads
        x = x.transpose(2, 1).contiguous().view(batch_size, length, -1)
        return x

    def forward(self, *args, **kwargs):
        raise NotImplementedError


class ReuseMultiheadSelfAttentionWithV(_ReuseMultiheadAttentionBaseWithV):

    def __init__(self,
                 hidden_dim: int,
                 num_heads: int,
                 bias: bool = True, *,
                 v_proj_dim: Optional[int] = None) -> None:
        super().__init__(hidden_dim, num_heads, bias, v_proj_dim=v_proj_dim)

    def forward(self,
                query_state: torch.Tensor,
                attn_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Self-attention MHA
        :param query_state:     (batch_size, query_length, hidden_dim)
        :param attn_scores:     (batch_size, num_heads, query_length, query_length)
        :return:                (batch_size, query_length, hidden_dim)
        """
        batch_size, query_len, hidden_dim = query_state.shape
        assert hidden_dim == self.hidden_dim

        v = self.v_proj(query_state)  # (batch_size, query_len, hidden_dim)
        v = self._transpose_for_attn(v)  # (batch_size, num_heads, query_len, head_dim)

        output = torch.matmul(attn_scores, v)  # (batch_size, num_heads, query_len, head_dim)
        output = self._transpose_for_output(output)  # (batch_size, query_length, hidden_dim)

        return output, attn_scores, v

    def step(self,
             query_state: torch.Tensor,
             attn_scores: torch.Tensor, *,
             prev_state: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Self-attention MHA generation step
        For generation, query_state is expected to get seq_len == 1 (not restricted).
        Assuming eval() mode.
        :param query_state:     (batch_size, 1, hidden_dim)
        :param attn_scores:     (batch_size, 1, 1 + key_len)
        :param prev_state:      {'prev_value'}
                prev_value:     (batch_size, num_heads, prev_key_length, head_dim)
                (1) will be automatically UPDATED after run (in-place to prev_state, so do not return state)
                (2) will attend ALL (mask == None)
        :return:                (batch_size, 1, hidden_dim)
        """
        batch_size, query_len, hidden_dim = query_state.shape
        assert hidden_dim == self.hidden_dim

        # load previous states
        prev_value = None
        prev_len = 0
        if len(prev_state) > 0:
            prev_value = prev_state.get("prev_value")
            if prev_value is not None:
                _, _, prev_len, _ = prev_value.shape
                assert prev_value.shape == (batch_size, self.num_heads, prev_len, self.proj_head_dim)

        v = self.v_proj(query_state)  # (batch_size, query_len, hidden_dim)
        v = self._transpose_for_attn(v)  # (batch_size, num_heads, query_len, head_dim)

        if prev_len > 0:
            v = torch.cat([prev_value, v], dim=2)  # (batch_size, num_heads, query_len + prev_len, head_dim)
        prev_state.update({"prev_value": v.detach().clone()})

        output = torch.matmul(attn_scores, v)  # (batch_size, num_heads, query_len, head_dim)
        output = self._transpose_for_output(output)  # (batch_size, query_length, hidden_dim)
        return output
