import torch
import torch.nn as nn
from utils import MLP
import MaskedReduce


class Set2Set(nn.Module):
    def __init__(self, hiddim: int, combine: str = "mul", aggr: str="sum", res: bool=True,  setdim: int=-2, **mlpargs) -> None:
        super().__init__()
        assert combine in  ["mul", "add"]
        self.mlp1 = MLP(hiddim, hiddim, hiddim, **mlpargs)
        self.mlp2 = MLP(hiddim, hiddim, hiddim, **mlpargs)
        self.setdim = setdim
        self.aggr = MaskedReduce.reduce_dict[aggr]
        self.res = res
        self.combine = combine

    def forward(self, x, mask):
        '''
        x (B, N, d)
        mask (B, N)
        '''
        x1 = self.mlp1(x)
        x1 = self.aggr(x1, mask.unsqueeze(-1), self.setdim).unsqueeze(self.setdim)
        x2 = self.mlp2(x)
        #print(torch.linalg.norm(x1).item(), torch.linalg.norm(x2).item())
        if self.combine == "mul":
            x1 = x1 * x2
        else:
            x1 = x1 + x2
        if self.res:
            x1 += x
        #print(torch.linalg.norm(x1).item())
        return x1

class Set2Vec(nn.Module):
    def __init__(self, hiddim: int, outdim: int, aggr: str="sum", setdim: int=-2, **mlpargs) -> None:
        super().__init__()
        self.mlp1 = MLP(hiddim, outdim, outdim, **mlpargs)
        self.mlp2 = MLP(outdim, outdim, outdim, **mlpargs)
        self.setdim = setdim
        self.aggr = MaskedReduce.reduce_dict[aggr]

    def forward(self, x, mask):
        '''
        x (B, N , d)
        mask (B, N)
        '''
        x1 = self.mlp1(x)
        x1 = self.aggr(x1, mask.unsqueeze(-1), self.setdim)
        return self.mlp2(x1)
    

if __name__ == "__main__":
    hiddim = 64
    N = 32
    mlpargs = {
                "numlayer": 5,
                "norm": "ln",
                "tailact": True,
                "dropout": 0.5,
                "activation": nn.ReLU(inplace=True),
            }
    device = torch.device("cuda")
    x = torch.randn((3, N, hiddim)).to(device)
    mask1 = torch.randint(0, 2, (3, N), dtype=torch.bool).to(device)
    for comb in ["mul", "add"]:
        for aggr in ["max", "sum", "mean"]:
            mod = Set2Set(hiddim, comb, aggr, True, -2, **mlpargs).to(device)
            mod.eval()
            '''
            batch test
            '''
            with torch.no_grad():
                h1 = mod.forward(x[:2], mask1[:2])[1]
                h2 = mod.forward(x[1:], mask1[1:])[0]
                print(torch.max(torch.abs(h1-h2)).item())
            '''
            permute test
            '''
            with torch.no_grad():
                h1 = mod.forward(x, mask1)
                perm = torch.randperm(N)
                h2 = mod.forward(x[:, perm], mask1[:, perm])
                print(torch.max(torch.abs(h1[:, perm]-h2)).item())
            '''
            zero test
            '''
            with torch.no_grad():
                zx = 100*torch.randn_like(x)
                zm = torch.ones_like(mask1)
                h1 = mod.forward(x, mask1)
                perm = torch.randperm(2*N)
                invperm = torch.empty_like(perm)
                invperm[perm] = torch.arange(2*N)
                h2 = mod.forward(torch.concat((x, zx), dim=1)[:, perm], torch.concat((mask1, zm), dim=1)[:, perm])[:, invperm][:,:N]
                print(torch.max(torch.abs(h1-h2)).item())
    for aggr in ["max", "sum", "mean"]:
        mod = Set2Vec(hiddim, hiddim, aggr, -2, **mlpargs).to(device)
        mod.eval()
        '''
        batch test
        '''
        with torch.no_grad():
            h1 = mod.forward(x[:2], mask1[:2])[1]
            h2 = mod.forward(x[1:], mask1[1:])[0]
            print(torch.max(torch.abs(h1-h2)).item())
        '''
        permute test
        '''
        with torch.no_grad():
            h1 = mod.forward(x, mask1)
            perm = torch.randperm(N)
            h2 = mod.forward(x[:, perm], mask1[:, perm])
            print(torch.max(torch.abs(h1-h2)).item())
        '''
        zero test
        '''
        with torch.no_grad():
            zx = 100*torch.randn_like(x)
            zm = torch.ones_like(mask1)
            h1 = mod.forward(x, mask1)
            perm = torch.randperm(2*N)
            invperm = torch.empty_like(perm)
            invperm[perm] = torch.arange(2*N)
            h2 = mod.forward(torch.concat((x, zx), dim=1)[:, perm], torch.concat((mask1, zm), dim=1)[:, perm])
            print(torch.max(torch.abs(h1-h2)).item())
    print(flush=True)
    import os
    os._exit(os.EX_OK)