import torch
import torch.nn as nn
from Emb import MultiEmbedding
from utils import MLP
from basis_layers import rbf_class_mapping
from typing import Final

EPS = 1e-4

class VanillaInputEncoder(nn.Module):

    def __init__(self, featdim, UExpDim, LambdaExpDim, **kwargs) -> None:
        super().__init__()
        self.permutedata = kwargs["permutedata"]
        self.featdim = featdim
        self.xemb = nn.Identity()
        if len(kwargs["xembdims"]) > 0:
            self.xemb = MultiEmbedding(featdim, kwargs["xembdims"],
                                        **kwargs["xemb"])
        self.UEmb = rbf_class_mapping[kwargs["uexp"]](UExpDim,
                                                      **kwargs["basic"],
                                                      **kwargs["uemb"])
        self.LambdaEmb = rbf_class_mapping[kwargs["lexp"]](
            LambdaExpDim, **kwargs["basic"], **kwargs["lambdaemb"])
        self.degreeEmb = MultiEmbedding(
            featdim, [100], **kwargs["xemb"]) if kwargs["degreeemb"] else None
        self.normA = kwargs["normA"]
        self.sizenormU = kwargs["sizenormU"]

    def forward(self, A, X, nodemask):
        D = torch.sum(A, dim=-1)  # (#graph, N)
        if self.normA:
            tD = torch.clamp_min(torch.sum(A, dim=0), 1)  # (# graph, N, N)
            tD = torch.rsqrt_(tD)
            A = torch.einsum("bij,bi,bj->bij", A, tD, tD)
        if self.permutedata and self.training:
            N = A.shape[1]
            perm = torch.randperm(N, device=A.device)
            A = A[:, perm][:, :, perm]
            invperm = torch.empty_like(perm)
            invperm[perm] = torch.arange(N, device=perm.device)
            Lambda, U = torch.linalg.eigh(A)
            U = U[:, invperm]  
        else:
            Lambda, U = torch.linalg.eigh(A)  
            # Lambda (#graph, M) U (#graph, N, M), ignore zero?
            
        X = self.xemb(X)
        if self.degreeEmb is not None:
            X *= self.degreeEmb(D.to(torch.long).unsqueeze(-1))

        Lambdamask = torch.abs(
            Lambda) < EPS  # (#graph, M) # mask zero frequency
        LambdaEmb = self.LambdaEmb(Lambda)  # (#graph, M, d2)
        gsizes = U.shape[1] - nodemask.sum(dim=1)

        if self.sizenormU:
            U *= torch.sqrt(gsizes.unsqueeze(-1).unsqueeze(-1))  # (G, N, M)

        UEmb = self.UEmb(U)
        negUEmb = self.UEmb(-U)
        return LambdaEmb, Lambdamask, UEmb, negUEmb, X, nodemask


class EdgeFeatureInputEncoder(nn.Module):

    def __init__(self, featdim, UExpDim, LambdaExpDim, **kwargs) -> None:
        super().__init__()
        self.xemb = nn.Identity()
        if len(kwargs["xembdims"]) > 0:
            self.xemb = MultiEmbedding(featdim, kwargs["xembdims"],
                                       **kwargs["xemb"])
        self.edge_feature = nn.Embedding(100, UExpDim)
        self.Uexp = rbf_class_mapping[kwargs["uexp"]](UExpDim,
                                                      **kwargs["basic"],
                                                      **kwargs["uemb"])
        self.edgeEmb = nn.Embedding(100, UExpDim, 0)
        self.Umlp = MLP(2 * UExpDim, 2 * UExpDim, UExpDim,
                        kwargs["uemb"]["numlayer"], True, **kwargs["basic"])
        self.LambdaEmb = rbf_class_mapping[kwargs["lexp"]](
            LambdaExpDim, **kwargs["basic"], **kwargs["lambdaemb"])
        self.degreeEmb = MultiEmbedding(
            featdim, [100], **kwargs["xemb"]) if kwargs["degreeemb"] else None
        self.normA = kwargs["normA"]
        self.sizenormU = kwargs["sizenormU"]

    def forward(self, A, X, nodemask):
        D = torch.sum(A, dim=-1)  # (#graph, N)
        if self.normA:
            tD = torch.clamp_min(torch.sum(A, dim=0), 1)  # (# graph, N, N)
            tD = torch.rsqrt_(tD)
            A = torch.einsum("bij,bi,bj->bij", A, tD, tD)
        Lambda, U = torch.linalg.eigh(
            A)  # Lambda (#graph, M) U (#graph, N, M), ignore zero?

        X = self.xemb(X)
        if self.degreeEmb is not None:
            X *= self.degreeEmb(D.to(torch.long).unsqueeze(-1))

        Lambdamask = torch.abs(
            Lambda) < EPS  # (#graph, M) # mask zero frequency
        LambdaEmb = self.LambdaEmb(Lambda)  # (#graph, M, d2)
        gsizes = U.shape[1] - nodemask.sum(dim=1)

        eA = self.edgeEmb(A.to(torch.long))  # (#graph, N, N, ExpDim)
        Ef = torch.einsum("bijd,bjk->bikd", eA, U)

        if self.sizenormU:
            U *= torch.sqrt(gsizes.unsqueeze(-1).unsqueeze(-1))  # (G, N, M)

        UEmb = self.Umlp(torch.concat((self.Uexp(U), Ef), dim=-1))
        negUEmb = self.Umlp(torch.concat((self.Uexp(-U), -Ef), dim=-1))
        return LambdaEmb, Lambdamask, UEmb, negUEmb, X, nodemask


class QInputEncoder(nn.Module):
    LambdaBound: Final[float]
    laplacian: Final[bool]
    def __init__(self, featdim, hiddim, LambdaBound=1e-4, **kwargs) -> None:
        super().__init__()
        self.LambdaBound = LambdaBound
        self.permutedata = kwargs["permutedata"]
        self.featdim = featdim
        self.xemb = nn.Sequential(nn.Linear(featdim, hiddim))
        if len(kwargs["xembdims"]) > 0:
            self.xemb = MultiEmbedding(hiddim, kwargs["xembdims"],
                                        **kwargs["xemb"])
        self.LambdaEmb = rbf_class_mapping[kwargs["lexp"]](
            hiddim, **kwargs["basic"], **kwargs["lambdaemb"])
        self.degreeEmb = MultiEmbedding(
            hiddim, [100], **kwargs["xemb"]) if kwargs["degreeemb"] else None
        self.AEmb = nn.Embedding(20, embedding_dim=hiddim, padding_idx=0)
        self.normA = kwargs["normA"]
        self.sizenormU = kwargs["sizenormU"]
        self.laplacian = kwargs["laplacian"]
        self.decompnoise = kwargs["decompnoise"]

    def setnoiseratio(self, ratio):
        self.decompnoise = ratio

    def forward(self, A, X, nodemask):
        D = torch.sum(A, dim=-1)  # (#graph, N)
        if self.laplacian:
            L = torch.diag_embed(D) - A
        else:
            L = A  # (#graph, N, N)
        if self.normA:
            tD = torch.clamp_min(D, 1)  # (# graph, N, N)
            tD = torch.rsqrt_(tD)
            L = torch.einsum("bij,bi,bj->bij", L, tD, tD)
        if self.permutedata or self.training:
            N = L.shape[1]
            perm = torch.randperm(N, device=L.device)
            L = L[:, perm][:, :, perm]
            invperm = torch.empty_like(perm)
            invperm[perm] = torch.arange(N, device=perm.device)
            Lambda, U = torch.linalg.eigh(L)
            U = U[:, invperm] # (#graph, N, M)
        else:
            Lambda, U = torch.linalg.eigh(L)  
        if self.laplacian:
            Lambda = Lambda[:, 1:]
            U = U[:, :, 1:]
        X = self.xemb(X)
        if self.degreeEmb is not None:
            X *= self.degreeEmb(D.to(torch.long).unsqueeze(-1))
        Lambdamask = torch.abs(Lambda) < self.LambdaBound  # (#graph, M) # mask zero frequency
        if self.laplacian:
            Lambda = torch.sqrt(Lambda)
        if self.training:
            Lambda += self.decompnoise * torch.randn_like(Lambda)
            U += self.decompnoise * torch.randn_like(U)
        LambdaEmb = self.LambdaEmb(Lambda)  # (#graph, M, d2)
        gsizes = U.shape[1] - nodemask.sum(dim=1)
        if self.sizenormU:
            U *= torch.sqrt(gsizes.unsqueeze(-1).unsqueeze(-1))
        U.masked_fill_(Lambdamask.unsqueeze(1), 0)
        U.masked_fill_(nodemask.unsqueeze(-1), 0)
        X.masked_fill_(nodemask.unsqueeze(-1), 0)
        return LambdaEmb, Lambdamask, U, X, nodemask, self.AEmb(torch.round(A).long())


    
input_encoder_dict = {"vanilla": VanillaInputEncoder, "edgefeat": EdgeFeatureInputEncoder, "q": QInputEncoder}


if __name__ == "__main__":
    hiddim = 64
    N = 32
    device = torch.device("cuda")
    x = torch.randn((3, N, hiddim)).to(device)
    kwargs = {
            "permutedata":True,
            "xemb": {
                "orthoinit": True,
                "bn": True,
                "ln": True,
                "dropout": 0.1,
                "lastzeropad": 0,
            },
            "lambdaemb": {
                "numlayer": 2,
                "norm": "ln",
            },
            "sizenormU": False,
            "normA": False,
            "degreeemb": True,
            "lexp": "mlp",
            "xembdims": [1],
            "basic": {
            "dropout": 0.5,
            "activation": nn.ReLU(),
            }
        }
    mask1 = torch.randint(0, 2, (3, N), dtype=torch.bool).to(device)

    A0 = torch.tensor([[0,3,2,0],[3,0, 1, 0],[2,1,0,0],[0,0,0,0]], dtype=torch.float).unsqueeze(0).repeat((2,1,1)).to(device)
    X0 = torch.zeros((2,4,1), dtype=torch.long).to(device)
    mask0 = torch.tensor([0,0,0,1], dtype=torch.bool).unsqueeze(0).repeat((2,1)).to(device)
    mod = QInputEncoder(16, hiddim, 0.001, **kwargs).to(device)
    mod.eval()
    with torch.no_grad():
        LambdaEmb, Lambdamask, U, X, nodemask = mod.forward(A0, X0, mask0)
        # print(LambdaEmb.shape, Lambdamask.shape, U.shape, X.shape, nodemask.shape)
        # print(Lambdamask, U, nodemask)
        '''
        batch test
        '''
        print(torch.max(torch.abs(LambdaEmb[0]-LambdaEmb[1])), torch.any(Lambdamask[0]^Lambdamask[1]), torch.max(torch.abs(U[1].abs()-U[0].abs())), torch.max(torch.abs(X[1]-X[0])), torch.any(nodemask[1]^nodemask[0]))
        '''
        perm test
        '''
        LambdaEmb1, Lambdamask1, U1, X1, nodemask1 = mod.forward(A0, X0, mask0)
        #print(U.abs(), U1.abs(), U.abs()-U1.abs())
        print(torch.max(torch.abs(LambdaEmb-LambdaEmb1)), torch.any(Lambdamask^Lambdamask1), torch.max(torch.abs(U.abs()-U1.abs())), torch.max(torch.abs(X-X1)), torch.any(nodemask^nodemask1))

        '''
        pad test
        '''
        At = torch.tensor([[0,3,2],[3,0, 1],[2,1,0]], dtype=torch.float).unsqueeze(0).repeat((2,1,1)).to(device)
        Xt = torch.zeros((2,3,1), dtype=torch.long).to(device)
        maskt = torch.tensor([0,0,0], dtype=torch.bool).unsqueeze(0).repeat((2,1)).to(device)
        LambdaEmbt, Lambdamaskt, Ut, Xt, nodemaskt = mod.forward(At, Xt, maskt)
        print(torch.max(torch.abs(LambdaEmb[:, 1:]-LambdaEmbt)), torch.any(Lambdamask[:, 1:]^Lambdamaskt), torch.max(torch.abs(U[:, :-1, 1:].abs()-Ut.abs())), torch.max(torch.abs(X[:, :-1]-Xt)), torch.any(nodemask[:,:-1]^nodemaskt))
    
    print(end="", flush=True)
    import os
    os._exit(os.EX_OK)