import nf
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

def exclusive_mean_pooling(x, mask):
    emb = x.sum(-2, keepdims=True)
    y = (emb - x) / mask.sum(-2, keepdim=True)
    return y

def exclusive_max_pooling(x, mask):
    first, second = torch.topk(x, 2, dim=1).values.chunk(2, dim=1)
    check = (x == first).float()
    y = (1 - check) * first + check * second
    return y

class ExclusiveInteractionNet(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers, interaction, **kwargs):
        super().__init__()
        self.interaction = interaction
        self.in_dim = in_dim
        self.set_emb = nf.net.MLP(in_dim, [hidden_dim] * num_layers, out_dim)

    def forward(self, t, x, mask=None, **kwargs):
        if mask is None:
            mask = torch.ones(*x.shape[:-1], 1)
        else:
            mask = mask[...,0,None]

        x = self.set_emb(x) * mask
        if self.interaction == 'sum':
            y = exclusive_mean_pooling(x, mask)
        elif self.interaction == 'max':
            y = exclusive_max_pooling(x, mask)
        y = y.unsqueeze(-2).repeat_interleave(self.in_dim, dim=-2)
        return y

class ExclusiveSetNet(nn.Module):
    """
    Inputs (..., in_dim)
    Outputs (..., in_dim, out_dim)
    """
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers, interaction='sum', **kwargs):
        super().__init__()
        self.out_dim = out_dim
        self.elementwise = nf.net.MADE(in_dim, [hidden_dim] * num_layers, in_dim * out_dim)
        self.interaction = ExclusiveInteractionNet(in_dim, hidden_dim, out_dim, num_layers, interaction)

    def forward(self, t, x, mask=None, **kwargs):
        y1 = self.elementwise(x)
        y1 = y1.reshape(*y1.shape[:-1], self.out_dim, -1).transpose(-1, -2).contiguous()
        y2 = self.interaction(t, x, mask=mask)
        return y1 + y2

class ExclusiveAttentionNet(nn.Module):
    """
    Inputs (..., in_dim)
    Outputs (..., in_dim, out_dim)
    """
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers, n_heads=1, n_points=None, induced=False, **kwargs):
        super().__init__()
        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.emb = nf.net.MADE(in_dim, [hidden_dim] * num_layers, in_dim * out_dim)
        self.att = nf.net.SelfAttentionNet(out_dim, hidden_dim, out_dim, num_layers=num_layers,
                                           n_heads=n_heads, n_points=n_points, induced=induced,
                                           mask_diagonal=False, residual=False, layernorm=False)

    def forward(self, t, x, mask=None, **kwargs):
        h = self.emb(x) # (B, N, D * H)

        h = h.reshape(*x.shape[:-1], self.out_dim, -1).transpose(-1, -2).contiguous() # (B, N, D, H)
        h = h.transpose(1, 2) # (B, D, N, H)
        h = h.reshape(-1, *h.shape[2:]) # (B * D, N, H)

        y = self.att(h, mask=mask) # (B * D, N, H)

        y = y.reshape(x.shape[0], x.shape[-1], x.shape[1], self.out_dim) # (B, D, N, H)
        y = y.transpose(1, 2).contiguous() # (B, N, D, H)

        return y

# class ExclusiveAttentionPoolNet(nn.Module):
#     """
#     Inputs (..., in_dim)
#     Outputs (..., in_dim, out_dim)
#     """
#     def __init__(self, in_dim, hidden_dim, out_dim, num_layers, n_heads=1, n_points=None, induced=False, **kwargs):
#         super().__init__()
#         self.out_dim = out_dim
#         self.hidden_dim = hidden_dim
#         self.att = nf.net.SelfAttentionNet(in_dim, hidden_dim, out_dim, num_layers=num_layers,
#                                            n_heads=n_heads, n_points=n_points, induced=induced,
#                                            mask_diagonal=False, residual=False, layernorm=False)

#     def forward(self, t, x, mask=None, **kwargs):
#         if mask is None:
#             mask = torch.ones(*x.shape[:-1], 1)
#         else:
#             mask = mask[...,0,None]

#         print(x.shape)
#         h = self.att(x) # (B, N, D * H)
#         print(h.shape)
#         return
#         y = exclusive_mean_pooling(h, mask)
#         y = y.unsqueeze(-2).repeat_interleave(x.shape[-1], dim=-2)

#         return y
