import torch
from typing import Final
import torch.nn as nn
from MaskedReduce import reduce_dict
from InputEncoder import QInputEncoder
import PermEquiLayer
from QGNNLF import sv2el,svMix,DirCFConv,QGNNLF
from utils import MLP
from ONDeepset import SimGlobalAggr, SimTprod
import torch.nn.functional as F


class VNorm(nn.Module):

    def __init__(self, hiddim, elementwise_affine: bool=False) -> None:
        super().__init__()
        assert not elementwise_affine

    def forward(self, v):
        '''
        v (*, m, d)
        '''
        v = F.normalize(v, dim=-2, eps=1e-3)
        return v
    

class VMean(nn.Module):

    def __init__(self, hiddim, elementwise_affine: bool=False) -> None:
        super().__init__()
        assert not elementwise_affine

    def forward(self, v):
        '''
        v (*, m, d)
        '''
        v = v - torch.mean(v, dim=-1, keepdim=True)
        return v
    
class Imod(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def forward(self, *args):
        return args

class PiOModel(nn.Module):
    elres: Final[bool]
    tgres: Final[bool]
    usetg: Final[bool]
    num_layers: Final[int]
    num_tasks: Final[int]
    nodetask: Final[bool]
    gsizenorm: Final[float]
    lsizenorm: Final[float]
    def __init__(self,
                 featdim: int,
                 caldim: int,
                 hiddim: int,
                 outdim: int,
                 lambda_encoder: str,
                 num_layers: int,
                 pool: str,
                 **kwargs) -> None:
        super().__init__()
        self.num_layers = num_layers
        self.num_tasks = outdim
        self.elres = kwargs["elres"]
        self.nodetask = (pool=="none")
        self.pool = reduce_dict[pool]
        self.usetg = kwargs["usetg"]
        self.tgres = kwargs["tgres"] and self.usetg
        usesvmix = kwargs["usesvmix"]

        self.inputencoder = QInputEncoder(featdim, hiddim,
                                          **kwargs["inputencoder"])
        self.LambdaEncoder = PermEquiLayer.PermEquiLayer(hiddim, hiddim, "deepset",
                                           False, **kwargs["l_model"])

        self.elprojs = nn.ModuleList(
            [sv2el(hiddim, caldim, **kwargs["elproj"]) for _ in range(num_layers)]
        )
        
        self.svmixs = nn.ModuleList(
            [svMix(hiddim, **kwargs["svmix"]) if usesvmix else Imod() for _ in range(num_layers)]
        )
        
        self.convs = nn.ModuleList(
            [DirCFConv(hiddim, **kwargs["conv"]) for _ in range(num_layers)]
        )

        self.predlin = MLP(hiddim, hiddim, outdim, **kwargs["predlin"])
        
        self.gaggrs = nn.ModuleList(
            [SimGlobalAggr(hiddim, **kwargs["gaggr"]) for _ in range(num_layers)]
        )

        self.tprods = nn.ModuleList(
            [SimTprod(hiddim, **kwargs["tprod"]) for _ in range(num_layers)]
        )

        self.vln = nn.Sequential(VMean(hiddim) if kwargs["vmean"] else nn.Identity(), VNorm(hiddim) if kwargs["vnorm"] else nn.Identity())
        self.elvln = nn.Sequential(VMean(hiddim) if kwargs["elvmean"] else nn.Identity(), VNorm(hiddim) if kwargs["elvnorm"] else nn.Identity())
        self.tgvln = nn.Sequential(VMean(hiddim) if kwargs["tgvmean"] else nn.Identity(), VNorm(hiddim) if kwargs["tgvnorm"] else nn.Identity())
        self.sln = nn.LayerNorm(hiddim, elementwise_affine=False) if kwargs["snorm"] else nn.Identity()
        self.gsizenorm = kwargs["gsizenorm"]
        self.lsizenorm = kwargs["lsizenorm"]

    def eigenforward(self, LambdaEmb, LambdaMask, U, X, nodemask, A, debug: str="1"):
        '''
        LambdaEmb (#graph, M, d1)
        LambdaMask (#graph, M)
        U (#graph, N, M)
        X (#graph, N, dx)
        nodemask (#graph, N)
        A (#graph, N, N, A)
        '''
        B, N, M = U.shape[0], U.shape[1], U.shape[2]
        gsize = N - torch.sum(nodemask.float(), dim=1)
        gsizenorm = torch.rsqrt_(gsize).pow_(self.gsizenorm).reshape(-1, 1, 1, 1)
        lsize = M - torch.sum(LambdaMask.float(), dim=1)
        lsizenorm = torch.rsqrt_(lsize).pow_(self.lsizenorm).reshape(-1, 1, 1, 1)
        LambdaEmb = self.LambdaEncoder(LambdaEmb, LambdaMask)  # LambdaEmb (#graph, M, d1)
        LambdaEmb = torch.where(LambdaMask.unsqueeze(-1), 0, LambdaEmb)
        coord = torch.einsum("bnmd,bmd->bnmd", U, LambdaEmb)  # (#graph, N, M, d)
        nnfilter = torch.logical_not(torch.logical_or(nodemask.unsqueeze(-1), nodemask.unsqueeze(1))).float().unsqueeze(-1)
        elvlncoord = self.elvln(coord)
        el = self.elprojs[0](X, elvlncoord, elvlncoord) * (gsizenorm * nnfilter)# + A
        ts, tv = self.svmixs[0](self.sln(X), self.vln(coord))
        ts1, tv1 = self.convs[0](ts, tv, el)
        if self.usetg:
            tgvlncoord = self.tgvln(coord)
            gv2 = self.gaggrs[0](tgvlncoord, tgvlncoord) * lsizenorm
            tv2 = self.tprods[0].forward(tv, gv2)
            coord = coord + tv1 + tv2
        else:
            coord = coord + tv1
        X = X + ts1
        for i in range(1, self.num_layers):
            if self.elres:
                elvlncoord = self.elvln(coord)
                el = el + nnfilter*self.elprojs[i](X, elvlncoord, elvlncoord) * (gsizenorm * nnfilter)
            if self.tgres:
                tgvlncoord = self.tgvln(coord)
                gv2 = gv2 + self.gaggrs[i](tgvlncoord, tgvlncoord) * lsizenorm
            ts, tv = self.svmixs[i](self.sln(X), self.vln(coord))
            
            ts1, tv1 = self.convs[i](ts, tv, el)
            if self.usetg:
                tv2 = self.tprods[i].forward(tv, gv2)
                coord = coord + tv1 + tv2
            else:
                coord = coord + tv1
            X = X + ts1            
        if self.nodetask:
            X = X
        else:
            X = self.pool(X, nodemask.unsqueeze(-1), 1)
        return self.predlin(X)

    def forward(self, A, X, nodemask, debug: str="1"):
        '''
        A (#graph, N, N)
        X (#graph, N, d)
        nodemask (#graph, N)
        '''
        pred = self.eigenforward(*self.inputencoder(A, X, nodemask), debug)
        if self.nodetask:
            pred = pred[torch.logical_not(nodemask)]
        return pred
