from typing import Optional, Tuple, Dict

import torch

from pado.core import PadoModule
from pado.nn.transformer.dot_product import ScaledDotProduct
from pado.nn.modules import Linear, Dropout

__all__ = ["MultiheadSelfAttentionWithV", "_MultiheadAttentionBaseWithV"]


class _MultiheadAttentionBaseWithV(PadoModule):

    def __init__(self,
                 hidden_dim: int,
                 num_heads: int,
                 attn_drop_prob: float = 0.0,
                 bias: bool = True, *,
                 attn_normalize: bool = True,
                 k_dim: Optional[int] = None,
                 v_dim: Optional[int] = None,
                 qk_proj_dim: Optional[int] = None,
                 v_proj_dim: Optional[int] = None,
                 was: bool = False,
                 was_gamma: float = 0.5) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.use_bias = bias
        if num_heads <= 0:
            raise ValueError(f"Number of heads should be positive, but got {num_heads}.")

        if k_dim is None:
            k_dim = hidden_dim
        if v_dim is None:
            v_dim = k_dim
        self.k_dim = k_dim
        self.v_dim = v_dim

        if qk_proj_dim is None:
            qk_proj_dim = hidden_dim
        if v_proj_dim is None:
            v_proj_dim = hidden_dim
        self.qk_proj_dim = qk_proj_dim
        self.v_proj_dim = v_proj_dim

        if (qk_proj_dim % num_heads != 0) or (v_proj_dim % num_heads != 0):
            raise ValueError(f"QKV dim {qk_proj_dim} or {v_proj_dim} should be multiple of num_heads {num_heads}.")
        self.head_dim = qk_proj_dim // num_heads

        self.q_proj = Linear(hidden_dim, qk_proj_dim, bias=bias, init_type="attn")
        self.k_proj = Linear(k_dim, qk_proj_dim, bias=bias, init_type="attn")
        self.v_proj = Linear(v_dim, v_proj_dim, bias=True, init_type="attn")  # always bias=True for V

        self.was = was
        self.was_gamma = was_gamma
        self.attn = ScaledDotProduct(self.head_dim, normalize=attn_normalize, was=was, was_gamma=was_gamma)
        self.attn_drop = Dropout(attn_drop_prob)

    def _transpose_for_attn(self, x: torch.Tensor):
        # (batch_size, length, hidden_dim) -> (batch_size, num_heads, length, head_dim)
        batch_size, length, hidden_dim = x.shape
        head_dim = hidden_dim // self.num_heads
        x = x.view(batch_size, length, self.num_heads, head_dim).transpose(1, 2).contiguous()
        return x

    def _transpose_for_output(self, x: torch.Tensor):
        # (batch_size, num_heads, length, head_dim) -> (batch_size, length, hidden_dim)
        batch_size, num_heads, length, head_dim = x.shape
        assert num_heads == self.num_heads
        x = x.transpose(2, 1).contiguous().view(batch_size, length, -1)
        return x

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def extra_repr(self) -> str:
        s = f"{self.hidden_dim}, num_heads={self.num_heads}"
        if self.attn_drop.p > 0:
            s += f", attn_drop_prob={self.attn_drop.p}"
        if not self.use_bias:
            s += f", bias=False"
        if self.k_dim != self.hidden_dim:
            s += f", k_dim={self.k_dim}"
        if self.v_dim != self.hidden_dim:
            s += f", v_dim={self.v_dim}"
        if self.qk_proj_dim != self.hidden_dim:
            s += f", qk_proj_dim={self.qk_proj_dim}"
        if self.v_proj_dim != self.hidden_dim:
            s += f", v_proj_dim={self.v_proj_dim}"
        if self.was:
            s += f", was=True, was_gamma={self.was_gamma}"
        return s


class MultiheadSelfAttentionWithV(_MultiheadAttentionBaseWithV):

    def __init__(self,
                 hidden_dim: int,
                 num_heads: int,
                 attn_drop_prob: float = 0.0,
                 bias: bool = True, *,
                 was: bool = False,
                 was_gamma: float = 0.5) -> None:
        super().__init__(hidden_dim, num_heads, attn_drop_prob, bias,
                         attn_normalize=True, was=was, was_gamma=was_gamma)

    def forward(self,
                query_state: torch.Tensor,
                attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Self-attention MHA
        :param query_state:     (batch_size, query_length, hidden_dim)
        :param attn_mask:       (batch_size, query_length, query_length)
        :return:                (batch_size, query_length, hidden_dim)
        """
        batch_size, query_len, hidden_dim = query_state.shape
        assert hidden_dim == self.hidden_dim

        q = self.q_proj(query_state)  # (batch_size, query_len, hidden_dim)
        k = self.k_proj(query_state)  # (batch_size, query_len, hidden_dim)
        v = self.v_proj(query_state)  # (batch_size, query_len, hidden_dim)

        q = self._transpose_for_attn(q)  # (batch_size, num_heads, query_len, head_dim)
        k = self._transpose_for_attn(k)  # (batch_size, num_heads, query_len, head_dim)
        v = self._transpose_for_attn(v)  # (batch_size, num_heads, query_len, head_dim)

        scores = self.attn(q, k, mask=attn_mask)  # (batch_size, num_heads, query_len, query_len)
        scores = self.attn_drop(scores)

        output = torch.matmul(scores, v)  # (batch_size, num_heads, query_len, head_dim)
        output = self._transpose_for_output(output)  # (batch_size, query_length, hidden_dim)

        return output, scores, v

    def step(self,
             query_state: torch.Tensor, *,
             prev_state: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Self-attention MHA generation step
        For generation, query_state is expected to get seq_len == 1 (not restricted).
        Assuming eval() mode.
        :param query_state:     (batch_size, 1, hidden_dim)
        :param prev_state:      {'prev_key', 'prev_value'}
                prev_key:       (batch_size, num_heads, prev_key_length, head_dim)
                prev_value:     (batch_size, num_heads, prev_key_length, head_dim)
                (1) will be automatically UPDATED after run (in-place to prev_state, so do not return state)
                (2) will attend ALL (mask == None)
        :return:                (batch_size, 1, hidden_dim)
        """
        batch_size, query_len, hidden_dim = query_state.shape
        assert hidden_dim == self.hidden_dim

        # load previous states
        prev_key = prev_value = None
        prev_len = 0
        if len(prev_state) > 0:
            prev_key = prev_state.get("prev_key")
            prev_value = prev_state.get("prev_value")
            if (prev_key is not None) and (prev_value is not None):
                _, _, prev_len, _ = prev_key.shape
                assert prev_key.shape == prev_value.shape == (batch_size, self.num_heads, prev_len, self.head_dim)

        q = self.q_proj(query_state)  # (batch_size, query_len, hidden_dim)
        k = self.k_proj(query_state)  # (batch_size, query_len, hidden_dim)
        v = self.v_proj(query_state)  # (batch_size, query_len, hidden_dim)

        q = self._transpose_for_attn(q)  # (batch_size, num_heads, query_len, head_dim)
        k = self._transpose_for_attn(k)  # (batch_size, num_heads, query_len, head_dim)
        v = self._transpose_for_attn(v)  # (batch_size, num_heads, query_len, head_dim)

        if prev_len > 0:
            k = torch.cat([prev_key, k], dim=2)  # (batch_size, num_heads, query_len + prev_len, head_dim)
            v = torch.cat([prev_value, v], dim=2)  # (batch_size, num_heads, query_len + prev_len, head_dim)
        prev_state.update({"prev_key": k.detach().clone(), "prev_value": v.detach().clone()})

        scores = self.attn(q, k, mask=None)  # (batch_size, num_heads, query_len, query_len + prev_len)

        output = torch.matmul(scores, v)  # (batch_size, num_heads, query_len, head_dim)
        output = self._transpose_for_output(output)  # (batch_size, query_length, hidden_dim)
        return output
