import nf
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, n_heads=1, mask_diagonal=False, **kwargs):
        super().__init__()

        self.mask_diagonal = mask_diagonal
        self._shape = (n_heads, hidden_dim // n_heads)

        self.key = nn.Linear(in_dim, hidden_dim)
        self.query = nn.Linear(in_dim, hidden_dim)
        self.value = nn.Linear(in_dim, hidden_dim)
        self.proj = nn.Linear(hidden_dim, out_dim)

    def forward(self, query, key, value, mask=None, **kwargs):
        if mask is None:
            mask = torch.ones(*value.shape[:-1], 1)
        else:
            mask = mask[...,0,None]

        q_shape = query.shape
        query = self.query(query).view(*query.shape[:-1], *self._shape).transpose(1, 2)
        key = self.key(key).view(*key.shape[:-1], *self._shape).transpose(1, 2)
        value = self.value(value).view(*value.shape[:-1], *self._shape).transpose(1, 2)

        att = query @ key.transpose(-2, -1) * (1 / key.shape[-1])**0.5

        if self.mask_diagonal:
            att.masked_fill_(torch.eye(att.shape[-1]).bool(), -np.inf)

        att_mask = 1 - mask.transpose(-1, -2).unsqueeze(-2).repeat_interleave(att.shape[-2], dim=-2)
        att.masked_fill_(att_mask.bool(), -1e8) # avoids whole row -inf

        att = F.softmax(att, -1)

        y = att @ value

        y = y.transpose(1, 2).reshape(*q_shape[:-1], -1)
        y = self.proj(y)
        return y

class SelfAttention(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, n_heads=1, mask_diagonal=False, **kwargs):
        super().__init__()
        self.attention = Attention(in_dim, hidden_dim, out_dim, n_heads, mask_diagonal)

    def forward(self, x, mask=None, **kwargs):
        y = self.attention(x, x, x, mask=mask, **kwargs)
        return y

class InducedSelfAttention(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, n_heads=1, n_points=32, **kwargs):
        super().__init__()

        self.att1 = Attention(in_dim, hidden_dim, in_dim, n_heads)
        self.att2 = Attention(in_dim, hidden_dim, out_dim, n_heads)
        self.points = nn.Parameter(torch.Tensor(1, n_points, in_dim).uniform_(-1., 1.))

    def forward(self, x, mask=None, **kwargs):
        points = self.points.repeat(x.shape[0], 1, 1)
        points = self.att1(points, x, x, mask=mask, **kwargs)
        y = self.att2(x, points, points, **kwargs)
        return y

class SelfAttentionNet(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, n_heads=1, n_points=None,
                 num_layers=None, induced=False, mask_diagonal=False,
                 residual=False, layernorm=False, **kwargs):
        super().__init__()
        self.num_layers = num_layers
        self.residual = residual
        self.layernorm = layernorm

        att_net = InducedSelfAttention if induced else SelfAttention

        self.emb = nn.Linear(in_dim, hidden_dim)
        self.ff = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)])
        self.att = nn.ModuleList([
            att_net(hidden_dim, hidden_dim, hidden_dim,
                    n_heads=n_heads, n_points=n_points, mask_diagonal=mask_diagonal)
                    for _ in range(num_layers)
        ])

        if layernorm:
            self.ln = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers * 2)])
        self.proj = nn.Linear(hidden_dim, out_dim)
        self.gate = nf.net.MLP(in_dim, [hidden_dim], out_dim, final_activation='Sigmoid')
        self.elementwise = nf.net.MLP(in_dim, [hidden_dim], out_dim)

    def forward(self, x, mask=None, **kwargs):
        y = self.emb(x)

        for i in range(self.num_layers):
            h = y + self.att[i](y)
            if self.layernorm:
                h = self.ln[i*2](h)

            y = self.ff[i](h)
            if self.residual:
                y = h + y
            if self.layernorm:
                y = self.ln[i*2+1](y)
        y = self.proj(y)

        g = self.gate(x)
        y_ = self.elementwise(x)
        y = y_ * g + (1 - g) * y

        return y
