# taken from https://github.com/CyberZHG/torch-multi-head-attention/blob/master/torch_multi_head_attention/multi_head_attention.py

import torch
import torch.nn as nn

from pdb import set_trace
from peagang.models.components.attention.scaled_dot_product import (
    ScaledDotProductAttention,
)
from peagang.models.components.utilities_classes import Swish, LinearTransmissionLayer
from peagang.models.components.utilities_functions import sn_wrap

__all__ = ["MultiHeadAttention"]


class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        in_features,
        out_features=None,
        head_num=1,
        bias=True,
        activation=None,
        out_activation=None,
        mode="QK",
        score_function="sigmoid",
        spectral_norm=None,
        skip=False, # adds skip connections from input to post attention and attention and post-act
        name=None

    ):
        """Multi-head attention.

        :param in_features: Size of each input sample.
        :param head_num: Number of heads.
        :param bias: Whether to use the bias term.
        :param activation: The activation after each linear transformation.
        """
        super(MultiHeadAttention, self).__init__()
        acts = {"relu": torch.nn.ReLU, "swish": Swish, "leakyrelu": torch.nn.LeakyReLU}
        if type(activation) is str:
            activation = acts[activation]()
        if type(out_activation) is str:
            out_activation = acts[out_activation]()
        if in_features % head_num != 0:
            raise ValueError(
                "`in_features`({}) should be divisible by `head_num`({})".format(
                    in_features, head_num
                )
            )
        self.mode = mode
        self.skip = skip
        self.in_features = in_features
        if out_features is None:
            out_features = in_features
        self.out_features = out_features
        self.head_num = head_num
        self.activation = activation
        self.out_activation = out_activation
        self.bias = bias
        self.score_function = score_function
        self.name=name

        self.linear_q = sn_wrap(
            nn.Linear(in_features, out_features, bias), spectral_norm
        )
        if self.mode == "QK":
            self.linear_k = sn_wrap(
                nn.Linear(in_features, out_features, bias), spectral_norm
            )
        else:
            self.linear_k = None
        self.linear_v = sn_wrap(
            nn.Linear(in_features, out_features, bias), spectral_norm
        )
        self.linear_o = sn_wrap(
            nn.Linear(out_features, out_features, bias), spectral_norm
        )
        if self.skip and out_features!=in_features:
            self.linear_o_proj=sn_wrap(nn.Linear(in_features, out_features, bias), spectral_norm)
        else:
            self.linear_o_proj=None

    def forward(self, q, k=None, v=None, mask=None, return_attention_and_scores=False):
        # q,k,v: tensors of batch_size, seq_len, in_feature
        if k is None:
            k=q
        if v is None:
            v=q
        ql = self.linear_q(q)
        vl = self.linear_v(v)
        if self.mode == "QQ":
            kl = ql
        elif self.mode == "QK":
            kl = self.linear_k(k)

        if self.activation is not None:
            ql = self.activation(ql)
            kl = self.activation(kl)
            vl = self.activation(vl)

        ql = self._reshape_to_batches(ql)
        kl = self._reshape_to_batches(kl)
        vl = self._reshape_to_batches(vl)
        if mask is not None:
            mask = mask.repeat(self.head_num, 1, 1)
        if return_attention_and_scores:
            y, _attn, _attn_scores = ScaledDotProductAttention()(
                ql, kl, vl, mask, return_attention_and_scores=return_attention_and_scores, att_sig=self.score_function
            )
        else:
            y = ScaledDotProductAttention()(
                ql, kl, vl, mask, return_attention_and_scores=return_attention_and_scores, att_sig=self.score_function
            )
        y_att = self._reshape_from_batches(y)


        if self.linear_o_proj:
            v_o=self.linear_o_proj(v)
        else:
            v_o=v
        yo = self.linear_o(y_att)

        if self.out_activation is not None:
            yo = self.out_activation(yo)
        if self.skip:
            yo=yo+v_o
        if return_attention_and_scores:
            return yo, _attn, _attn_scores
        else:
            return yo

    @staticmethod
    def gen_history_mask(x):
        """Generate the mask that only uses history data.

        :param x: Input tensor.
        :return: The mask.
        """
        batch_size, seq_len, _ = x.size()
        return (
            torch.tril(torch.ones(seq_len, seq_len))
            .view(1, seq_len, seq_len)
            .repeat(batch_size, 1, 1)
        )

    def _reshape_to_batches(self, x):
        batch_size, seq_len, in_feature = x.size()
        sub_dim = in_feature // self.head_num
        return (
            x.reshape(batch_size, seq_len, self.head_num, sub_dim)
            .permute(0, 2, 1, 3)
            .reshape(batch_size * self.head_num, seq_len, sub_dim)
        )

    def _reshape_from_batches(self, x):
        batch_size, seq_len, in_feature = x.size()
        batch_size //= self.head_num
        out_dim = in_feature * self.head_num
        return (
            x.reshape(batch_size, self.head_num, seq_len, in_feature)
            .permute(0, 2, 1, 3)
            .reshape(batch_size, seq_len, out_dim)
        )

    def extra_repr(self):
        return "in_features={}, head_num={}, bias={}, activation={}".format(
            self.in_features, self.head_num, self.bias, self.activation,
        )
