# Encode eigenvalues and eigenvectors.

import torch.nn as nn
from torch_geometric.utils import scatter, to_dense_batch, to_dense_adj
from .clifford.dataprocess import process_data
from .clifford.layers import *
from .clifford.modules import *
from .clifford.infrastructure import CliffordAlgebra
from .clifford.decoder import MessagePassingDecoder as GINEncoder
from .utils import *
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_node_encoder
import logging
import torch.nn.functional as F
from torch_geometric.utils import scatter, to_dense_batch
from typing import Literal, Callable, Tuple
from pygho.honn.TensorOp import OpMessagePassing
from pygho import SparseTensor, MaskedTensor
from pygho.backend.Spspmm import spspmm_ind, filterind
from pygho.hodata.SpData import KEYSEP, SpHoData
from torch import Tensor
import torch
from .Deepset import Set2Set

class Encoder(nn.Module):
    def __init__(self, num_channels: int,
                 num_layers: int,
                 gp_layer: Literal['ew', 'fc'] = 'ew',
                 set2set_layers: int = 1,
                 max_mult: int = 5,
                 eigen_noise: float = 1e-5,
                 use_eigenvalue: bool = False,
                 # use eigenvalue information in Clifford NN
                 use_concat: bool = False,
                 # concatenate invariant representations of different
                 # eigenspaces for each graph, instead of pooling
                 use_signnet: bool = False,
                 # override other settings to use SignNet
                 dim_reduction_method: Literal['hard_split', 
                                               'learnable_project'] = 'hard_split',
                 eigvec_lim: int = 0,
                 restrict_grade = None,
                 share_weights: bool = False,
                 lambda_norm: str = 'none'):
        super().__init__()

        self.num_channels = num_channels
        self.num_layers = num_layers
        self.gp_layer = gp_layer
        self.max_mult = max_mult
        self.eigen_noise = eigen_noise
        self.use_eigenvalue = use_eigenvalue
        self.lambda_norm = lambda_norm
        self.set2set_layers = set2set_layers
        self.restrict_grade = restrict_grade
        self.use_concat = use_concat
        self.eigvec_lim = eigvec_lim
        self.use_signnet = use_signnet
        self.dim_reduction_method = dim_reduction_method

        if dim_reduction_method == 'learnable_project':
            self.projects = nn.ModuleDict()
            for i in range(max_mult+1, 100):
                self.projects[str(i)] = nn.utils.parametrizations.orthogonal(
                    nn.Linear(i, max_mult, bias=False)
                )
        else:
            self.projects = None

        if use_signnet:
            assert self.eigvec_lim > 0
            self.max_mult = 1
            self.encoder = GINEncoder(1, num_channels, 0, num_channels, num_channels // 2,
                                      num_layers, 0.0, use_ln=True, use_bn=False, task='node')
            if self.use_eigenvalue:
                self.lambda_encoder = nn.Sequential(
                    MLP2(1, num_channels, lambda_norm),
                    nn.LayerNorm(num_channels),
                    nn.ELU(),
                    nn.Linear(num_channels, num_channels // 2)
                )
        
        else:
            if self.use_concat:
                assert self.eigvec_lim > 0
                self.contract_lin = nn.Linear(num_channels // 2, 1)
                self.transform_lin = nn.Linear(eigvec_lim, num_channels // 2)

            if self.use_eigenvalue:
                self.lambda_encoder = nn.Sequential(
                    MLP2(1, num_channels, lambda_norm),
                    nn.LayerNorm(num_channels),
                    nn.SiLU()
                )
            else:
                self.lambda_encoder = nn.Sequential(
                    MLP2(1, num_channels, lambda_norm),
                    nn.LayerNorm(num_channels),
                    nn.SiLU(),
                    nn.Linear(num_channels, num_channels // 2)
                )

            self.encoder_dict = nn.ModuleDict()

            """
            Each transforms
                - (..., eigenspace_dim) -> (..., num_channels // 2)
            """
            self.encoder_dict[str(1)] = nn.Sequential(
                ScalarEncoder(num_channels, num_layers, 
                            use_eigenvalue=use_eigenvalue,
                            eigval_encoder=nn.Sequential(
                                    self.lambda_encoder,
                                    nn.Linear(num_channels, num_channels)
                            ) if use_eigenvalue else None),
                nn.Linear(num_channels, num_channels),
                nn.ELU(),
                nn.Linear(num_channels, num_channels // 2)
            )

            if not share_weights:
                for i in range(2, max_mult+1):
                    self.encoder_dict[str(i)] = nn.Sequential(
                        FixedDimEncoder(i, num_channels // (i+1), num_layers, gp_layer,
                                        use_eigenvalue=use_eigenvalue,
                                        eigval_encoder=nn.Sequential(
                                            self.lambda_encoder,
                                            nn.Linear(num_channels,
                                                    num_channels // (i+1))
                                        ) if use_eigenvalue else None,
                                        restrict_grade=restrict_grade),
                        nn.Linear(num_channels // (i+1), num_channels),
                        nn.ELU(),
                        nn.Linear(num_channels, num_channels // 2)
                    )
            else:
                self._hidden_channels = num_channels // 3
                self._lambda_encoder = nn.Sequential(
                                            self.lambda_encoder,
                                            nn.Linear(num_channels, self._hidden_channels)
                                        )
                self._max_dim_encoder = FixedDimEncoder(max_mult, self._hidden_channels, num_layers, gp_layer,
                                                        use_eigenvalue=use_eigenvalue,
                                                        eigval_encoder=self._lambda_encoder 
                                                        if use_eigenvalue else None,
                                                        restrict_grade=restrict_grade)
                self.encoder_dict[str(max_mult)] = nn.Sequential(
                        self._max_dim_encoder,
                        nn.Linear(self._hidden_channels, num_channels),
                        nn.ELU(),
                        nn.Linear(num_channels, num_channels // 2)
                    )
                for i in range(2, max_mult):
                    self.encoder_dict[str(i)] = nn.Sequential(
                        FixedDimEncoder(i, self._hidden_channels, num_layers, gp_layer,
                                        use_eigenvalue=use_eigenvalue,
                                        eigval_encoder=self._lambda_encoder 
                                        if use_eigenvalue else None,
                                        restrict_grade=restrict_grade,
                                        share_weights_from=self._max_dim_encoder),
                        nn.Linear(self._hidden_channels, num_channels),
                        nn.ELU(),
                        nn.Linear(num_channels, num_channels // 2)
                    )

            set2sets = [Set2Set(num_channels // 2) for _ in range(set2set_layers)]
            
            if self.use_eigenvalue:
                self.lambda_final = nn.Sequential(
                    nn.Linear(num_channels, num_channels),
                    nn.ELU(),
                    nn.Linear(num_channels, num_channels // 2),
                    *set2sets
                )
            else:
                self.lambda_final = nn.Sequential(*set2sets)

    def to(self, device):
        super().to(device)
        if hasattr(self, 'encoder_dict'):
            for i in range(2, self.max_mult+1):
                self.encoder_dict[str(i)][0].to(device)
        if self.projects is not None:
            for _, proj in self.projects.items():
                proj.to(device)
        return self

    def forward(self, Lambda, U, E_emb = None, mask = None):
        """
        Input: 
            - Lambda: (batch_size, num_eigvals)
            - U: (batch_size, num_nodes, num_eigvals)
            - E_emb (optional): (batch_size, num_nodes, num_nodes, num_channels)
            - mask (optional): ()

        Output:
            - U_emb: (batch_size, num_nodes, emb_dim)
            - where `emb_dim = num_channels // 2`
        """
        if self.use_signnet:
            Lambda_, U_ = process_data(Lambda, U, use_signnet=True, 
                                       num_eigenvecs=self.eigvec_lim)

            U_emb_pos = self.encoder(U_.unsqueeze(-1), E_emb, None, mask)
            U_emb_neg = self.encoder(-U_.unsqueeze(-1), E_emb, None, mask)

            U_emb = U_emb_pos + U_emb_neg
            U_emb = U_emb.reshape(*U_emb.shape[:2], -1)

            if self.use_eigenvalue:
                Lambda_emb = self.lambda_encoder(Lambda_.unsqueeze(-1))
                return Lambda_emb.reshape(Lambda_emb.shape[0], -1).unsqueeze(1) * U_emb
            
            return U_emb

        batch_size = Lambda.shape[0]

        data_dict = process_data(Lambda, U, self.max_mult, 
                                 method=self.dim_reduction_method, 
                                 projects=self.projects)
        
        U_emb_list = []

        if self.use_concat:
            mask_list = []
            cum_eigvecs = 0

        for dim, data in data_dict.items():
            U, Lambda, graph_idx = data['U'], data['Lambda'], data['graph_idx']

            if self.training:
                Lambda += self.eigen_noise * torch.randn_like(Lambda)
                U += self.eigen_noise * torch.randn_like(U)

            # (num_eigenspaces, num_nodes, num_channels // 2)
            U_emb = self.encoder_dict[str(dim)]((U, Lambda))
            # (num_eigenspaces, num_channels // 2)
            Lambda_emb = self.lambda_final(self.lambda_encoder(
                Lambda.unsqueeze(-1)))
            
            # (num_eigenspaces, num_nodes, num_channels // 2)
            U_emb = Lambda_emb.unsqueeze(-2) * U_emb

            if self.use_concat:
                emb_batch, emb_mask, num_eigvecs = scatter_cat(U_emb, graph_idx, 
                                                               dim, batch_size)
                U_emb_list.append(self.contract_lin(emb_batch))
                mask_list.append(emb_mask)
                cum_eigvecs += num_eigvecs
                if torch.all(cum_eigvecs >= self.eigvec_lim): 
                    break
            else:
                # pooling over all eigenspaces
                # (batch_size, num_nodes, num_channels // 2)
                U_emb_pooled = scatter(U_emb, graph_idx, dim=0, dim_size=batch_size)

                U_emb_list.append(U_emb_pooled)

        if self.use_concat:
            if torch.any(cum_eigvecs >= self.eigvec_lim):
                fill_to = cum_eigvecs.max()
            else:
                fill_to = self.eigvec_lim
            device = U_emb_list[0].device
            hsize = (fill_to - cum_eigvecs).max()
            data = torch.zeros(batch_size, hsize, *U_emb_list[0].shape[2:], 
                               device=device, dtype=U_emb_list[0].dtype)
            filter = torch.zeros(batch_size, hsize, device=device, 
                                 dtype=torch.bool)
            filter[torch.arange(hsize, device=device).repeat(batch_size).reshape(
                batch_size, hsize) < (fill_to - cum_eigvecs
                                      ).unsqueeze(-1)] = True
            U_emb_list.append(data)
            mask_list.append(filter)
            U_emb_final = torch.cat(U_emb_list, dim=1)
            mask_final = torch.cat(mask_list, dim=1)
            num_eigvecs = mask_final.sum(-1)[0]
            assert torch.all(mask_final.sum(-1) == num_eigvecs)
            # (batch_size, num_nodes, num_channels // 2)
            return self.transform_lin(U_emb_final[mask_final].reshape(
                batch_size, num_eigvecs, -1, 1)[
                    :, :self.eigvec_lim].squeeze(-1).permute(0, 2, 1))
        else:
            return sum(U_emb_list) # (batch_size, num_nodes, num_channels // 2)


@register_node_encoder('CAPE')
class CAPENodeEncoder(nn.Module):
    def __init__(self, dim_emb, expand_x=True):
        super().__init__()
        dim_in = cfg.share.dim_in  # Expected original input node features dim
        pecfg = cfg.posenc_CAPE
        dim_pe = pecfg.dim_pe
        self.dim_pe = dim_pe
        self.use_signnet = pecfg.get('use_signnet', False)
        self.pass_as_var = pecfg.get('pass_as_var', False)  # Pass PE also as a separate variable

        if dim_emb - dim_pe < 1:
            raise ValueError(f"CAPE size {dim_pe} is too large for "
                             f"desired embedding size of {dim_emb}.")
        if expand_x:
            self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe)
        self.expand_x = expand_x
        self.type = pecfg.get('type', 'subspace')
        assert self.type in ['subspace', 'onespace']
        if self.type == 'subspace':
            self.encoder = Encoder(num_channels=dim_pe*2,
                                   num_layers=pecfg.layers,
                                   gp_layer=pecfg.get('gp_layer', 'ew'),
                                   set2set_layers=pecfg.get('set2set_layers', 1),
                                   max_mult=pecfg.get('max_mult', 5),
                                   eigen_noise=pecfg.get('eigen_noise', 1e-5),
                                   use_eigenvalue=pecfg.get('use_eigenvalue', False), # use eigenvalue information in Clifford NN
                                   use_concat=pecfg.get('use_concat', False), # concatenate invariant representations of different eigenspaces for each graph, instead of pooling
                                   use_signnet=pecfg.get('use_signnet', False), # override other settings to use SignNet
                                   dim_reduction_method=pecfg.get('dim_reduction_method', 'hard_split'),
                                   eigvec_lim=pecfg.get('eigvec_lim', 0),
                                   restrict_grade=pecfg.get('restrict_grade', None),
                                   share_weights=pecfg.get('share_weights', False),
                                   lambda_norm=pecfg.get('lambda_norm', 'none'))
        else:
            self.decouple = pecfg.get('decouple', True)
            if self.decouple:
                self.dim_h = pecfg.get('dim_h', 64)
                self.encoder = EigenEncoder(hiddim=self.dim_h,
                                            outdim=dim_pe,
                                            numlayer=pecfg.layers)
            else:
                pass  # TODO: here return Identity, jointly embed in the main model

    def forward(self, batch):
        # Expand node features if needed
        if self.expand_x:
            h = self.linear_x(batch.x)
        else:
            h = batch.x
        N = cfg.dataset.get('max_num_nodes', 37)
        if self.use_signnet:
            raise NotImplementedError
            # U_emb = self.encoder(batch['Lambda'].reshape(batch.num_graphs, N),
            #                      batch['U'].reshape(batch.num_graphs, N, N),
            #                      E_emb, batch['M'].to(self.device))
        elif self.type == 'subspace':
            pos_enc = self.encoder(batch['Lambda'].reshape(batch.num_graphs, N),
                                   batch['U'].reshape(batch.num_graphs, N, N))
        else:
            if self.decouple:
                ZERO_EPS = 1e-2
                GAP_EPS = cfg.posenc_CAPE.get('gap_eps', 5e-2)
                LapNoise = 1e-6
                DecompNoise = 0.  # 1e-6
                AUG = 1
                A = to_dense_adj(batch.edge_index, batch.batch, max_num_nodes=N)
                D = A.sum(-1)

                if cfg.posenc_CAPE.get('use_laplacian', True):
                    L = torch.diag_embed(D) - A
                else:
                    L = A

                if cfg.posenc_CAPE.get('normalize_matrix', False):
                    # L <- D^(-1/2) L D^(-1/2)
                    tD = torch.clamp_min(D, 1).rsqrt()
                    L = tD.unsqueeze(-1) * L * tD.unsqueeze(-2)

                # Lambda: (B, N)
                # U: (N, N)
                # dim -1 of U is eigval_dim, dim -2 of U is node_dim
                fL = L.to(torch.float)
                fL = fL + LapNoise * torch.randn_like(fL)
                fL = fL + AUG * torch.eye(fL.shape[1], dtype=fL.dtype, device=fL.device)
                # logging.info(fL.shape)
                Lambda, U = torch.linalg.eigh(fL)
                Lambda -= AUG
                batch['U'] = U
                batch['Lambda'] = Lambda
                # U = U[:, invperm]

                nonzeromask = Lambda.abs() > ZERO_EPS
                Lambda += DecompNoise * torch.randn_like(Lambda)
                U += DecompNoise * torch.randn_like(U)
                # logging.info(batch.M.shape)
                # logging.info(nonzeromask.shape)
                U.masked_fill_(torch.logical_and(batch.M.reshape(batch.num_graphs, N).unsqueeze(-1), nonzeromask.unsqueeze(1)).logical_not(), 0)
                batch["lambdamask"] = nonzeromask
                batch["ind2"] = torch.stack(torch.nonzero(
                    torch.logical_and(nonzeromask.unsqueeze(1), nonzeromask.unsqueeze(2)).logical_and(
                        torch.abs(Lambda.unsqueeze(1) - Lambda.unsqueeze(2)) < GAP_EPS), as_tuple=True), dim=0)
                def coscutoff(diffabs):
                    return 0.5 * torch.cos(diffabs.clamp_max(GAP_EPS) / GAP_EPS * torch.pi) + 0.5
                ind2 = batch['ind2']
                batch["cutoff2"] = coscutoff(torch.abs(Lambda[ind2[0], ind2[1]] - Lambda[ind2[0], ind2[2]]))
                batch[f"v2{KEYSEP}v2{KEYSEP}2{KEYSEP}v2{KEYSEP}1{KEYSEP}1{KEYSEP}acd"] = filterind(ind2,
                                                                                                 *spspmm_ind(ind2, 2, ind2,
                                                                                                             1, False, 1))
                X = torch.zeros([batch.num_graphs, N, self.dim_h], dtype=torch.float, device=batch.x.device)
                E = A.unsqueeze(-1).repeat(1, 1, 1, self.dim_h)
                pos_enc = self.encoder(X, batch["M"].reshape(batch.num_graphs, N), E, batch['U'], batch['Lambda'], batch["lambdamask"], batch)
            else:
                return batch  # TODO: forward EigenEncoder in the main model
        pos_enc = pos_enc.reshape(batch.num_graphs * N, self.dim_pe)
        pos_enc = pos_enc[batch.M.reshape(-1)]
        # Concatenate final PEs to input embedding
        batch.x = torch.cat((h, pos_enc), 1)
        # Keep PE also separate in a variable (e.g. for skip connections to input)
        if self.pass_as_var:
            batch.pe_CAPE = pos_enc
        return batch


@register_node_encoder('CAPE_joint')
class CAPEJointEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        pecfg = cfg.posenc_CAPE
        dim_pe = cfg.gt.dim_hidden  # should align with the dimension of the main model
        self.dim_pe = dim_pe
        self.type = pecfg.get('type', 'onespace')
        assert self.type == 'onespace'
        self.decouple = pecfg.get('decouple', False)
        assert not self.decouple
        self.dim_h = pecfg.get('dim_h', dim_pe)
        self.encoder = EigenEncoder(hiddim=self.dim_h,
                                    outdim=dim_pe,
                                    numlayer=pecfg.layers)

    def forward(self, batch):
        N = cfg.dataset.get('max_num_nodes', 37)
        res = batch.x
        X, mask = to_dense_batch(batch.x, batch.batch, max_num_nodes=N)
        E = to_dense_adj(batch.edge_index, batch.batch, batch.edge_attr, N)

        ZERO_EPS = 1e-2
        GAP_EPS = cfg.posenc_CAPE.get('gap_eps', 5e-2)
        LapNoise = 1e-6
        DecompNoise = 0.  # 1e-6
        AUG = 1
        A = to_dense_adj(batch.edge_index, batch.batch, max_num_nodes=N)
        D = A.sum(-1)
        if cfg.posenc_CAPE.get('use_laplacian', True):
            L = torch.diag_embed(D) - A
        else:
            L = A
        if cfg.posenc_CAPE.get('normalize_matrix', False):
            # L <- D^(-1/2) L D^(-1/2)
            tD = torch.clamp_min(D, 1).rsqrt()
            L = tD.unsqueeze(-1) * L * tD.unsqueeze(-2)
        # Lambda: (B, N)
        # U: (N, N)
        # dim -1 of U is eigval_dim, dim -2 of U is node_dim
        fL = L.to(torch.float)
        fL = fL + LapNoise * torch.randn_like(fL)
        fL = fL + AUG * torch.eye(fL.shape[1], dtype=fL.dtype, device=fL.device)
        Lambda, U = torch.linalg.eigh(fL)
        Lambda -= AUG
        batch['U'] = U
        batch['Lambda'] = Lambda
        # U = U[:, invperm]
        nonzeromask = Lambda.abs() > ZERO_EPS
        Lambda += DecompNoise * torch.randn_like(Lambda)
        U += DecompNoise * torch.randn_like(U)
        U.masked_fill_(torch.logical_and(mask.unsqueeze(-1), nonzeromask.unsqueeze(1)).logical_not(), 0)
        batch["lambdamask"] = nonzeromask
        batch["ind2"] = torch.stack(torch.nonzero(
            torch.logical_and(nonzeromask.unsqueeze(1), nonzeromask.unsqueeze(2)).logical_and(
                torch.abs(Lambda.unsqueeze(1) - Lambda.unsqueeze(2)) < GAP_EPS), as_tuple=True), dim=0)
        def coscutoff(diffabs):
            return 0.5 * torch.cos(diffabs.clamp_max(GAP_EPS) / GAP_EPS * torch.pi) + 0.5
        ind2 = batch['ind2']
        batch["cutoff2"] = coscutoff(torch.abs(Lambda[ind2[0], ind2[1]] - Lambda[ind2[0], ind2[2]]))
        batch[f"v2{KEYSEP}v2{KEYSEP}2{KEYSEP}v2{KEYSEP}1{KEYSEP}1{KEYSEP}acd"] = filterind(ind2, *spspmm_ind(ind2, 2, ind2, 1, False, 1))

        pos_enc = self.encoder(X, mask, E, batch['U'], batch['Lambda'], batch["lambdamask"], batch)

        pos_enc = pos_enc.reshape(batch.num_graphs * N, self.dim_pe)
        batch.x = pos_enc[batch.M.reshape(-1)] + res

        return batch



def scatter_cat(U_emb, graph_idx, eigen_dim, batch_size):
    values, indices = graph_idx.sort()
    # batch: (batch_size, max_num_eigenspace_per_graph, num_nodes, num_channels)
    # mask: (batch_size, max_num_eigenspace_per_graph)
    batch, mask = to_dense_batch(U_emb[indices], values, batch_size=batch_size)

    # batch: (batch_size, max_num_eigenspace_per_graph * eigenspace_dim, 
    #         num_nodes, num_channels)
    # mask: (batch_size, max_num_eigenspace_per_graph * eigenspace_dim)
    batch = batch.repeat_interleave(eigen_dim, dim=1)
    mask = mask.repeat_interleave(eigen_dim, dim=1)

    return batch, mask, mask.sum(-1) # computes # of eigenvectors of each graph
    

class FixedDimEncoder(nn.Module):
    def __init__(self, dim: int, 
                 num_channels: int, 
                 num_layers: int,
                 gp_layer: Literal['ew', 'fc'] = 'ew',
                 use_eigenvalue: bool = False,
                 eigval_encoder = None,
                 restrict_grade = None,
                 share_weights_from = None,
                 residual: bool = True):
        super().__init__()
        self.dim = dim
        self.algebra = CliffordAlgebra([1.0] * dim)
        self.num_channels = num_channels
        self.num_layers = num_layers
        self.gp_layer = gp_layer
        self.use_eigenvalue = use_eigenvalue
        self.residual = residual
        self.restrict_grade = restrict_grade

        self.module_list = nn.ModuleList()

        if gp_layer == 'ew':
            layer_class = ElementWiseLayer
        elif gp_layer == 'fc':
            layer_class = FullyConnectedLayer

        self.module_list.append(layer_class(self.algebra, 1, num_channels, 
                                            restrict_grade=restrict_grade,
                                            share_weights_from=share_weights_from.module_list[0]
                                            if share_weights_from is not None else None))

        for i in range(num_layers-1):
            self.module_list.append(MVSiLU(self.algebra, num_channels))
            self.module_list.append(layer_class(self.algebra, num_channels, num_channels,
                                                use_eigenvalue=use_eigenvalue,
                                                eigval_encoder=eigval_encoder,
                                                restrict_grade=restrict_grade,
                                                share_weights_from=share_weights_from.module_list[2*i+2]
                                                if share_weights_from is not None else None))

    def to(self, device):
        super().to(device)
        self.algebra.to(device)
        for i in range(self.num_layers):
            self.module_list[2*i].to(device)
        return self

    def forward(self, input_):
        x, eigenvalue = input_
        y = x.reshape(-1, 1, x.shape[-1])
        y = self.algebra.embed_grade(y, 1)
        self.algebra.to(y.device)
        for i in range(self.num_layers-1):
            if self.residual:
                y = self.module_list[2*i+1](self.module_list[2*i](y, eigenvalue)) + y
            else:
                y = self.module_list[2*i+1](self.module_list[2*i](y, eigenvalue))
        if self.residual:
            y = self.module_list[-1](y, eigenvalue) + y
        else:
            y = self.module_list[-1](y, eigenvalue)
        return y[:, :, 0].reshape(*x.shape[:-1], self.num_channels)


class ScalarEncoder(nn.Module):
    """
    Based on SignNet.
    """
    def __init__(self, 
                 num_channels: int, 
                 num_layers: int,
                 use_eigenvalue: bool = False,
                 eigval_encoder = None,
                 residual: bool = True):
        super().__init__()
        self.num_channels = num_channels
        self.num_layers = num_layers
        self.use_eigenvalue = use_eigenvalue
        self.residual = residual

        if self.use_eigenvalue:
            assert eigval_encoder is not None
            self.eigval_encoder = eigval_encoder

        self.module_list = nn.ModuleList()

        self.module_list.append(NormedLinear(1, num_channels))

        for _ in range(num_layers-1):
            self.module_list.append(nn.SiLU())
            self.module_list.append(NormedLinear(num_channels, num_channels))

    def forward(self, input_):
        x, eigenvalue = input_
        return self.forward_(input_) + self.forward_((-x, eigenvalue))
    
    def forward_(self, input_):
        x, eigenvalue = input_
        y = x.reshape(-1, 1)

        if self.use_eigenvalue:
            assert eigenvalue is not None
            assert y.shape[0] % eigenvalue.numel() == 0
            repeats = y.shape[0] // eigenvalue.numel()
            eigenvalue = eigenvalue.reshape(-1
                    ).repeat_interleave(repeats).unsqueeze(-1)

        for i in range(self.num_layers-1):
            if self.residual and self.use_eigenvalue:
                y = self.module_list[2*i+1](self.module_list[2*i](y)
                    ) * self.eigval_encoder(eigenvalue) + y
            elif self.residual:
                y = self.module_list[2*i+1](self.module_list[2*i](y)) + y
            elif self.use_eigenvalue:
                y = self.module_list[2*i+1](self.module_list[2*i](y)
                    ) * self.eigval_encoder(eigenvalue)
            else:
                y = self.module_list[2*i+1](self.module_list[2*i](y))

        if self.residual and self.use_eigenvalue:
            y = self.module_list[-1](y) * self.eigval_encoder(eigenvalue) + y
        elif self.residual:
            y = self.module_list[-1](y) + y
        elif self.use_eigenvalue:
            y = self.module_list[-1](y) * self.eigval_encoder(eigenvalue)
        else:
            y = self.module_list[-1](y)

        return y.reshape(*x.shape[:-1], self.num_channels)


class VMean(nn.Module):

    def __init__(self, hiddim, elementwise_affine: bool = False) -> None:
        super().__init__()
        assert not elementwise_affine

    def forward(self, v1: MaskedTensor, v2: SparseTensor):
        '''
        v (*, m, d)
        '''
        v1 = v1.tuplewiseapply(lambda x: x - torch.mean(x, dim=-1, keepdim=True))
        v2 = v2.tuplewiseapply(lambda x: x - torch.mean(x, dim=-1, keepdim=True))
        return v1, v2


class VNorm(nn.Module):

    def __init__(self, hiddim, elementwise_affine: bool = False) -> None:
        super().__init__()
        assert not elementwise_affine

    def forward(self, v1: MaskedTensor, v2: SparseTensor):
        '''
        v1 (b, m, n, d)
        v2 (b, m, m, n, d)
        '''
        v1 = v1.tuplewiseapply(lambda v: F.normalize(v, dim=-3, eps=1e-1))
        v2norm = torch.rsqrt(v2.tuplewiseapply(torch.square).sum([1, 2]) + 1e-1)
        v2 = v2.tuplewiseapply(lambda v: v * v2norm[v2.indices[0]])
        return v1, v2


class TensorProduct_2(nn.Module):
    def __init__(self, hiddim, res: bool = True) -> None:
        super().__init__()
        self.v2v1_v1 = OpMessagePassing("SD", "sum", "v2", 1, "v1", 0, "v1", 0)
        self.v2v2_v2 = OpMessagePassing("SS", "sum", "v2", 2, "v2", 1, "v2", 1,
                                        message_func=lambda a, b, c, _: a * b * c)
        # self.lin0 = nn.Linear(hiddim, hiddim, bias=True)
        self.lin0 = nn.Sequential(nn.Linear(hiddim, hiddim), nn.LayerNorm(hiddim, elementwise_affine=False),
                                  nn.SiLU(inplace=True))
        self.lin1 = nn.Linear(hiddim, hiddim, bias=False)
        self.lin2 = nn.Linear(hiddim, hiddim, bias=False)
        self.coeff = nn.Parameter(torch.randn((8, hiddim)))
        self.linout0 = nn.Linear(hiddim, hiddim, bias=True)
        self.linout1 = nn.Linear(hiddim, hiddim, bias=False)
        self.linout2 = nn.Linear(hiddim, hiddim, bias=False)
        # self.linE = nn.Linear(hiddim, hiddim, bias=True)
        '''
        with torch.no_grad():
            self.lin0_.weight.fill_(0.)
            self.lin1_.weight.fill_(0.)
            self.lin2_.weight.fill_(0.)
        '''
        # self.norm = nn.LayerNorm(hiddim, elementwise_affine=False)
        self.vmean = VMean(hiddim)
        self.vnorm = VNorm(hiddim)
        self.res = res

    def forward(self, s: Tensor, v1: MaskedTensor, v2: SparseTensor, v2cosmask: SparseTensor, datadict: dict) -> Tuple[
        Tensor, MaskedTensor, SparseTensor]:
        '''
        n is not sparse or mask
        s, s_ (b, n, d)
        v1, v1_ (b, m, n, d)
        v2, v2_ sparse tensor of shape (b, m, m, n, d)
        '''
        # print("prod in")
        # print(torch.std(s, dim=1).mean(), torch.std(v1.fill_masked(0), dim=1).mean(), torch.std(v2.values, dim=0).mean())
        # print(torch.std(s_, dim=1).mean(), torch.std(v1_.fill_masked(0), dim=1).mean(), torch.std(v2_.values, dim=0).mean())
        coeff = self.coeff  # 2*torch.sigmoid(self.coeff)-1 # tanh cause compile error
        # coeff = coeff.clone() # self.coeff #self.coeff#
        rs, rv1, rv2 = s, v1, v2
        # s = self.norm(s)
        v1, v2 = self.vnorm(*self.vmean(v1, v2))
        s_, v1_, v2_ = self.lin0(s), v1.tuplewiseapply(self.lin1), v2.tuplewiseapply(self.lin2)
        s, v1, v2 = rs, rv1, rv2

        rets = coeff[0] * (s * s_)  # 0, 0 -> 0
        rets = rets + coeff[1] * torch.sum(v1.fill_masked(0.) * v1_.fill_masked(0.), dim=1)  # 1, 1 -> 0
        rets = rets + coeff[2] * v2.tuplewiseapply(lambda v: v * v2_.values).sum([1, 2])  # 2, 2 -> 0
        rets = self.linout0(rets)

        retv1 = self.v2v1_v1.forward(v2_, v1, v1, datadict).tuplewiseapply(lambda x: x * coeff[3])  # 2, 1 -> 1
        retv1 = retv1.add(v1.tuplewiseapply(lambda x: x * s_.unsqueeze(1) * coeff[4]), samesparse=True)
        retv1 = retv1.tuplewiseapply(self.linout1)

        retv2 = self.v2v2_v2.forward(v2, v2_, v2cosmask, datadict).tuplewiseapply(lambda x: x * coeff[5])  # 2, 2 -> 2
        retv2 = retv2.add(v2.tuplewiseapply(lambda v: v * s_[v2.indices[0]] * coeff[6]), True)  # 1, 0 -> 1 # 0, 2 -> 2
        # 1, 1 -> 2
        ind = v2.indices
        v112 = v1.fill_masked(0)[ind[0], ind[1]] * v1_.fill_masked(0)[ind[0], ind[2]] * v2cosmask.values * coeff[7]
        retv2 = retv2.tuplewiseapply(lambda v: v112 + v)
        retv2 = retv2.tuplewiseapply(self.linout2)
        nE = 0  # self.linE(torch.einsum("bmid,bmjd->bijd", v1.fill_masked(0), v1_.fill_masked(0)))
        if self.res:
            rets = rets + rs
            retv1 = retv1.add(rv1, True)
            retv2 = retv2.add(rv2, True)
        return rets, retv1, retv2, nE


class MPNN2(nn.Module):
    def __init__(self, hiddim, res: bool = True) -> None:
        super().__init__()
        self.res = res
        # self.linIns = nn.Sequential(nn.Linear(hiddim, hiddim), nn.SiLU(inplace=True))
        self.linIns = nn.Sequential(nn.Linear(hiddim, hiddim), nn.LayerNorm(hiddim, elementwise_affine=False), nn.SiLU(inplace=True))

        self.linInv1 = nn.Linear(hiddim, hiddim, bias=False)
        self.linInv2 = nn.Linear(hiddim, hiddim, bias=False)
        # self.norm = nn.LayerNorm(hiddim, elementwise_affine=False)
        self.vmean = VMean(hiddim)
        self.vnorm = VNorm(hiddim)

    def forward(self, E: torch.Tensor, s: torch.Tensor, v1: MaskedTensor, v2: SparseTensor):
        '''
        s: (b, N, d)
        E: (b, N, N, d)
        v1: (b, m, n, d)
        v2: (b, m, m, n, d)
        '''
        # print("mpnn in")
        # print(torch.std(s, dim=1).mean(), torch.std(v1.fill_masked(0), dim=1).mean(), torch.std(v2.values, dim=0).mean())
        rs, rv1, rv2 = s, v1, v2
        # s = self.norm(s)
        v1, v2 = self.vnorm(*self.vmean(v1, v2))
        rets = torch.einsum("bjd,bijd->bid", self.linIns(s), E)
        retv1 = torch.einsum("bijd,bmjd->bmid", E, v1.tuplewiseapply(self.linInv1).fill_masked(0))
        retv1 = MaskedTensor(retv1, v1.mask, is_filled=True)
        retv2 = v2.tuplewiseapply(lambda x: (
                    E.transpose(1, 3)[v2.indices[0]] @ self.linInv2(v2.values).transpose(1, 2).unsqueeze(-1)).squeeze(
            -1).transpose(1, 2))  # torch.einsum("cijd,cjd->cid",E[v2.indices[0]], self.linInv2(x)))
        if self.res:
            rets = rets + rs
            retv1 = retv1.add(rv1, True)
            retv2 = retv2.add(rv2, True)
        return rets, retv1, retv2


default_mlp_dict = mlpdict = {"norm": "ln", "activation": "silu", "dropout": 0.0}


class EigenEncoder(nn.Module):

    def __init__(self, hiddim, outdim, numlayer, mlpkwarg1=default_mlp_dict, mlpkwarg2=default_mlp_dict,
                 mlpkwarg3=default_mlp_dict) -> None:
        super().__init__()
        self.mpnns = nn.ModuleList([MPNN2(hiddim, res=True) for _ in range(numlayer)])
        self.prods = nn.ModuleList([TensorProduct_2(hiddim, res=True) for _ in range(numlayer)])
        self.numlayer = numlayer
        self.linlambda0 = MLP(1, 2 * hiddim, 2 * hiddim, 1, True, **mlpkwarg1)
        self.lambdasetencoder = Set2Set(2 * hiddim, aggr="sum", **mlpkwarg2)
        self.lambdasetencoder2 = Set2Set(2 * hiddim, aggr="sum", **mlpkwarg2)
        self.lambdasetencoder3 = Set2Set(2 * hiddim, aggr="sum", **mlpkwarg2)
        self.linlambda1 = MLP(2 * hiddim, 2 * hiddim, hiddim, 2, tailact=False, **mlpkwarg3)
        self.linlambda2 = MLP(4 * hiddim, 4 * hiddim, hiddim, 2, tailact=False, **mlpkwarg3)
        self.hiddim = hiddim
        self.outdim = outdim
        self.outmlp = nn.Linear(self.hiddim, self.outdim) if self.hiddim != self.outdim else nn.Identity()

    def forward(self, X: Tensor, nodemask: Tensor, E: Tensor, U: Tensor, Lambda: Tensor, lambdamask: Tensor,
                datadict: dict):
        '''
        X: (b, N, d)
        nodemask: (b, N, d)
        E: (b, N, N, d)
        U: (b, N, M)
        Lambda: (b, M)
        lambdamask: (b, M)
        '''
        NodeNormVec = torch.rsqrt(nodemask.to(torch.float).sum(dim=1))
        s = X

        lambdasmooth = torch.tanh(Lambda).unsqueeze(-1)
        lambda0 = self.linlambda0(Lambda.unsqueeze(-1)) * lambdasmooth
        lambdaset = self.lambdasetencoder.forward(lambda0, lambdamask, lambdasmooth, NodeNormVec.reshape(-1, 1, 1))
        lambdaset = self.lambdasetencoder2.forward(lambdaset, lambdamask, lambdasmooth, NodeNormVec.reshape(-1, 1, 1))
        lambdaset = self.lambdasetencoder3.forward(lambdaset, lambdamask, lambdasmooth, NodeNormVec.reshape(-1, 1, 1))
        ind2 = datadict["ind2"]
        cutoff2 = datadict["cutoff2"].unsqueeze(-1) * lambdasmooth[ind2[0], ind2[1]] * lambdasmooth[ind2[0], ind2[2]]
        lambda2 = self.linlambda2(
            torch.concat((lambdaset[ind2[0], ind2[1]], lambdaset[ind2[0], ind2[2]]), dim=-1)) * cutoff2
        lambda1 = self.linlambda1(lambdaset) * lambdasmooth
        B, N, M = U.shape

        v1 = MaskedTensor(U.transpose(1, 2).unsqueeze(-1) * lambda1.unsqueeze(2), lambdamask)
        v2value = lambda2.unsqueeze(-2) * (U[ind2[0], :, ind2[1]] * U[ind2[0], :, ind2[2]]).unsqueeze(-1)
        v2 = SparseTensor(ind2, v2value, shape=(B, M, M, N, self.hiddim), is_coalesced=True)
        v2cosmask = SparseTensor(ind2, cutoff2.unsqueeze(-1), shape=(B, M, M, 1, 1), is_coalesced=True)

        for i in range(self.numlayer):
            s, v1, v2, nE = self.prods[i](s, v1, v2, v2cosmask, datadict)
            s, v1, v2 = self.mpnns[i].forward(E, s, v1, v2)  # * NodeNormVec.reshape(-1, 1, 1, 1)
            # E = E + nE
        s = self.outmlp(s)
        return s

