# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.scalar_bias import scalar_bias


class SingleHeadAttention(nn.Module):
    """
    Single-head attention that supports Gating and Downsampling
    """

    def __init__(
        self,
        out_channels,
        embed_dim,
        head_dim,
        head_index,
        dropout=0.0,
        bias=True,
        project_input=True,
        gated=False,
        downsample=False,
        num_heads=1,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.dropout_module = FairseqDropout(
            dropout, module_name=self.__class__.__name__
        )
        self.head_index = head_index
        self.head_dim = head_dim
        self.project_input = project_input
        self.gated = gated
        self.downsample = downsample
        self.num_heads = num_heads
        self.projection = None

        k_layers = []
        v_layers = []
        if self.downsample:
            k_layers.append(Downsample(self.head_index))
            v_layers.append(Downsample(self.head_index))
            out_proj_size = self.head_dim
        else:
            out_proj_size = self.head_dim * self.num_heads
        if self.gated:
            k_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias))
            self.in_proj_q = GatedLinear(self.embed_dim, out_proj_size, bias=bias)
            v_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias))
        else:
            k_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias))
            self.in_proj_q = Linear(self.embed_dim, out_proj_size, bias=bias)
            v_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias))

        self.in_proj_k = nn.Sequential(*k_layers)
        self.in_proj_v = nn.Sequential(*v_layers)

        if self.downsample:
            self.out_proj = Linear(out_proj_size, self.head_dim, bias=bias)
        else:
            self.out_proj = Linear(out_proj_size, out_channels, bias=bias)

        self.scaling = self.head_dim**-0.5

    def forward(
        self,
        query,
        key,
        value,
        mask_future_timesteps=False,
        key_padding_mask=None,
        use_scalar_bias=False,
    ):
        """Input shape: Time x Batch x Channel
        Self-attention can be implemented by passing in the same arguments for
        query, key and value. Future timesteps can be masked with the
        `mask_future_timesteps` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """
        src_len, bsz, out_channels = key.size()
        tgt_len = query.size(0)
        assert list(query.size()) == [tgt_len, bsz, out_channels]
        assert key.size() == value.size()

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.downsample:
            size = bsz
        else:
            size = bsz * self.num_heads

        k = key
        v = value
        q = query
        if self.project_input:
            q = self.in_proj_q(q)
            k = self.in_proj_k(k)
            v = self.in_proj_v(v)
            src_len = k.size()[0]
        q *= self.scaling

        if not self.downsample:
            q = q.view(tgt_len, size, self.head_dim)
            k = k.view(src_len, size, self.head_dim)
            v = v.view(src_len, size, self.head_dim)

        q = q.transpose(0, 1)
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        if mask_future_timesteps:
            assert (
                query.size() == key.size()
            ), "mask_future_timesteps only applies to self-attention"
            attn_weights *= torch.tril(
                attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(),
                diagonal=-1,
            )[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0)
            attn_weights += torch.triu(
                attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(),
                diagonal=0,
            )[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0)
        tgt_size = tgt_len
        if use_scalar_bias:
            attn_weights = scalar_bias(attn_weights, 2)
            v = scalar_bias(v, 1)
            tgt_size += 1

        if key_padding_mask is not None:
            # don't attend to padding symbols
            if key_padding_mask.max() > 0:
                if self.downsample:
                    attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len)
                else:
                    attn_weights = attn_weights.view(
                        size, self.num_heads, tgt_len, src_len
                    )
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    -math.inf,
                )
                attn_weights = attn_weights.view(size, tgt_len, src_len)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout_module(attn_weights)

        attn = torch.bmm(attn_weights, v)
        if self.downsample:
            attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.head_dim)
        else:
            attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)

        attn = self.out_proj(attn)

        return attn, attn_weights


class DownsampledMultiHeadAttention(nn.ModuleList):
    """
    Multi-headed attention with Gating and Downsampling
    """

    def __init__(
        self,
        out_channels,
        embed_dim,
        num_heads,
        dropout=0.0,
        bias=True,
        project_input=True,
        gated=False,
        downsample=False,
    ):
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.downsample = downsample
        self.gated = gated
        self.project_input = project_input
        assert self.head_dim * num_heads == embed_dim

        if self.downsample:
            attention_heads = []
            for index in range(self.num_heads):
                attention_heads.append(
                    SingleHeadAttention(
                        out_channels,
                        self.embed_dim,
                        self.head_dim,
                        index,
                        dropout,
                        bias,
                        self.project_input,
                        self.gated,
                        self.downsample,
                        self.num_heads,
                    )
                )
            super().__init__(modules=attention_heads)
            self.out_proj = Linear(embed_dim, out_channels, bias=bias)
        else:
            # either we have a list of attention heads, or just one attention head
            # if not being downsampled, we can do the heads with one linear layer instead of separate ones
            super().__init__()
            self.attention_module = SingleHeadAttention(
                out_channels,
                self.embed_dim,
                self.head_dim,
                1,
                dropout,
                bias,
                self.project_input,
                self.gated,
                self.downsample,
                self.num_heads,
            )

    def forward(
        self,
        query,
        key,
        value,
        mask_future_timesteps=False,
        key_padding_mask=None,
        use_scalar_bias=False,
    ):
        src_len, bsz, embed_dim = key.size()
        tgt_len = query.size(0)
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]
        assert key.size() == value.size()

        tgt_size = tgt_len
        if use_scalar_bias:
            tgt_size += 1

        attn = []
        attn_weights = []
        if self.downsample:
            for attention_head_number in range(self.num_heads):
                # call the forward of each attention head
                _attn, _attn_weight = self[attention_head_number](
                    query,
                    key,
                    value,
                    mask_future_timesteps,
                    key_padding_mask,
                    use_scalar_bias,
                )
                attn.append(_attn)
                attn_weights.append(_attn_weight)
            full_attn = torch.cat(attn, dim=2)
            full_attn = self.out_proj(full_attn)
            return full_attn, attn_weights[0].clone()
        else:
            _attn, _attn_weight = self.attention_module(
                query,
                key,
                value,
                mask_future_timesteps,
                key_padding_mask,
                use_scalar_bias,
            )
            attn.append(_attn)
            attn_weights.append(_attn_weight)
            full_attn = torch.cat(attn, dim=2)
            full_attn_weights = torch.cat(attn_weights)
            full_attn_weights = full_attn_weights.view(
                bsz, self.num_heads, tgt_size, src_len
            )
            full_attn_weights = full_attn_weights.sum(dim=1) / self.num_heads
            return full_attn, full_attn_weights


class Downsample(nn.Module):
    """
    Selects every nth element, where n is the index
    """

    def __init__(self, index):
        super().__init__()
        self.index = index

    def forward(self, x):
        return x[:: self.index + 1]


def Linear(in_features, out_features, dropout=0.0, bias=True):
    """Weight-normalized Linear layer (input: B x T x C)"""
    m = nn.Linear(in_features, out_features, bias=bias)
    m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
    m.bias.data.zero_()
    return nn.utils.weight_norm(m)


def GatedLinear(in_features, out_features, dropout=0.0, bias=True):
    """Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units"""
    return nn.Sequential(
        Linear(in_features, out_features * 4, dropout, bias),
        nn.GLU(),
        Linear(out_features * 2, out_features * 2, dropout, bias),
        nn.GLU(),
        Linear(out_features, out_features, dropout, bias),
    )
