# Towards Dynamic Message Passing on Graphs, NeurIPS 2024
# The source of model's main code: https://github.com/sunjss/N2

import os
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn.parameter import Parameter
import math
import torch.optim as optim
import copy
import time
from sklearn.metrics import accuracy_score as ACC

from torch_geometric.utils import (to_dense_batch, 
                                   add_remaining_self_loops)
from torch_scatter import scatter
from torch_geometric.utils.num_nodes import maybe_num_nodes
from typing import Any, Optional, Tuple, Union
from torch import Tensor


class N2(nn.Module):
    def __init__(
        self,
        in_features,
        class_num,
        device,
        args,
    ):
        super(N2, self).__init__()
        # ------------- Parameters ----------------
        self.device = device
        self.epochs = args.epochs
        self.patience = args.patience
        self.lr = args.lr
        self.l2_coef = args.l2_coef

        # ---------------- Model -------------------
        self.model = N2Node(T=args.nlayers,
                d_in=in_features,
                d_ein=0,
                d_model=args.d_model,
                nclass=class_num, 
                q_dim=args.q_dim,
                n_q=args.n_q,
                n_c=1,
                n_pnode=args.n_pnode,
                task_type="single-class",
                dropout=args.dropout,
                self_loop=~args.wo_selfloop,
                pre_encoder=None,
                pos_encoder=None)

    def forward(self, x, edge_index, return_Z=False):
        output, Z = self.model(x, edge_index)
        if return_Z:
            return output, Z
        return output
        

    def fit(self, graph, labels, train_mask, val_mask, test_mask):
        graph = graph.to(self.device)
        labels = labels.to(self.device)
        self.train_mask = train_mask.to(self.device)
        self.valid_mask = val_mask.to(self.device)
        self.test_mask = test_mask.to(self.device)
        self.to(self.device)
        X = graph.ndata["feat"]
        n_nodes, _ = X.shape
        adj = graph.adj(scipy_fmt='csr')
        edge_index = torch.tensor(
            np.array(adj.nonzero()), device=self.device, dtype=torch.long
        )

        optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.l2_coef)
        
        best_epoch = 0
        best_acc = 0.0
        cnt = 0
        best_state_dict = None
        for epoch in range(self.epochs):
            self.train()
            optimizer.zero_grad()
            output = self.forward(X, edge_index)
            loss = (
                F.nll_loss(output[self.train_mask], labels[self.train_mask])
            )
            loss.backward()
            optimizer.step()

            [train_acc, valid_acc, test_acc] = self.test(
                X,
                edge_index,
                labels,
                [self.train_mask, self.valid_mask, self.test_mask],
            )

            if valid_acc > best_acc:
                cnt = 0
                best_acc = valid_acc
                best_epoch = epoch
                best_state_dict = copy.deepcopy(self.state_dict())
                print(f'\nEpoch:{epoch}, Loss:{loss.item()}')
                print(
                    f'train acc: {train_acc:.3f} valid acc: {valid_acc:.3f}, test acc: {test_acc:.3f}'
                )
            else:
                cnt += 1
                if cnt == self.patience:
                    print(
                        f"Early Stopping! Best Epoch: {best_epoch}, best val acc: {best_acc}"
                    )
                    break
        self.load_state_dict(best_state_dict)
        self.best_epoch = best_epoch

    def test(self, X, edge_index, labels, index_list):
        self.eval()
        with torch.no_grad():
            C = self.forward(X, edge_index)
            y_pred = torch.argmax(C, dim=1)
        acc_list = []
        for index in index_list:
            acc_list.append(ACC(labels[index].cpu(), y_pred[index].cpu()))
        return acc_list

    def predict(self, graph):
        self.eval()
        graph = graph.to(self.device)
        X = graph.ndata['feat']
        adj = graph.adj(scipy_fmt='csr')
        edge_index = torch.tensor(
            np.array(adj.nonzero()), device=self.device, dtype=torch.long
        )

        with torch.no_grad():
            Z, C = self.forward(X, edge_index, return_Z=True)
            y_pred = torch.argmax(C, dim=1)

        return y_pred.cpu(), C.cpu(), Z.cpu()




class N2_model(nn.Module):
    def __init__(self, *, d_in=1, 
                          d_ein=0, 
                          nclass=1, 
                          d_model=64, 
                          q_dim=64, 
                          n_q=8, 
                          n_c=8, 
                          n_pnode=256, 
                          T=1, 
                          task_type="single-class",
                          self_loop=True,
                          pre_encoder=None, 
                          pos_encoder=None, 
                          dropout=0.1):
        super(N2_model, self).__init__()
        d_in = d_in + d_ein
        self.node_state_interface = nn.Sequential(nn.Linear(d_in, q_dim),
                                                  nn.LeakyReLU(),
                                                  nn.Dropout(dropout))
        self.feat_ff = nn.Sequential(nn.Linear(d_in, d_model),
                                     nn.LeakyReLU(),
                                     nn.Dropout(dropout))
        self.pnode_state = nn.Parameter(torch.randn(1, n_pnode, q_dim))
        if task_type != "reg":
            self.node_state_updater = NodePseudoSubsystem(d_in, d_ein, q_dim, n_pnode, d_model, q_dim, n_q, dropout)
            self.class_neuron = nn.Parameter(torch.randn(n_c, nclass, q_dim))
            self.out_ff = PathIntegral(q_dim, n_q)
        else:
            self.node_state_updater = NodePseudoSubsystem(d_in, d_ein, q_dim, n_pnode, d_model, q_dim, n_q, dropout, False)
        self.pre_encoder = pre_encoder 
        self.pos_encoder = pos_encoder 
        self.task_type = task_type  
        self.T = T
        self.n_q = n_q
        self.q_dim = q_dim
        self.n_pnode = n_pnode
        self.d_model = d_model
        self.self_loop = self_loop
    

    def _get_sparse_normalized_adj(self, *, edge_index=None, max_num_nodes=None, edge_weight=None, batch=None):
        if edge_weight is None:
            edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)
        if edge_weight.dtype == torch.long:
            edge_weight = edge_weight.type(torch.float32)

        # normalize edge weight
        row, col = edge_index[0], edge_index[1]
        deg = scatter(edge_weight, row, 0, 
                      dim_size=maybe_num_nodes(edge_index), 
                      reduce='sum')
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 

        # batch gen
        if batch is None:
            num_nodes = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
            batch = edge_index.new_zeros(num_nodes)
        batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1

        # transform edge index into batched index with padding
        one = batch.new_ones(batch.size(0))
        num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='sum')
        cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])

        idx0 = batch[edge_index[0]]
        idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
        idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]

        if ((idx1.numel() > 0 and idx1.max() >= max_num_nodes)
            or (idx2.numel() > 0 and idx2.max() >= max_num_nodes)):
            mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes)
            idx0 = idx0[mask]
            idx1 = idx1[mask]
            idx2 = idx2[mask]
            edge_weight = edge_weight[mask]
        
        idx = torch.stack((idx1, idx2), 0)
        idx = idx0 * max_num_nodes + idx
        return idx, edge_weight


    def _feature_prep(self, X, edge_index):
        features = X
        mask = None
        if features.ndim == 2:
            features = features.unsqueeze(0)
        edge_index, edge_attr = add_remaining_self_loops(edge_index, 
                                                         None, 
                                                         num_nodes=X.shape[0])
        edge_index, edge_weight = self._get_sparse_normalized_adj(edge_index=edge_index, 
                                                                  max_num_nodes=X.shape[0],
                                                                  batch=None)

        return features, edge_index, edge_weight, edge_attr, mask


    def get_output(self, features):
        if "single-class" in self.task_type:
            output = F.log_softmax(features, dim=-1)
        elif self.task_type in ["multi-class",  "reg"]:
            output = features
        elif self.task_type == "binary-class":
            output = features.flatten()
        elif "link" in self.task_type:
            outdim = features.shape[-1] // 2
            node_in = features[:, :outdim]
            node_out = features[:, outdim:]
            if "scale-dot" in self.task_type:
                output = torch.matmul(node_in, node_out.T) / outdim
            elif "cosine" in self.task_type:
                norm_in = torch.norm(node_in, dim=-1)
                norm_out = torch.norm(node_out, dim=-1)
                output = torch.matmul(node_in, node_out.T) / (norm_in * norm_out)
            output = output * 2
        else:
            raise ValueError("Unsupported task type " + self.task_type)
        return output


    def forward(self):
        pass


class N2Node(N2_model):
    def __init__(self, *, d_in=1, 
                          d_ein=0,
                          nclass=1, 
                          d_model=64, 
                          q_dim=64, 
                          n_q=8, 
                          n_c=8, 
                          n_pnode=256, 
                          T=1, 
                          task_type="single-class",
                          pre_encoder=None, 
                          pos_encoder=None, 
                          self_loop=True,
                          dropout=0.1):
        super(N2Node, self).__init__(d_in=d_in, 
                                     d_ein=d_ein,
                                     nclass=nclass, 
                                     d_model=d_model, 
                                     q_dim=q_dim, 
                                     n_q=n_q, 
                                     n_c=n_c, 
                                     n_pnode=n_pnode, 
                                     T=T, 
                                     task_type=task_type,
                                     pre_encoder=pre_encoder, 
                                     pos_encoder=pos_encoder, 
                                     self_loop=self_loop,
                                     dropout=dropout)

    def forward(self, X, edge_index):
        features, edge_index, edge_weight, edge_attr, mask = self._feature_prep(X, edge_index)
        size = features.shape[-2]
        node_state = self.node_state_interface(features)  # (b_s, n, q_dim)
        features = self.feat_ff(features)
        if mask is not None:
            mask = mask.unsqueeze(-1)
            features = mask * features
            node_state = mask * node_state
        
        pnode_state = self.pnode_state
        for t in range(self.T):
            node_state, pnode_state, features = self.node_state_updater(edge_index=edge_index, 
                                                                        edge_weight=edge_weight,
                                                                        edge_attr=edge_attr,
                                                                        features=features, 
                                                                        node_state=node_state, 
                                                                        pnode_state=pnode_state, 
                                                                        mask=mask, 
                                                                        size=size)
        features = self.out_ff(node_state.unsqueeze(1), 
                               self.class_neuron)
        features = features.flatten(0, 1)
        features = features[mask.flatten()] if mask is not None else features
        return self.get_output(features), features


class PathIntegral(nn.Module):
    def __init__(self, q_dim, n_q):
        super(PathIntegral, self).__init__()
        self.lambda_copies = nn.Parameter(torch.randn(n_q, 1, 1))
        self.n_q = n_q
        self.q_dim = q_dim
        self.in_subsystem = None
        self.out_subsystem = None


    def _path_integral(self, in_subsystem=None, out_subsystem=None):
        if in_subsystem == None:
            in_subsystem = self.in_subsystem
        if out_subsystem == None:
            out_subsystem = self.out_subsystem
        if in_subsystem == None or out_subsystem == None:
            raise ValueError("Path integral requires computational object.")
        
        # TODO: ablation study between attention and dist
        if in_subsystem.shape[-3] == 1 and out_subsystem.shape[-3] == self.n_q:
            out_subsystem = out_subsystem * self.lambda_copies
            out_subsystem_sum = out_subsystem.sum(-3) / (self.n_q * self.q_dim)
            weighted_dist_sum = torch.matmul(in_subsystem.squeeze(-3), 
                                             out_subsystem_sum.transpose(-2, -1))
        elif in_subsystem.shape[-3] == self.n_q and out_subsystem.shape[-3] == 1:
            in_subsystem = in_subsystem * self.lambda_copies
            in_subsystem_sum = in_subsystem.sum(-3) / (self.n_q * self.q_dim)
            weighted_dist_sum = torch.matmul(in_subsystem_sum, 
                                             out_subsystem.squeeze(-3).transpose(-2, -1))
        else:
            dist = torch.matmul(in_subsystem, out_subsystem.transpose(-2, -1)) / self.q_dim
            # dist = torch.tanh(dist)
            weighted_dist = dist * self.lambda_copies
            weighted_dist_sum = weighted_dist.sum(-3) / self.n_q  # (n_in, n_out)

        # dist clip
        # weighted_dist_sum = torch.tanh(weighted_dist_sum)
        with torch.no_grad():
            clip_check = torch.abs(weighted_dist_sum.sum(-1, keepdim=True))
            fill_value = torch.ones_like(clip_check)
            scaler = torch.where(clip_check > 1e+4, 1e+4 / clip_check, fill_value)
        weighted_dist_sum = scaler * weighted_dist_sum
        return weighted_dist_sum


    def forward(self, in_subsystem, out_subsystem):
        return self._path_integral(in_subsystem, out_subsystem)


class NodePseudoSubsystem(nn.Module):
    '''
    Neurons as Nodes: get nodes' neuronal state through neurons as nodes
    '''
    def __init__(self, d_in, d_ein, d_out, n_pnode, d_model, q_dim, n_q, dropout=0.0, norm=True):
        super(NodePseudoSubsystem, self).__init__()
        self.collection1 = PathIntegral(q_dim, n_q)
        self.pnode_agg1 = PNodeCommunicator(d_model, d_model, q_dim, n_q, dropout)

        self.inspection = PathIntegral(q_dim, n_q)
        self.edge_wise_ff = nn.Linear(d_model, d_model)
        self.hstate_interface = nn.Sequential(nn.Linear(d_model * 2 + q_dim, q_dim), 
                                              nn.LeakyReLU(), 
                                              nn.Dropout(dropout))

        self.collection2 = PathIntegral(q_dim, n_q)
        self.pnode_agg2 = PNodeCommunicator(d_model * 3 + q_dim, d_out, q_dim, n_q, dropout)
        
        self.dispatch = PathIntegral(q_dim, n_q)
        self.feat_ff = nn.Sequential(nn.Linear(q_dim, d_model), 
                                     nn.LeakyReLU(), 
                                     nn.Dropout(dropout))
        if norm:
            self.phidden_norm = nn.LayerNorm(q_dim)
            self.hidden_norm = nn.LayerNorm(q_dim)
            self.pout_norm = nn.LayerNorm(q_dim)
            self.out_norm = nn.LayerNorm(q_dim)
            self.feat_norm = nn.LayerNorm(d_model)
        self.norm = norm
        print(f"Using norm: {norm}" )

        self.time_embedding = nn.Embedding(6, d_model)
        self.q_dim = q_dim
        self.pnode_num = n_pnode
        self.d_model = d_model


    def _feature_inspection(self, features, node_state, pnode_state, node_num):
        # init feature inspection (node to pnode, pnode-level learning)
        ipn2n_dist = self.collection1(pnode_state, node_state)  # (n_pnode, n)
        glob_init = torch.matmul(ipn2n_dist, features) / node_num  # (n_pnode, d_model)
        pnode_disp1, self.str_inspector = self.pnode_agg1(pnode_state, glob_init)
        pnode_state = pnode_disp1 + pnode_state
        if self.norm:
            pnode_state = self.phidden_norm(pnode_state)

        # inspector dispatch (pnode to node, node-level learning)
        self.pnode_state = pnode_state
        n2ipn_dist = self.inspection(node_state, pnode_state)  # (n, n_pnode)
        inspector = torch.matmul(n2ipn_dist, self.str_inspector)  # (n, d_model)
        return inspector, pnode_state


    def _pnode_aggregator(self, pnode_state, hnode_state, insp_out, node_num):
        # feature collection (node to pnode)
        opn2n_dist = self.collection2(pnode_state, hnode_state)  # (n_pnode, n)
        glob_info = torch.matmul(opn2n_dist, insp_out) / node_num  # (n_pnode, d_model * 2)
        glob_info = torch.concat((glob_info, self.str_inspector), -1)

        # pnode-level feature refinement (pnode-level learning)
        pnode_disp2, dispatch_value = self.pnode_agg2(pnode_state, glob_info)  # (n_pnode, n_pnode)
        pnode_state = pnode_state + pnode_disp2
        if self.norm:
            pnode_state = self.pout_norm(pnode_state)
        
        n2opn_dist = self.dispatch(hnode_state, pnode_state)  # (b_s, n, n_pnode)
        dispatch_value = torch.matmul(n2opn_dist, dispatch_value)
        return dispatch_value, pnode_state


    def _edge_aggregation(self, insp_in, edge_index, edge_weight, edge_attr, size):
        adj = to_torch_coo_tensor(edge_index, edge_weight, size=size)
        insp_out = torch.matmul(adj, insp_in)  # (b_s * n, 2 * d_model)
        return insp_out


    def forward(self, *, edge_index=None, 
                         edge_weight=None, 
                         edge_attr=None,
                         features=None, 
                         node_state=None, 
                         pnode_state=None, 
                         mask=None,
                         size=None):
        # node_state (b_s, n, q_dim), features (b_s, n, d_model)
        b_s, n = features.shape[:2]
        node_num = n if mask is None else mask.sum(-2, keepdim=True)
        
        # inspector generation
        insp, pnode_state = self._feature_inspection(features, 
                                                     node_state.unsqueeze(1),
                                                     pnode_state,
                                                     node_num)
        
        # feature inspection
        insp_in = torch.concat((features, insp, node_state), -1)
        insp_out = self._edge_aggregation(insp_in.flatten(0, -2),
                                          edge_index,
                                          edge_weight, 
                                          edge_attr, 
                                          size).view(b_s, n, -1)
        hnode_state = self.hstate_interface(insp_out)  # (b_s, n, q_dim)
        hnode_state = hnode_state + node_state
        if self.norm:
            hnode_state = self.hidden_norm(hnode_state)

        if mask is not None:
            hnode_state = mask * hnode_state
            
        # feature aggregation
        dispatch_value, pnode_state = self._pnode_aggregator(pnode_state, 
                                                             hnode_state.unsqueeze(1),
                                                             insp_out,
                                                             node_num)
        update_features = self.feat_ff(dispatch_value)
        features = update_features + features
        node_state = hnode_state + dispatch_value # (n, q_dim)
        if self.norm:
            node_state = self.out_norm(node_state)
            features = self.feat_norm(features)
        if mask is not None:
            node_state = mask * node_state
            features = mask * features

        return node_state, pnode_state, features


class PNodeCommunicator(nn.Module):
    def __init__(self, d_in, d_out, q_dim, n_q, dropout):
        super(PNodeCommunicator, self).__init__()
        self.q_dim = q_dim
        self.pnode_agg = PathIntegral(q_dim, n_q)
        self.glob2disp = nn.Sequential(nn.Linear(d_in, q_dim * n_q), 
                                        nn.LeakyReLU(), 
                                        nn.Dropout(dropout))
        self.glob2value = nn.Sequential(nn.Linear(d_in, d_out), 
                                        nn.LeakyReLU(), 
                                        nn.Dropout(dropout))


    def forward(self, state, glob):
        glob_updater = self.pnode_agg(state, state)  # (n_pnode, n_pnode)
        glob_update = torch.matmul(glob_updater, glob)  # (n_pnode, d_in)

        displacement = self.glob2disp(glob_update)  # (n_pnode, q_dim)
        displacement = displacement.unflatten(-1, (self.q_dim, -1))
        displacement = displacement.permute(0, 3, 1, 2)

        dispatch_value = self.glob2value(glob_update)  # (n_pnode, d_out)

        return displacement, dispatch_value



def to_torch_coo_tensor(
    edge_index: Tensor,
    edge_attr: Optional[Tensor] = None,
    size: Optional[Union[int, Tuple[int, int]]] = None,
) -> Tensor:
    """Converts a sparse adjacency matrix defined by edge indices and edge
    attributes to a :class:`torch.sparse.Tensor`.

    Args:
        edge_index (LongTensor): The edge indices.
        edge_attr (Tensor, optional): The edge attributes.
            (default: :obj:`None`)
        size (int or (int, int), optional): The size of the sparse matrix.
            If given as an integer, will create a quadratic sparse matrix.
            If set to :obj:`None`, will infer a quadratic sparse matrix based
            on :obj:`edge_index.max() + 1`. (default: :obj:`None`)

    :rtype: :class:`torch.sparse.FloatTensor`

    Example:

        >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],
        ...                            [1, 0, 2, 1, 3, 2]])
        >>> to_torch_coo_tensor(edge_index)
        tensor(indices=tensor([[0, 1, 1, 2, 2, 3],
                            [1, 0, 2, 1, 3, 2]]),
            values=tensor([1., 1., 1., 1., 1., 1.]),
            size=(4, 4), nnz=6, layout=torch.sparse_coo)

    """
    if size is None:
        size = int(edge_index.max()) + 1
    if not isinstance(size, (tuple, list)):
        size = (size, size)

    if edge_attr is None:
        edge_attr = torch.ones(edge_index.size(1), device=edge_index.device)

    size = tuple(size) + edge_attr.size()[1:]
    out = torch.sparse_coo_tensor(edge_index, edge_attr, size,
                                  device=edge_index.device)
    out = out.coalesce()
    return out