from typing import Any

import torch
from torch import Tensor
import warnings
import math

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.nn.conv.gcn_conv import gcn_norm

import torch.nn as nn
from torch.nn import Parameter, ReLU, Module


import torch
import torch.nn as nn
from torch.nn import Module, Linear, ModuleList, Sequential, LeakyReLU
from torch_geometric.data import Data
import torch_geometric.nn as pyg_nn
from torch_geometric.nn import global_add_pool, global_max_pool, global_mean_pool
from typing import Optional
from collections import OrderedDict
from torch import tanh
from torch_geometric.nn import ChebConv, BatchNorm#, BatchNorm1d
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd.functional import jacobian

from torch import Tensor
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import OptTensor
from torch_geometric.utils import get_laplacian
from torch.nn.utils.parametrize import register_parametrization
class AntiSymmetric(Module):
    r"""
    Anti-Symmetric Parametrization

    A weight matrix :math:`\mathbf{W}` is parametrized as
    :math:`\mathbf{W} = \mathbf{W} - \mathbf{W}^T`
    """
    def __init__(self, dissipative_force=0.0):
        super().__init__()
        self.g = dissipative_force

    def forward(self, W: Tensor) -> Tensor:
        return W.triu(diagonal=1) - W.triu(diagonal=1).T - self.g * torch.eye(W.shape[0]).to(W.device)

    def right_inverse(self, W: Tensor) -> Tensor:
        return W.triu(diagonal=1)


class NonDissip_ChebConv_old(MessagePassing):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        K: int,
        bias: bool = True,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        assert K > 0

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalization = 'sym'
        self.lins = torch.nn.ModuleList([
            Linear(in_channels, out_channels, bias=False,
                   weight_initializer='glorot') for _ in range(K-1)
        ])
        for i in range(K-1):
            register_parametrization(self.lins[i],'weight', AntiSymmetric())

        if bias:
            self.bias = Parameter(Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        for lin in self.lins[1:]:
            lin.reset_parameters()
        zeros(self.bias)


    def __norm__(
        self,
        edge_index: Tensor,
        num_nodes: Optional[int],
        edge_weight: OptTensor,
        normalization: Optional[str],
        lambda_max: OptTensor = None,
        dtype: Optional[int] = None,
        batch: OptTensor = None,
    ):
        edge_index, edge_weight = get_laplacian(edge_index, edge_weight,
                                                normalization, dtype,
                                                num_nodes)
        assert edge_weight is not None

        if lambda_max is None:
            lambda_max = 2.0 * edge_weight.max()
        elif not isinstance(lambda_max, Tensor):
            lambda_max = torch.tensor(lambda_max, dtype=dtype,
                                      device=edge_index.device)
        assert lambda_max is not None

        if batch is not None and lambda_max.numel() > 1:
            lambda_max = lambda_max[batch[edge_index[0]]]

        edge_weight = (2.0 * edge_weight) / lambda_max
        edge_weight.masked_fill_(edge_weight == float('inf'), 0)

        loop_mask = edge_index[0] == edge_index[1]
        edge_weight[loop_mask] -= 1

        return edge_index, edge_weight

    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        edge_weight: OptTensor = None,
        batch: OptTensor = None,
        lambda_max: OptTensor = None,
        eig: bool = False
    ) -> Tensor:

        edge_index, norm = self.__norm__(
            edge_index,
            x.size(self.node_dim),
            edge_weight,
            self.normalization,
            lambda_max,
            dtype=x.dtype,
            batch=batch,
        )

        self.edge_index, self.norm = edge_index, norm
        self.n1, self.n2 = x.shape[0], x.shape[1]

        Tx_0 = x
        Tx_1 = x  # Dummy.
        out = Tx_0

        # propagate_type: (x: Tensor, norm: Tensor)
        if len(self.lins) >= 1:
            Tx_1 = self.propagate(edge_index, x=x, norm=norm)
            out = out + self.lins[0](Tx_1)

        for lin in self.lins[1:]:
            Tx_2 = self.propagate(edge_index, x=Tx_1, norm=norm)
            Tx_2 = 2. * Tx_2 - Tx_0
            out = out + lin.forward(Tx_2)
            Tx_0, Tx_1 = Tx_1, Tx_2

        if self.bias is not None:
            out = out + self.bias

        if eig:
          eigs_r = []
          eigs_im = []
          J = jacobian(self.conv_jac, (x.view(-1)), create_graph=True)

          print('Computing Eigenvalues')

          eigs = torch.linalg.eigvals(J).detach().cpu().numpy().tolist()
          for element in eigs:
                  eigs_r.append(element.real)
                  eigs_im.append(element.imag)

          # make two subplots: the first is the distribution of the eigenvalues and the second is the plot of the eigenvalues in the complex plane
          fig, axs = plt.subplots(1,2)
          axs[0].hist(eigs_r, bins=100, label='Real', alpha=0.5)
          axs[0].hist(eigs_im, bins=100, label='Imaginary', alpha=0.5)
          axs[0].set_xlabel('Eigenvalue')
          axs[0].set_ylabel('Frequency')
          axs[0].set_title(f'Distribution of Eigenvalues')
          axs[0].legend()
          lin = np.linspace(0, 2*np.pi, 1000)
          axs[1].plot(np.cos(lin), np.sin(lin), linewidth=1, color='k')
          axs[1].scatter(eigs_r, eigs_im, marker='x', label='Eigs, After Training', linewidth=2, rasterized=True)
          axs[1].set_xlabel('Real')
          axs[1].set_ylabel('Imaginary')
          axs[1].set_title(f'Eigenvalues')
          plt.show()

        return out

    def conv_jac(self, x: Tensor) -> Tensor:
        x = x.reshape(self.n1, self.n2)
        Tx_0 = x
        Tx_1 = x  # Dummy.
        out = Tx_0

        # propagate_type: (x: Tensor, norm: Tensor)
        if len(self.lins) >= 1:
            Tx_1 = self.propagate(self.edge_index, x=x, norm=self.norm)
            out = out + self.lins[0](Tx_1)

        for lin in self.lins[1:]:
            Tx_2 = self.propagate(self.edge_index, x=Tx_1, norm=self.norm)
            Tx_2 = 2. * Tx_2 - Tx_0
            out = out + lin.forward(Tx_2)
            Tx_0, Tx_1 = Tx_1, Tx_2

        if self.bias is not None:
            out = out + self.bias
        return out.view(-1)

    def message(self, x_j: Tensor, norm: Tensor) -> Tensor:
        return norm.view(-1, 1) * x_j

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, K={len(self.lins)}, '
                f'normalization={self.normalization})')

class Euler_ChebConv(MessagePassing):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        K: int,
        step_size: float,
        # step_size: float = 0.2,
        dissipation_force: float = 0.05,
        bias: bool = True,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        assert K > 0

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalization = 'sym'
        self.e = step_size
        self.g = dissipation_force
        self.lins = torch.nn.ModuleList()
        for _ in range(K):
            self.lins.append(Linear(in_channels, out_channels, bias=False, weight_initializer='glorot'))
            register_parametrization(self.lins[-1],'weight', AntiSymmetric(dissipative_force=self.g))

        if bias:
            self.bias = Parameter(Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        for lin in self.lins[1:]:
            lin.reset_parameters()
        zeros(self.bias)


    def __norm__(
        self,
        edge_index: Tensor,
        num_nodes: Optional[int],
        edge_weight: OptTensor,
        normalization: Optional[str],
        lambda_max: OptTensor = None,
        dtype: Optional[int] = None,
        batch: OptTensor = None,
    ):
        edge_index, edge_weight = get_laplacian(edge_index, edge_weight,
                                                normalization, dtype,
                                                num_nodes)
        assert edge_weight is not None

        if lambda_max is None:
            lambda_max = 2.0 * edge_weight.max()
        elif not isinstance(lambda_max, Tensor):
            lambda_max = torch.tensor(lambda_max, dtype=dtype,
                                      device=edge_index.device)
        assert lambda_max is not None

        if batch is not None and lambda_max.numel() > 1:
            lambda_max = lambda_max[batch[edge_index[0]]]

        edge_weight = (2.0 * edge_weight) / lambda_max
        edge_weight.masked_fill_(edge_weight == float('inf'), 0)

        loop_mask = edge_index[0] == edge_index[1]
        edge_weight[loop_mask] -= 1

        return edge_index, edge_weight

    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        edge_weight: OptTensor = None,
        batch: OptTensor = None,
        lambda_max: OptTensor = None,
        eig: bool = False
    ) -> Tensor:

        edge_index, norm = self.__norm__(
            edge_index,
            x.size(self.node_dim),
            edge_weight,
            self.normalization,
            lambda_max,
            dtype=x.dtype,
            batch=batch,
        )

        self.edge_index, self.norm = edge_index, norm
        self.n1, self.n2 = x.shape[0], x.shape[1]

        Tx_0 = x
        Tx_1 = x  # Dummy.
        out = self.lins[0](Tx_0)

        # propagate_type: (x: Tensor, norm: Tensor)
        if len(self.lins) > 1:
            Tx_1 = self.propagate(edge_index, x=x, norm=norm)
            out = out + self.lins[1](Tx_1)

        for lin in self.lins[2:]:
            Tx_2 = self.propagate(edge_index, x=Tx_1, norm=norm)
            Tx_2 = 2. * Tx_2 - Tx_0
            out = out + lin.forward(Tx_2)
            Tx_0, Tx_1 = Tx_1, Tx_2

        if self.bias is not None:
            out = out + self.bias

        out = x + self.e * out

        if eig:
          eigs_r = []
          eigs_im = []
          J = jacobian(self.conv_jac, (x.view(-1)), create_graph=True)

          print('Computing Eigenvalues')

          eigs = torch.linalg.eigvals(J).detach().cpu().numpy().tolist()
          for element in eigs:
                  eigs_r.append(element.real)
                  eigs_im.append(element.imag)

          # make two subplots: the first is the distribution of the eigenvalues and the second is the plot of the eigenvalues in the complex plane
          fig, axs = plt.subplots(1,2)
          axs[0].hist(eigs_r, bins=100, label='Real', alpha=0.5)
          axs[0].hist(eigs_im, bins=100, label='Imaginary', alpha=0.5)
          axs[0].set_xlabel('Eigenvalue')
          axs[0].set_ylabel('Frequency')
          axs[0].set_title(f'Distribution of Eigenvalues')
          axs[0].legend()
          lin = np.linspace(0, 2*np.pi, 1000)
          axs[1].plot(np.cos(lin), np.sin(lin), linewidth=1, color='k')
          axs[1].scatter(eigs_r, eigs_im, marker='x', label='Eigs, After Training', linewidth=2, rasterized=True)
          axs[1].set_xlabel('Real')
          axs[1].set_ylabel('Imaginary')
          axs[1].set_title(f'Eigenvalues')
          plt.show()
        return out

    def conv_jac(self, x: Tensor) -> Tensor:
        x = x.reshape(self.n1, self.n2)
        Tx_0 = x
        Tx_1 = x  # Dummy.
        out = self.lins[0](Tx_0)

        # propagate_type: (x: Tensor, norm: Tensor)
        if len(self.lins) > 1:
            Tx_1 = self.propagate(self.edge_index, x=x, norm=self.norm)
            out = out + self.lins[1](Tx_1)

        for lin in self.lins[2:]:
            Tx_2 = self.propagate(self.edge_index, x=Tx_1, norm=self.norm)
            Tx_2 = 2. * Tx_2 - Tx_0
            out = out + lin.forward(Tx_2)
            Tx_0, Tx_1 = Tx_1, Tx_2

        if self.bias is not None:
            out = out + self.bias

        out = x + self.e * out
        return out.view(-1)

    def message(self, x_j: Tensor, norm: Tensor) -> Tensor:
        return norm.view(-1, 1) * x_j

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, K={len(self.lins)}, '
                f'normalization={self.normalization})')



class GCNConv(MessagePassing):
    def __init__(self, in_dim, out_dim, normalize='sym', bias=True, **kwargs):
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.normalize = normalize
        self.bias_bool = bias

        self.linear = torch.nn.Linear(in_dim, out_dim, bias=False)
        if self.bias_bool:
            self.bias = Parameter(torch.zeros(out_dim))
        else:
            self.register_parameter('bias', None)

    def forward(self, x, edge_index, **kwargs):
        edge_index, _ = add_self_loops(edge_index)

        if self.normalize == 'sym':
            edge_index, edge_weight = gcn_norm(  # yapf: disable
                edge_index, edge_weight=None, num_nodes=x.size(0),
                add_self_loops=False, dtype=x.dtype)
        elif self.normalize == 'rw':
            deg = degree(edge_index, num_nodes=x.shape(0))
            deg_inv = 1 / deg
            edge_weight = None
        else:
            raise NotImplementedError


        h = self.linear(x)
        h = self.propagate(edge_index, x=h, edge_weight=edge_weight)
        if self.normalize == 'rw':
            torch.einsum('n, nd -> nd', deg_inv, h)

        if self.bias_bool:
            h = h + self.bias

        return h

    def message(self, x_j: Tensor, edge_weight) -> Tensor:
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j


class LinConv(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x, **kwargs):
        return self.linear(x)

class SAGEConv(MessagePassing):

    def __init__(self, in_dim, out_dim, bias=True, **kwargs):
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.bias_bool = bias


        self.linear_self = torch.nn.Linear(in_dim, out_dim, bias=False)
        self.linear_ngb = torch.nn.Linear(in_dim, out_dim, bias=False)

        if self.bias_bool:
            self.bias = Parameter(torch.zeros(out_dim))
        else:
            self.register_parameter('bias', None)


    def forward(self, x, edge_index, **kwargs):
        deg = degree(edge_index[0], num_nodes=x.shape[0])
        deg_inv = 1 / deg

        agg = self.propagate(edge_index, x=x, edge_weight=None)
        agg = torch.einsum('n, nd -> nd', deg_inv, agg)

        h = self.linear_self(x) + self.linear_ngb(agg)
        if self.bias_bool:
            h = h + self.bias

        return h

class SumConv(MessagePassing):

    def __init__(self, in_dim, out_dim, bias=True, **kwargs):
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.bias_bool = bias


        self.linear_self = torch.nn.Linear(in_dim, out_dim, bias=False)
        self.linear_ngb = torch.nn.Linear(in_dim, out_dim, bias=False)

        if self.bias_bool:
            self.bias = Parameter(torch.zeros(out_dim))
        else:
            self.register_parameter('bias', None)


    def forward(self, x, edge_index, **kwargs):
        agg = self.propagate(edge_index, x=x, edge_weight=None)

        h = self.linear_self(x) + self.linear_ngb(agg)
        if self.bias_bool:
            h = h + self.bias
        return h


class GINConv(MessagePassing):
    def __init__(self, in_dim, out_dim, bias=True, eps=True, **kwargs):
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim

        self.bias_bool = bias
        self.eps_bool = eps

        self.mlp = nn.Sequential(
            nn.Linear(self.in_dim, self.out_dim, bias=self.bias_bool),
            ReLU(),
            nn.Linear(self.out_dim, self.out_dim, bias=self.bias_bool)
        )

        if self.eps_bool:
            self.eps = torch.nn.Parameter(torch.zeros(1))
        else:
            self.register_parameter('eps', None)

    def forward(self, x, edge_index, **kwargs):
        agg = self.propagate(edge_index, x=x)
        h = self.mlp(agg + (1 + self.eps) * x)

        return h


class TaylorBuNNConv(MessagePassing):
    def __init__(self, in_dim, out_dim, normalize='sym', bias=True,
                 max_deg=4, tau=1, **kwargs):
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.normalize = normalize
        self.bias_bool = bias
        self.max_deg = max_deg
        self.tau = tau

        self.linear = torch.nn.Linear(in_dim, out_dim, bias=False)
        if self.bias_bool:
            self.bias = Parameter(torch.zeros(out_dim))
        else:
            self.register_parameter('bias', None)

    def forward(self, x, edge_index, node_rep, **kwargs):
        bundle_dim = node_rep.shape[-1]
        num_nodes = x.shape[0]
        dim = x.shape[1]
        num_bundles = node_rep.shape[1]

        if self.normalize == 'sym':
            edge_index, edge_weight = gcn_norm(  # yapf: disable
                edge_index, edge_weight=None, num_nodes=x.size(0),
                add_self_loops=False, dtype=x.dtype)
        elif self.normalize == 'rw':
            deg = degree(edge_index[0], num_nodes=x.shape[0])
            deg_inv = 1 / deg
            edge_weight = None
        else:
            raise NotImplementedError

        vector_field = x.reshape(num_nodes, num_bundles, bundle_dim,
                                 -1)  # works since dim divisible by bundle dim
        # first transform into the 'edge space'
        vector_field = torch.einsum('abcd, abde -> abce', node_rep, vector_field)
        h = vector_field.reshape(num_nodes, dim)

        h = self.linear(h)

        curr_pow = h  # -tau^p*Delta^p/(p!)
        for k in range(1, self.max_deg + 1):
            agg = self.propagate(edge_index, x=curr_pow, edge_weight=edge_weight)  # A*( -tau^p*Delta^p/(p!))
            if self.normalize == 'rw':
                agg = torch.einsum('n, nd -> nd', deg_inv, agg)
            curr_pow = -self.tau * 1 / k * (curr_pow - agg)  # -tau/(p+1) * ((-tau)^p*Delta^p/(p!) - A*( (-tau^p)*Delta^p/(p!)))
            h = h + curr_pow

        if self.bias_bool:
            h = h + self.bias

        # transform back to 'node space'
        vector_field = h.reshape(num_nodes, num_bundles, bundle_dim, -1)
        vector_field = torch.einsum('abcd, abde -> abce', node_rep.transpose(2, 3),
                                    vector_field)  # inverse is transpose
        h = vector_field.reshape(num_nodes, dim)
        return h

    def message(self, x_j: Tensor, edge_weight) -> Tensor:
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j


class SpectralBuNNConv(Module):
    def __init__(self, in_dim, out_dim, normalize='sym', bias=True,
                 learn_tau=True, tau=1, num_bundles=1, **kwargs):
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.normalize = normalize
        self.bias_bool = bias
        self.num_bundles = num_bundles

        self.relu = ReLU()   # to keep the taus positive

        self.linear = torch.nn.Linear(in_dim, out_dim, bias=False)
        if self.bias_bool:
            self.bias = Parameter(torch.zeros(out_dim))
        else:
            self.register_parameter('bias', None)

        if learn_tau:
            self.tau = Parameter((torch.rand([self.num_bundles]) + 0.5) * tau)
        else:
            big_tau = tau * torch.ones([self.num_bundles], dtype=torch.float)
            self.register_buffer('tau', big_tau)
            self.tau = big_tau

    def forward(self, x, edge_index, node_rep, eig_vecs, eig_vals, **kwargs):
        # TODO: this works for a single graph, need to generalize to batches
        # eig_vect is of shape: [n, k] where k<=n where [i, q] is the value at node i of eigenvector k
        # eig_vals is of shape [k]
        bundle_dim = node_rep.shape[-1]
        num_nodes = x.shape[0]
        num_bundles = node_rep.shape[1]
        num_eigs = eig_vecs.shape[1]

        if self.normalize == 'sym':
            _, edge_weight = gcn_norm(  # yapf: disable
                edge_index, edge_weight=None, num_nodes=x.size(0),
                add_self_loops=False, dtype=x.dtype)
        elif self.normalize == 'rw':
            deg = degree(edge_index[0], num_nodes=x.shape[0])
            deg_inv = 1 / deg
            edge_weight = None
        else:
            raise NotImplementedError
        vector_field = x.reshape(num_nodes, num_bundles, bundle_dim,
                                 -1)  # works since dim divisible by bundle dim
        # first transform into the 'edge space'
        vector_field = torch.einsum('abcd, abde -> abce', node_rep, vector_field)
        h = vector_field.reshape(num_nodes, self.in_dim)

        h = self.linear(h)

        # compute projections
        proj = torch.einsum('kn, nd -> kd', eig_vecs.transpose(0, 1), h)  # phi^T*X
        proj = proj.reshape(num_eigs, num_bundles, bundle_dim, -1)
        # compute exponentials
        rep_eigvals = eig_vals.unsqueeze(1).repeat(1, self.num_bundles)
        exponent = torch.einsum('kb, b -> kb', rep_eigvals, self.relu(self.tau))  # -lambda*tau
        expo = torch.exp(-exponent)
        # evolve the projections
        evol = torch.einsum('kbdc, kb -> kbdc', proj, expo)
        # unproject
        unproj = torch.einsum('nk, kbdc -> nbdc', eig_vecs, evol)
        h = unproj.reshape(num_nodes, self.out_dim)

        if self.bias_bool:
            h = h + self.bias

        # transform back to 'node space'
        vector_field = h.reshape(num_nodes, num_bundles, bundle_dim, -1)
        vector_field = torch.einsum('abcd, abde -> abce', node_rep.transpose(2, 3),
                                    vector_field)  # inverse is transpose
        h = vector_field.reshape(num_nodes, self.out_dim)
        return h


class LimBuNNConv(Module):
    def __init__(self, in_dim, out_dim, normalize='sym', bias=True,
                 num_bundles=1, **kwargs):
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.normalize = normalize
        self.bias_bool = bias
        self.num_bundles = num_bundles

        self.linear = torch.nn.Linear(in_dim, out_dim, bias=False)
        if self.bias_bool:
            self.bias = Parameter(torch.zeros(out_dim))
        else:
            self.register_parameter('bias', None)

    def forward(self, x, edge_index, node_rep, **kwargs):
        # TODO: this works for a single graph, need to generalize to batches...
        # eig_vect is of shape: [n, k] where k<=n where [i, q] is the value at node i of eigenvector k
        # eig_vals is of shape [k]
        bundle_dim = node_rep.shape[-1]
        num_nodes = x.shape[0]
        num_bundles = node_rep.shape[1]

        if self.normalize == 'sym':
            deg = degree(edge_index[0], num_nodes=x.shape[0])
            deg_sqrt = deg.pow(0.5)
        elif self.normalize == 'rw':
            deg = degree(edge_index[0], num_nodes=x.shape[0])
        else:
            raise NotImplementedError
        vector_field = x.reshape(num_nodes, num_bundles, bundle_dim,
                                 -1)  # works since dim divisible by bundle dim
        # first transform into the 'edge space'
        vector_field = torch.einsum('abcd, abde -> abce', node_rep, vector_field)
        h = vector_field.reshape(num_nodes, self.in_dim)

        h = self.linear(h)

        if self.normalize == 'sym':
            h = torch.einsum('n, nd -> nd', deg_sqrt, h)
        else:
            h = torch.einsum('n, nd -> nd', deg, h)
        proj = h.sum(dim=0) / (edge_index.size(1))
        h = proj.unsqueeze(0).repeat(num_nodes, 1)
        if self.normalize == 'sym':
            h = torch.einsum('n, nd -> nd', deg_sqrt, h)

        if self.bias_bool:
            h = h + self.bias

        # transform back to 'node space'
        vector_field = h.reshape(num_nodes, num_bundles, bundle_dim, -1)
        vector_field = torch.einsum('abcd, abde -> abce', node_rep.transpose(2, 3),
                                    vector_field)  # inverse is transpose
        h = vector_field.reshape(num_nodes, self.out_dim)
        return h
