from Deepset import Set2Set, Set2Vec
import torch
import torch.nn as nn
from typing import Callable
from torch_geometric.nn import Sequential as PygSequential
from utils import MLP


class PermEquiLayer(nn.Module):

    def __init__(self, hiddim: int, outdim: int, set2set: str, invout: bool,
                 numlayers: int, **kwargs) -> None:
        super().__init__()
        assert set2set in ["deepset", "transformer"]
        if set2set == "deepset":
            self.set2set = PygSequential(
                "x, mask",
                [(Set2Set(hiddim,
                          kwargs["combine"],
                          kwargs["aggr"],
                          res=kwargs["res"],
                          **kwargs["mlpargs1"]), "x, mask -> x")
                 for _ in range(numlayers)] + [(nn.Identity(), "x -> x")])
        elif set2set == "transformer":
            raise NotImplementedError 
        if invout:
            self.set2vec = Set2Vec(hiddim,
                                   outdim,
                                   aggr=kwargs["pool"],
                                   **kwargs["mlpargs2"])
        else:
            self.set2vec = PygSequential(
                "x, mask", [(MLP(hiddim, outdim, outdim, **kwargs["mlpargs2"]), "x->x")])

    def forward(self, x, mask):
        '''
        x (B, N, d)
        mask (B, N)
        '''
        #print(self.set2set)
        #print(torch.linalg.norm(x).item())
        x = self.set2set(x, mask)
        #print(torch.linalg.norm(x).item())
        x = self.set2vec(x, mask)
        #print(torch.linalg.norm(x).item())
        return x

if __name__ == "__main__":
    hiddim = 64
    N = 32
    paramdict = {
            "set2set": "deepset",
            "outdim": hiddim,
            "numlayers": 3,
            "nhead": 1,
            "dffn": hiddim,
            "norm_first": True,
            "aggr": "sum",
            "combine": "mul",
            "res": True,
            "mlpargs1": {
                "numlayer": 5,
                "norm": "ln",
                "tailact": True,
                "dropout": 0.5,
                "activation": nn.ReLU(inplace=True),
            },
            "mlpargs2": {
                "numlayer": 5,
                "norm": "ln",
                "tailact": True,
                "dropout": 0.5,
                "activation": nn.ReLU(inplace=True),
            },
            "invout": False,
            "pool": "sum"
        }
    device = torch.device("cuda")
    x = torch.randn((3, N, hiddim)).to(device)
    mask1 = torch.randint(0, 2, (3, N), dtype=torch.bool).to(device)
    mod = PermEquiLayer(hiddim, **paramdict).to(device)
    mod.eval()
    '''
    batch test
    '''
    with torch.no_grad():
        h1 = mod.forward(x[:2], mask1[:2])[1]
        h2 = mod.forward(x[1:], mask1[1:])[0]
        print(torch.max(torch.abs(h1-h2)).item())
    '''
    permute test
    '''
    with torch.no_grad():
        h1 = mod.forward(x, mask1)
        perm = torch.randperm(N)
        h2 = mod.forward(x[:, perm], mask1[:, perm])
        print(torch.max(torch.abs(h1[:, perm]-h2)).item())
    '''
    zero test
    '''
    with torch.no_grad():
        zx = 100*torch.randn_like(x)
        zm = torch.ones_like(mask1)
        h1 = mod.forward(x, mask1)
        perm = torch.randperm(2*N)
        invperm = torch.empty_like(perm)
        invperm[perm] = torch.arange(2*N)
        h2 = mod.forward(torch.concat((x, zx), dim=1)[:, perm], torch.concat((mask1, zm), dim=1)[:, perm])[:, invperm][:,:N]
        print(torch.max(torch.abs(h1-h2)).item())
    

    paramdict['invout']=True
    mod = PermEquiLayer(hiddim, **paramdict).to(device)
    mod.eval()
    '''
    batch test
    '''
    with torch.no_grad():
        h1 = mod.forward(x[:2], mask1[:2])[1]
        h2 = mod.forward(x[1:], mask1[1:])[0]
        print(torch.max(torch.abs(h1-h2)).item())
    '''
    permute test
    '''
    with torch.no_grad():
        h1 = mod.forward(x, mask1)
        perm = torch.randperm(N)
        h2 = mod.forward(x[:, perm], mask1[:, perm])
        print(torch.max(torch.abs(h1-h2)).item())
    '''
    zero test
    '''
    with torch.no_grad():
        zx = 100*torch.randn_like(x)
        zm = torch.ones_like(mask1)
        h1 = mod.forward(x, mask1)
        perm = torch.randperm(2*N)
        invperm = torch.empty_like(perm)
        invperm[perm] = torch.arange(2*N)
        h2 = mod.forward(torch.concat((x, zx), dim=1)[:, perm], torch.concat((mask1, zm), dim=1)[:, perm])
        print(torch.max(torch.abs(h1-h2)).item())
    print(flush=True)
    import os
    os._exit(os.EX_OK)