import torch
import torch.nn as nn
from utils import MLP
from InputEncoder import QInputEncoder
from PermEquiLayer import PermEquiLayer
from MaskedReduce import reduce_dict
from typing import Final
from QGNNLF import svMix

EPS = 1e-3

class ONDeepset(nn.Module):
    num_layers: Final[int]
    num_tasks: Final[int]
    nodetask: Final[bool]
    def __init__(self,
                 featdim: int,
                 hiddim: int,
                 outdim: int,
                 lambda_encoder: str,
                 num_layers: int,
                 pool: str,
                 **kwargs) -> None:
        super().__init__()
        self.num_layers = num_layers
        self.num_tasks = outdim
        self.nodetask = (pool=="none")
        self.pool = reduce_dict[pool]

        self.inputencoder = QInputEncoder(featdim, hiddim,
                                          **kwargs["inputencoder"])
        self.LambdaEncoder = PermEquiLayer(hiddim, hiddim, lambda_encoder,
                                           False, **kwargs["l_model"])
        
        self.svmixs = nn.ModuleList(
            [svMix(hiddim, **kwargs["svmix"]) for _ in range(num_layers)]
        )

        self.gaggrs = nn.ModuleList(
            [GlobalAggr(hiddim, **kwargs["gaggr"]) for _ in range(num_layers)]
        )

        self.tprods = nn.ModuleList(
            [Tprod(hiddim, **kwargs["tprod"]) for _ in range(num_layers)]
        )

        self.predlin = MLP(hiddim, hiddim, outdim,
                           **kwargs["predlin"])
        

    def eigenforward(self, LambdaEmb, LambdaMask, U, X, nodemask, debug: str="1"):
        '''
        LambdaEmb (#graph, M, d1)
        LambdaMask (#graph, M)
        U (#graph, N, M)
        X (#graph, N, dx)
        nodemask (#graph, N)
        '''
        B, N, M = U.shape[0], U.shape[1], U.shape[2]
        gsize = N - torch.sum(nodemask.float(), dim=1)
        lsize = M - torch.sum(LambdaMask.float(), dim=1)
        LambdaEmb = self.LambdaEncoder(LambdaEmb, LambdaMask)  # LambdaEmb (#graph, M, d1)
        LambdaEmb = torch.masked_fill(LambdaEmb, LambdaMask.unsqueeze(-1), 0)
        coord = torch.einsum("bnm,bmd->bnmd", U, LambdaEmb)  # (#graph, N, M, d)
        ts, tv = X, coord
        ts, tv = self.svmixs[0](ts, tv)
        X = X + ts
        coord = coord + tv
        gs, gv, gv2 = self.gaggrs[0](X, coord, nodemask, gsize, lsize)
        for i in range(1, self.num_layers):
            ts, tv = self.tprods[i-1].forward(X, coord, nodemask, lsize, gs, gv, gv2)
            ts, tv = self.svmixs[i](ts, tv)
            X = X + ts
            X[nodemask] = 0
            coord = coord + tv
            tgs, tgv, tgv2 = self.gaggrs[i](X, coord, nodemask, gsize, lsize)
            gs = gs + tgs
            gv = gv + tgv
            gv2 = gv2 + tgv2
        if self.nodetask:
            X = X[torch.logical_not(nodemask)]
        else:
            X = gs
        return self.predlin(X)

    def forward(self, A, X, nodemask, debug: str="1"):
        '''
        A (#graph, N, N)
        X (#graph, N, d)
        nodemask (#graph, N)
        '''
        pred = self.eigenforward(*self.inputencoder(A, X, nodemask), debug)
        return pred


class VLinear(nn.Linear):
    def __init__(self, indim, outdim):
        super().__init__(indim, outdim, False)
    
    def forward(self, args):
        return super().forward(args)

    
class GlobalAggr(nn.Module):
    def __init__(self, hiddim, **kwargs) -> None:
        super().__init__()
        self.scalar = PermEquiLayer(hiddim, hiddim, "deepset", True, **kwargs["permlayer"])
        self.linv1 = VLinear(hiddim, hiddim)
        self.linv2 = VLinear(hiddim, hiddim)
        self.isreduce = kwargs["isreduce"]

    def forward(self, s, v, nodemask, gsize, lsize):
        '''
        s (B, N, d)
        v (B, N, M, d)
        nodemask (B, N)
        gsize (B, )
        return (B, d), (B, M, d), (B, M, M, d)
        '''
        gs = self.scalar.forward(s, nodemask)
        gv = v.sum(dim=1)
        gv2 = torch.einsum("bnad,bncd->bacd", self.linv1(v), self.linv2(v))
        if self.isreduce:
            # print(gv.shape, gsize.shape)
            gv *= 1/gsize.reshape(-1, 1, 1)
            gv2 *= 1/gsize.reshape(-1, 1, 1, 1)
        return gs, gv, gv2
    

class Tprod(nn.Module):
    def __init__(self, hiddim, **kwargs) -> None:
        super().__init__()
        self.lin1 = nn.Linear(hiddim, hiddim)
        self.lin2 = nn.Linear(hiddim, hiddim)
        self.lin3 = nn.Linear(hiddim, hiddim)
        self.lins = MLP(hiddim, hiddim, hiddim, **kwargs["mlp"])
        self.linv1 = VLinear(hiddim, hiddim)
        self.linv2 = VLinear(hiddim, hiddim)
        self.linv3 = VLinear(hiddim, hiddim)

    def forward(self, s, v, gs, gv, gv2):
        '''
        s (B, N, d)
        v (B, N, M, d)
        nodemask (B, N)
        return (B, d), (B, M, d), (B, M, M, d)
        '''
        v_vv2 = self.linv1(torch.einsum("bnmd,bmcd->bncd", v, gv2))
        v_sv = self.linv2(torch.einsum("bnd,bmd->bnmd", s, gv))
        v_vs = self.linv3(gs.unsqueeze(-2).unsqueeze(-2) * v)
        # (torch.einsum("bd,bnmd->bnmd", gs, v), lsize)
        v = v_vv2 + v_sv + v_vs
        s_v2t = torch.einsum("bmmd->bd", gv2)
        s_vv = torch.einsum("bmd,bnmd->bnd", gv, v)
        s = (self.lin1(s_v2t)*gs).unsqueeze(-2)*self.lin2(s)*self.lin3(s_vv)
        s = self.lins(s)
        return s, v

    
class SimGlobalAggr(nn.Module):
    def __init__(self, hiddim, **kwargs) -> None:
        super().__init__()
        self.linv1 = VLinear(hiddim, hiddim)
        self.linv2 = VLinear(hiddim, hiddim)

    def forward(self, v1, v2):
        '''
        s (B, N, d)
        v (B, N, M, d)
        nodemask (B, N)
        gsize (B, )
        return (B, d), (B, M, d), (B, M, M, d)
        '''
        gv2 = torch.einsum("bnad,bncd->bacd", self.linv1(v1), self.linv2(v2))
        return gv2
    

class SimTprod(nn.Module):
    def __init__(self, hiddim, **kwargs) -> None:
        super().__init__()
        self.linv1 = VLinear(hiddim, hiddim)
        self.linv2 = VLinear(hiddim, hiddim)

    def forward(self, v, gv2):
        '''
        s (B, N, d)
        v (B, N, M, d)
        nodemask (B, N)
        return (B, d), (B, M, d), (B, M, M, d)
        '''
        v_vv2 = torch.einsum("bnmd,bmcd->bncd", self.linv1(v), gv2)
        return v_vv2