import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .. import net


def attention(query, key, value, n_heads=1, mask_diagonal=False, mask=None):

    *query_shape, _ = query.shape
    *value_shape, D = value.shape
    query = query.view(*query_shape, n_heads, D // n_heads).transpose(-2, -3)
    value = value.view(*value_shape, n_heads, D // n_heads).transpose(-2, -3)
    key = key.view(*value_shape, n_heads, D // n_heads).transpose(-2, -3)

    att = query @ key.transpose(-1, -2) * (1 / key.shape[-1])**0.5

    if mask_diagonal:
        att.masked_fill_(torch.eye(att.shape[-1]).bool(), -np.inf)
    if mask is not None:
        att_mask = 1 - \
            mask.transpose(-1, -
                           2).unsqueeze(-2).repeat_interleave(att.shape[-2], dim=-2)
        att.masked_fill_(att_mask.bool(), -np.inf)

    att = util.safe_softmax(att, -1)

    y = att @ value

    y = y.transpose(-2, -3).reshape(*query_shape, -1)

    if mask is not None and y.shape[-2] == mask.shape[-2]:
        y = y * mask
    return y


class Attention(nn.Module):
   
    def __init__(self, in_dim, hidden_dims, out_dim, n_heads=1, mask_diagonal=False, **kwargs):
        super().__init__()

        self.mask_diagonal = mask_diagonal
        self.n_heads = n_heads

        self.key = net.MLP(in_dim, hidden_dims[:-1], hidden_dims[-1])
        self.query = net.MLP(in_dim, hidden_dims[:-1], hidden_dims[-1])
        self.value = net.MLP(in_dim, hidden_dims[:-1], hidden_dims[-1])
        self.proj = nn.Linear(hidden_dims[-1], out_dim)

    def forward(self, query, key, value, mask=None, **kwargs):
        """
        Args:
            query (tensor): Queries with shape (..., N, in_dim)
            key (tensor): Keys with shape (..., N, in_dim)
            value (tensor): Values with shape (..., N, in_dim)
            mask (tensor, optional): Queries with shape (..., N, 1)

        Returns:
            y (tensor): Outputs with shape (..., N, out_dim)
        """
        y = attention(self.query(query), self.key(key), self.value(
            value), self.n_heads, self.mask_diagonal, mask)
        return self.proj(y)


class SelfAttention(Attention):
    """
    Self attention layer.
    Same as `st.net.Attention` but `forward` takes only `x` and `mask`.

    Args:
        Same as in `st.net.Attention`
    """

    def __init__(self, in_dim, hidden_dim, out_dim, n_heads=1, mask_diagonal=False, **kwargs):
        super().__init__(in_dim, hidden_dim, out_dim, n_heads, mask_diagonal)

    def forward(self, x, mask=None, **kwargs):
        """
        Args:
            x (tensor): Input, acts as query, key and value, with shape (..., N, in_dim)
            mask (tensor, optional): Queries with shape (..., N, 1)

        Returns:
            y (tensor): Outputs with shape (..., N, out_dim)
        """
        return super().forward(x, x, x, mask=mask)


class InducedSelfAttention(nn.Module):
   

    def __init__(self, in_dim, hidden_dim, out_dim, n_heads=1, n_points=32, **kwargs):
        super().__init__()

        self.att1 = Attention(in_dim, hidden_dim, in_dim, n_heads)
        self.att2 = Attention(in_dim, hidden_dim, out_dim, n_heads)
        self.points = nn.Parameter(torch.empty(
            n_points, in_dim).uniform_(-1., 1.))

    def forward(self, x, mask=None, **kwargs):
        h = self.points[(None,) * (len(x.shape) - 2)
                        ].repeat(*x.shape[:-2], 1, 1)
        h = self.att1(h, x, x, mask=mask, **kwargs)
        y = self.att2(x * (1 if mask is None else mask), h, h, **kwargs)
        return y
