import torch
import torch.nn as nn
from utils import MLP
from InputEncoder import QInputEncoder
from PermEquiLayer import PermEquiLayer
from MaskedReduce import reduce_dict
from typing import Final

EPS = 1e-3

class QGNNLF(nn.Module):
    num_layers: Final[int]
    num_tasks: Final[int]
    nodetask: Final[bool]
    def __init__(self,
                 featdim: int,
                 caldim: int,
                 hiddim: int,
                 outdim: int,
                 input_encoder: str,
                 lambda_encoder: str,
                 num_layers: int,
                 pool: str,
                 **kwargs) -> None:
        super().__init__()
        self.num_layers = num_layers
        self.num_tasks = outdim
        self.nodetask = (pool=="none")
        self.pool = reduce_dict[pool]

        self.inputencoder = QInputEncoder(featdim, hiddim,
                                          **kwargs["inputencoder"])
        self.LambdaEncoder = PermEquiLayer(hiddim, hiddim, lambda_encoder,
                                           False, **kwargs["l_model"])

        self.elprojs = nn.ModuleList(
            [sv2el(hiddim, caldim, **kwargs["elproj"]) for _ in range(num_layers)]
        )
        
        self.svmixs = nn.ModuleList(
            [svMix(hiddim, **kwargs["svmix"]) for _ in range(num_layers)]
        )
        
        self.convs = nn.ModuleList(
            [DirCFConv(hiddim, **kwargs["conv"]) for _ in range(num_layers)]
        )

        self.predlin = MLP(hiddim, hiddim, outdim,
                           **kwargs["predlin"])
        

    def eigenforward(self, LambdaEmb, LambdaMask, U, X, nodemask, debug: str="1"):
        '''
        LambdaEmb (#graph, M, d1)
        LambdaMask (#graph, M)
        U (#graph, N, M)
        X (#graph, N, dx)
        nodemask (#graph, N)
        '''
        # B, N, M = U.shape[0], U.shape[1], U.shape[2]
        # NMmask = torch.logical_or(nodemask.unsqueeze(-1), LambdaMask.unsqueeze(-2))
        # NNmask = torch.logical_or(nodemask.unsqueeze(-1), nodemask.unsqueeze(-2))
        # torch.save(U[torch.logical_not(NMmask)], f"tmp/U_{debug}.pt")
        LambdaEmb = self.LambdaEncoder(
            LambdaEmb, LambdaMask)  # LambdaEmb (#graph, M, d1)
        # torch.save(LambdaEmb[torch.logical_not(LambdaMask)], f"tmp/lambdaemb_{debug}.pt")
        LambdaEmb = torch.where(LambdaMask.unsqueeze(-1), 0, LambdaEmb)
        coord = torch.einsum("bnm,bmd->bnmd", U, LambdaEmb)  # (#graph, N, M, d)
        # torch.save(coord[torch.logical_not(NMmask)], f"tmp/coord_{debug}.pt")
        # print(LambdaMask[0])
        # print(torch.max(torch.abs(coord[nodemask.unsqueeze(-1).unsqueeze(-1).expand_as(coord)])))
        # print(torch.max(torch.abs(coord[LambdaMask.unsqueeze(-2).unsqueeze(-1).expand_as(coord)])))
        nnfilter = torch.logical_not(torch.logical_or(nodemask.unsqueeze(-1), nodemask.unsqueeze(1))).float().unsqueeze(-1)
        # print(nnmask.shape, el.shape)
        el = self.elprojs[0](X, coord) * nnfilter #torch.where(nnmask, 0, self.elprojs[0](X, coord))
        # torch.save(el[torch.logical_not(NNmask)], f"tmp/iel_{debug}.pt")
        # print(torch.max(torch.abs(el[nmask1.expand_as(el)])))
        # print(torch.max(torch.abs(el[nmask2.expand_as(el)])))
        ts, tv = self.convs[0](X, coord, el)
        ts, tv = self.svmixs[0](ts, tv)
        # torch.save((ts[torch.logical_not(nodemask)], tv[torch.logical_not(NMmask)]), f"tmp/tsv_{0}_{debug}.pt")
        X = X + ts
        coord = coord + tv
        # torch.save((X[torch.logical_not(nodemask)], coord[torch.logical_not(NMmask)]), f"tmp/xc_{0}_{debug}.pt")
        for i in range(1, self.num_layers):
            el += nnfilter*self.elprojs[i](X, coord)
            #torch.save(el[torch.logical_not(NNmask)], f"tmp/iel_{i}_{debug}.pt")
            # print(torch.max(torch.abs(el[nmask1.expand_as(el)])))
            # print(torch.max(torch.abs(el[nmask2.expand_as(el)])))
            ts, tv = self.convs[i](X, coord, el)
            ts, tv = self.svmixs[i](ts, tv)
            #torch.save((ts[torch.logical_not(nodemask)], tv[torch.logical_not(NMmask)]), f"tmp/tsv_{i}_{debug}.pt")
            X = X + ts
            coord = coord + tv
            #torch.save((X[torch.logical_not(nodemask)], coord[torch.logical_not(NMmask)]), f"tmp/xc_{i}_{debug}.pt")
        # print(torch.max(torch.abs(X[nodemask])))
        # print(torch.max(torch.abs(coord[nodemask])))
        # print(torch.max(torch.abs(coord[LambdaMask.unsqueeze(-2).unsqueeze(-1).expand_as(coord)])))
        X = self.pool(X, nodemask.unsqueeze(-1), 1)
        return self.predlin(X)

    def forward(self, A, X, nodemask, debug: str="1"):
        '''
        A (#graph, N, N)
        X (#graph, N, d)
        nodemask (#graph, N)
        '''
        if not self.nodetask:
            pred = self.eigenforward(*self.inputencoder(A, X, nodemask), debug)
            return pred
        else:
            pred = self.eigenforward(*self.inputencoder(A, X, nodemask), debug)[
                torch.logical_not(nodemask)]
            return pred


class DirCFConv(nn.Module):

    def __init__(self, hiddim, uselinv=True, **kwargs):
        super().__init__()
        self.lins = MLP(hiddim, hiddim, hiddim, **kwargs)
        self.linv = nn.Linear(hiddim, hiddim, bias=False) if uselinv else nn.Identity()

    def forward(self, s, v, el):
        '''
        x (B, N, d)
        v (B, N, M, d)
        el (B, N, N, d)
        '''
        s = torch.einsum("bijd,bjd->bid", el, self.lins(s))
        v = torch.einsum("bijd,bjmd->bimd", el, self.linv(v))
        return s, v


class svMix(nn.Module):

    res: Final[bool]

    def __init__(self, hiddim, uselinv=True, res=True, boostsv=False, **kwargs) -> None:
        super().__init__()
        self.linv1 = nn.Linear(hiddim, hiddim, bias=False) if uselinv else nn.Identity()
        self.linv2 = nn.Linear(hiddim, hiddim, bias=False) if uselinv else nn.Identity()
        self.linv3 = nn.Linear(hiddim, hiddim, bias=False) if boostsv and uselinv else nn.Identity()
        self.lins1 = MLP(hiddim, hiddim, hiddim, **kwargs)
        self.lins2 = MLP(hiddim, hiddim, hiddim, **kwargs)
        self.lins3 = MLP(hiddim, hiddim, hiddim, **kwargs) if boostsv else nn.Identity()
        self.res = res

    def forward(self, s, v):
        '''
        s (B, N, d)
        v (B, N, M, d)
        keep zero
        '''
        vprod = self.lins3(torch.einsum("bnmd,bnmd->bnd", self.linv1(v), self.linv2(v)))
        if self.res:
            ts = s + self.lins1(s) * vprod  # (B, N, d)
            tv = v + torch.einsum("bnmd,bnd->bnmd", self.linv3(v), self.lins2(s))
        else:
            ts = self.lins1(s) * vprod  # (B, N, d)
            tv = torch.einsum("bnmd,bnd->bnmd", self.linv3(v), self.lins2(s))
        return ts, tv


class sv2el(nn.Module):
    uses: Final[bool]
    def __init__(self, indim, hiddim, uselinv=True, uselins=True, uses=True, **kwargs) -> None:
        super().__init__()
        self.linv1 = nn.Linear(indim, hiddim, bias=False) if (uselinv or indim!=hiddim) else nn.Identity()
        self.linv2 = nn.Linear(indim, hiddim, bias=False) if (uselinv or indim!=hiddim) else nn.Identity()
        self.lins1 = nn.Linear(indim, hiddim) if (uselins or indim!=hiddim)  else nn.Identity()
        self.lins2 = nn.Linear(indim, hiddim) if (uselins or indim!=hiddim) else nn.Identity()
        self.lin = MLP(hiddim, hiddim, hiddim, **kwargs)
        self.uses = uses

    def forward(self, s, v1, v0):
        '''
        s (b, n, d)
        v (b, n, m, d)
        '''
        if self.uses:
            ret = self.lin(
                torch.einsum("bid,bjd,bimd,bjmd->bijd",
                            self.lins1(s), self.lins2(s), self.linv1(v1),
                            self.linv2(v0))) 
        else:
            ret = self.lin(torch.einsum("bimd,bjmd->bijd",
                            self.linv1(v1), self.linv2(v0))) 
        # print(torch.linalg.norm(v1).item(), torch.linalg.norm(ret).item(), torch.linalg.norm(self.linv1.weight).item(), torch.linalg.norm(self.linv2.weight).item())
        return ret