"""The Attention Module.
"""
from typing import Dict

import torch
import torch.nn as nn
from torch.nn.functional import interpolate


class LayerNorm(nn.Module):
    """Normalize along a dimension."""

    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor):
        v, m = torch.var_mean(x.clone(), self.dim, keepdim=True, unbiased=True)
        return (x - m) / torch.sqrt(v + 1e-5)


class MultiHeadedAttention(nn.Module):
    """Initialize a Multi-headed attention module."""

    def __init__(
        self, channels: int, h: int, stride: int, mode: str, kernel: str, gated: bool
    ):
        """Initialize the Multi-Headed attention module.
        Args:
            * channels: The number of channels in the descriptors.
            * h: The number of heads to be created.
            * stride: The stride to be applied on dense features.
            * mode: Describes the type of input, can be "sparse_to_dense",
                "dense_to_sparse" or "dense_to_dense".
            * kernel: The attention kernel, choose from "dot", "linear", and "dot_linear".
            * gated: If True, a post-softmax gating parameter is learned.
        """
        super().__init__()
        assert channels % h == 0
        self.dim = channels // h
        self.h = h
        self.stride = stride
        self.mode = mode
        self.kernel = kernel

        # Linear projection layers
        if mode == "sparse_to_dense":
            query_linear, key_linear, value_linear = [nn.Conv1d] + 2 * [nn.Conv2d]
            strides = [1] * 3
        elif mode == "dense_to_sparse":
            query_linear, key_linear, value_linear = [nn.Conv2d] + 2 * [nn.Conv1d]
            strides = [1] * 3
        elif mode == "dense_to_dense":
            query_linear, key_linear, value_linear = [nn.Conv2d] * 3
            strides = [1, stride, stride]

        self.projection_layers = nn.ModuleList(
            [
                query_linear(channels, channels, kernel_size=1, stride=strides[0]),
                key_linear(channels, channels, kernel_size=1, stride=strides[1]),
                value_linear(channels, channels, kernel_size=1, stride=strides[2]),
            ]
        )

        # Gating parameters
        self.gated = gated
        if gated:
            self.affine = nn.Linear(1, 1, bias=True)
            nn.init.constant_(self.affine.bias, 0.0)
            self.affine.weight.data.fill_(1e-3)

    def attention(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        """Compute and apply the attention maps."""

        # Prepare inputs
        # Q : h x N x d
        # K : h x d x M
        # V : h x d x M
        if self.mode == "sparse_to_dense":
            Q = Q.view(len(Q), self.h, self.dim).transpose(0, 1)
            K, V = [x.view(self.h, self.dim, -1) for x in (K, V)]
        elif self.mode == "dense_to_sparse":
            Q = Q.view(self.h, self.dim, -1).transpose(1, 2)
            K = K.view(len(K), self.h, self.dim).transpose(0, 1).transpose(1, 2)
            V = V.view(len(V), self.h, self.dim).transpose(0, 1).transpose(1, 2)
        elif self.mode == "dense_to_dense":
            Q = Q.view(self.h, self.dim, -1).transpose(1, 2)
            K, V = [x.view(self.h, self.dim, -1) for x in (K, V)]

        # Pre-normalize
        Q = Q / (self.dim ** 0.25)
        K = K / (self.dim ** 0.25)

        # Dot-product attention
        QK = torch.bmm(Q, K)
        if self.gated:
            max_, _ = QK.max(dim=-1, keepdim=True)
            gating = torch.sigmoid(self.affine(max_))
        QK = torch.nn.functional.softmax(QK, dim=-1)
        if self.gated:
            QK = gating * QK
        QKV = torch.bmm(QK, V.transpose(1, 2))

        # QKV : N x h x d
        if self.mode == "sparse_to_dense":
            return QKV.transpose(0, 1).contiguous()
        else:
            return QKV.transpose(1, 2).contiguous()

    def projection_reshape(
        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
    ):
        if self.mode == "sparse_to_dense":
            return query[..., None], key, value
        elif self.mode == "dense_to_sparse":
            return query, key[..., None], value[..., None]
        return query, key, value

    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        """Apply the multi-headed attention module.
        Args:
            * Q: The [N x C] query tensor.
            * K: The [1 x C x Hk x Wk] key tensor.
            * V: The [1 x C x Hv x Wv] value tensor.
        """
        if self.mode == "dense_to_dense":
            init_height, init_width = Q.shape[-2:]

        # Apply linear projection
        Q, K, V = [
            p(i)
            for p, i in zip(self.projection_layers, self.projection_reshape(Q, K, V))
        ]

        # Attention module
        x = self.attention(Q, K, V)

        if self.mode == "sparse_to_dense":
            x = x.view(len(Q), self.h * self.dim)
        else:
            x = x.view(1, self.h * self.dim, *list(Q.shape[-2:]))

        # (Optional) Upsample previously downsampled query
        if self.mode == "dense_to_dense" and self.stride > 1:
            x = interpolate(
                x,
                (init_height, init_width),
                mode="bilinear",
                align_corners=True,
            )

        return x


class TransformerBlock(nn.Module):
    """A Transformer encoder block."""

    def __init__(
        self,
        num_dim: int,
        config: Dict,
        mode: str,
        kernel: str,
        is_first: bool = False,
    ):
        super().__init__()
        assert mode in ["sparse_to_dense", "dense_to_sparse", "dense_to_dense"]
        assert kernel in ["dot", "linear", "dot_linear"]
        self.mode = mode
        self.attention = MultiHeadedAttention(
            num_dim,
            config["num_heads"],
            config["stride"],
            mode,
            kernel,
            config["gated"],
        )
        self.ln = LayerNorm(dim=1)

        # Model non-linearity
        if mode == "sparse_to_dense":
            linear_layer, kwargs = nn.Linear, {}
        else:
            linear_layer, kwargs = nn.Conv2d, {"kernel_size": 1}
        self.merge = linear_layer(num_dim, num_dim, bias=False, **kwargs)
        self.non_linearity = nn.Sequential(
            *[
                linear_layer(num_dim, num_dim, **kwargs),
                nn.ReLU(),
                linear_layer(num_dim, num_dim, **kwargs),
            ]
        )

        # Adaptation layers for first attention layer
        self.is_first = is_first
        if is_first:
            self.adap = nn.Linear(num_dim, num_dim)

    def forward(self, source: torch.Tensor, target: torch.Tensor):
        """Perform sparse-to-dense cross or self attention.
        Args:
            * source: The source features (sparse or dense).
            * target: The target features (sparse or dense).
        Returns:
            * sparse_features: The [N x C] sparse descriptors.
        """
        # Sanity checks
        assert source.shape[1] == target.shape[1]
        if self.mode == "sparse_to_dense":
            assert source.ndim == 2 and target.ndim == 4
        elif self.mode == "dense_to_dense":
            assert source.ndim == 4 and target.ndim == 4
        elif self.mode == "dense_to_sparse":
            assert target.ndim == 2 and source.ndim == 4

        # Adap layer for first attention module
        if self.is_first:
            source = self.adap(source)

        # Layer Norm
        source, target = self.ln(source), self.ln(target)

        # Attention + Residual
        source = source + self.merge(self.attention(source, target, target))

        return source
