from typing import Optional, Tuple
import torch
import torch.nn.functional as F

from pado.nn.parameter import ParameterModule
from pado.nn.modules import Linear
from pado.nn.transformer.multihead_attention import _MultiheadAttentionBase
from pado.nn.transformer.utils import apply_attn_mask, apply_weak_attention_suppression


class _RelMultiheadAttentionBase(_MultiheadAttentionBase):

    def __init__(self,
                 hidden_dim: int,
                 num_heads: int,
                 attn_drop_prob: float = 0.0,
                 bias: bool = False, *,
                 share_query_bias: bool = False,
                 k_dim: Optional[int] = None,
                 v_dim: Optional[int] = None,
                 qk_proj_dim: Optional[int] = None,
                 v_proj_dim: Optional[int] = None,
                 was: bool = False,
                 was_gamma: float = 0.5) -> None:
        super().__init__(hidden_dim, num_heads, attn_drop_prob, bias=bias, attn_normalize=False,
                         k_dim=k_dim, v_dim=v_dim, qk_proj_dim=qk_proj_dim, v_proj_dim=v_proj_dim,
                         was=False, was_gamma=was_gamma)
        self.r_proj = Linear(hidden_dim, self.qk_proj_dim, bias=False, init_type="attn")
        self.was = was  # set self.attn.was = False in __init__ and reset to True
        self.was_gamma = was_gamma
        assert self.attn.was is False

        self.share_query_bias = share_query_bias
        if not share_query_bias:
            self.query_key_bias = ParameterModule(torch.zeros(hidden_dim, dtype=torch.float32))
            self.query_pos_bias = ParameterModule(torch.zeros(hidden_dim, dtype=torch.float32))
        else:
            self.query_key_bias = self.query_pos_bias = None

    @staticmethod
    def _rel_shift_left(x: torch.Tensor) -> torch.Tensor:
        """
        Shift relative length to LEFT (same as TXL impl.)
        :param x:   (batch_size, num_heads, query_length, key_length)

        6 5 4 3 2 1 0    4 3 2 1 0 - -
        6 5 4 3 2 1 0 -> 5 4 3 2 1 0 -
        6 5 4 3 2 1 0    6 5 4 3 2 1 0

        x[:, 0] << shift left (q - 1)
        x[:, 1] << shift left (q - 2)
        ...
        x[:, q-1] << shift left 0
        :return:    (batch_size, num_heads, query_length, key_length)
        """
        batch_size, num_heads, query_len, key_len = x.shape
        x_padded = F.pad(x, (1, 0))
        # - 3 2 1 0
        # - 3 2 1 0
        # - 3 2 1 0

        x_padded = x_padded.view(batch_size, num_heads, 1 + key_len, query_len)  # (..., 1 + key_len, query_len)
        # - 3 2
        # 1 0 -
        # 3 2 1
        # 0 - 3
        # 2 1 0

        x = x_padded[:, :, 1:, :].view_as(x)  # (..., query_len, key_len)
        # 1 0 - 3
        # 2 1 0 -
        # 3 2 1 0
        return x

    @staticmethod
    def _rel_shift_right(x):
        """
        Shift relative length to RIGHT
        :param x:   (batch_size, num_heads, query_length, key_length)

                                      |--use--|
        -3 -2 -1  0  1  2    -3 -2 -1  0  1  2
        -3 -2 -1  0  1  2 ->  - -3 -2 -1  0  1
        -3 -2 -1  0  1  2     -  - -3 -2 -1  0

        x[:, 0] >> shift right 0
        x[:, 1] >> shift right 1
        ...
        x[:, q-1] >> shift right (q - 1)
        :return     (batch_size, num_heads, query_length, key_length)
        """
        batch_size, num_heads, query_len, key_len = x.shape
        x_padded = F.pad(x, (0, 1))
        # -2 -1 0 1 2 -
        # -2 -1 0 1 2 -
        # -2 -1 0 1 2 -

        x_padded = x_padded.view(batch_size, num_heads, -1)
        # -2 -1 0 1 2 - -2 -1 0 1 2 - -2 -1 0 1 2 -

        x = x_padded[:, :, :query_len * key_len].view_as(x)
        # -2 -1 0 1 2
        # - -2 -1 0 1
        # 2 - -2 -1 0

        return x

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def extra_repr(self) -> str:
        s = super().extra_repr()
        s += f", share_query_bias={self.share_query_bias}"
        return s


class AutoRegressiveRelMultiheadAttention(_RelMultiheadAttentionBase):

    def forward(self,
                query_state: torch.Tensor,
                pos_emb: torch.Tensor,
                query_key_bias: Optional[torch.Tensor],
                query_pos_bias: Optional[torch.Tensor],
                memory: Optional[torch.Tensor] = None,
                attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        AR, Rel-pos Self-attention MHA
        :param query_state:     (batch_size, query_length, hidden_dim)
        :param pos_emb:         (1, memory_length + query_length, hidden_dim)
        :param query_key_bias:  (hidden_dim,)   if share_query_bias
        :param query_pos_bias:  (hidden_dim,)   if share_query_bias
        :param memory:          (batch_size, memory_length, hidden_dim)
        :param attn_mask:       (batch_size, query_length, memory_length + query_length)
        :return:                (batch_size, query_length, hidden_dim)
        """
        batch_size, query_len, hidden_dim = query_state.shape
        assert hidden_dim == self.hidden_dim

        q = self.q_proj(query_state)  # (batch_size, query_len, hidden_dim)

        if memory is not None:
            kv_query_state = torch.cat([memory, query_state], dim=1)  # (batch_size, memory_length + query_length, d)
        else:
            kv_query_state = query_state
        k = self.k_proj(kv_query_state)  # (batch_size, memory_len + query_len, hidden_dim)
        v = self.v_proj(kv_query_state)  # (batch_size, memory_len + query_len, hidden_dim)
        key_len = k.shape[1]

        pos_emb = pos_emb[:, -key_len:]
        rk = self.r_proj(pos_emb)  # (1, memory_len + query_len, hidden_dim)
        if rk.shape[1] != key_len:
            raise ValueError(f"Relative position sequence length {rk.shape[1]} mismatch, "
                             f"query_len: {query_len}, key_len: {key_len}.")
        rk = rk.expand(batch_size, rk.shape[1], hidden_dim)  # (batch_size, memory_len + query_len, hidden_dim)

        if self.share_query_bias:
            assert (query_key_bias is not None) and (query_pos_bias is not None)
            wq = q + query_key_bias  # (batch_size, query_len, hidden_dim)
            rq = q + query_pos_bias  # (batch_size, query_len, hidden_dim)
        else:
            wq = q + self.query_key_bias()
            rq = q + self.query_pos_bias()

        wq = self._transpose_for_attn(wq)  # (batch_size, num_heads, query_len, head_dim)
        wk = self._transpose_for_attn(k)  # (batch_size, num_heads, key_len, head_dim)
        rq = self._transpose_for_attn(rq)  # (batch_size, num_heads, query_len, head_dim)
        rk = self._transpose_for_attn(rk)  # (batch_size, num_heads, key_len, head_dim)

        v = self._transpose_for_attn(v)  # (batch_size, num_heads, key_len, head_dim)

        scores_ac = self.attn(wq, wk, mask=None)  # (batch_size, num_heads, query_len, key_len), no normalize
        scores_bd = self.attn(rq, rk, mask=None)  # (batch_size, num_heads, query_len, key_len), no normalize
        scores_bd = self._rel_shift_left(scores_bd)  # will be masked later, (batch_size, num_heads, query_len, key_len)

        scores = scores_ac + scores_bd  # sum at un-normalized state
        scores = apply_attn_mask(scores, mask=attn_mask)
        if self.was:
            scores = apply_weak_attention_suppression(scores, attn_mask, gamma=self.was_gamma)
        scores = torch.softmax(scores, dim=-1, dtype=torch.float32)
        if attn_mask is not None:
            scores = scores.masked_fill(torch.logical_not(attn_mask.unsqueeze(1)), 0.0)

        scores = self.attn_drop(scores)

        output = torch.matmul(scores, v)  # (batch_size, num_heads, query_len, head_dim)
        output = self._transpose_for_output(output)  # (batch_size, query_len, hidden_dim)

        return output, scores


class AutoEncodingRelMultiheadAttention(_RelMultiheadAttentionBase):

    def forward(self,
                query_state: torch.Tensor,
                pos_emb: torch.Tensor,
                query_key_bias: Optional[torch.Tensor],
                query_pos_bias: Optional[torch.Tensor],
                attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        AR, Rel-pos Self-attention MHA
        :param query_state:     (batch_size, query_length, hidden_dim)
        :param pos_emb:         (1, 2 * query_length - 1, hidden_dim)
        :param query_key_bias:  (hidden_dim,)   if share_query_bias
        :param query_pos_bias:  (hidden_dim,)   if share_query_bias
        :param attn_mask:       (batch_size, query_length, query_length * 2 - 1)
        :return:                (batch_size, query_length, hidden_dim)
        """
        batch_size, query_len, hidden_dim = query_state.shape
        assert hidden_dim == self.hidden_dim

        q = self.q_proj(query_state)  # (batch_size, query_len, hidden_dim)
        k = self.k_proj(query_state)
        v = self.v_proj(query_state)

        rk = self.r_proj(pos_emb)  # (1, query_len * 2 - 1, hidden_dim)
        rk = rk.expand(batch_size, -1, hidden_dim)

        if self.share_query_bias:
            assert (query_pos_bias is not None) and (query_key_bias is not None)
            wq = q + query_key_bias  # (batch_size, query_len, hidden_dim)
            rq = q + query_pos_bias  # (batch_size, query_len, hidden_dim)
        else:
            wq = q + self.query_key_bias()
            rq = q + self.query_pos_bias()

        wq = self._transpose_for_attn(wq)  # (batch_size, num_heads, query_len, head_dim)
        wk = self._transpose_for_attn(k)
        rq = self._transpose_for_attn(rq)
        rk = self._transpose_for_attn(rk)

        v = self._transpose_for_attn(v)  # (batch_size, num_heads, query_len, head_dim)

        scores_ac = self.attn(wq, wk, mask=None)  # (batch_size, num_heads, query_len, query_len), no normalize
        scores_bd = self.attn(rq, rk, mask=None)  # (batch_size, num_heads, query_len, query_len * 2)
        scores_bd = self._rel_shift_right(scores_bd)
        scores_bd = scores_bd[:, :, :, -query_len:]

        scores = scores_ac + scores_bd  # sum at un-normalized state
        scores = apply_attn_mask(scores, mask=attn_mask)
        if self.was:
            scores = apply_weak_attention_suppression(scores, attn_mask, gamma=self.was_gamma)
        scores = torch.softmax(scores, dim=-1, dtype=torch.float32)
        if attn_mask is not None:
            scores = scores.masked_fill(torch.logical_not(attn_mask.unsqueeze(1)), 0.0)
        scores = self.attn_drop(scores)

        output = torch.matmul(scores, v)  # (batch_size, num_heads, query_len, head_dim)
        output = self._transpose_for_output(output)  # (batch_size, query_len, hidden_dim)

        return output, scores
