# BSign is Not a Remedy: Multiset-to-Multiset Message Passing for Learning on Heterophilic Graphs, ICML 2024
# The source of model's main code: https://github.com/Jinx-byebye/m2mgnn

import os
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torch.nn.parameter import Parameter
import numpy as np
import scipy.sparse as sp
import copy
import time
from torch import FloatTensor
import torch_sparse
from sklearn.metrics import accuracy_score as ACC
from utils import sys_normalized_adjacency, sparse_mx_to_torch_sparse_tensor
from torch_scatter import scatter
from torch_geometric.utils import remove_self_loops


class M2M2_layer(nn.Module):
    def __init__(self, in_features, nhidden, c, dropout, temperature=1):
        super(M2M2_layer, self).__init__()

        self.lin = nn.Linear(in_features, nhidden, bias=False).double()
        self.att = nn.Linear(nhidden, c, bias=False).double()
        self.temperature = temperature
        self.c = c
        self.dropout = dropout
        self.reg = None
        self.reset_parameters()

    def forward(self, x, edge_index):
        x = x.double()

        x = self.lin(x)
        row, col = edge_index
        bin_rela = F.relu(0.5 * x[row] + x[col])
        bin_rela = self.att(bin_rela)
        bin_rela = F.softmax(bin_rela / self.temperature, dim=1)
        self.reg = (
            np.sqrt(self.c)
            / bin_rela.size(0)
            # * torch.linalg.vector_norm(bin_rela.sum(dim=0), 2)
            * torch.norm(bin_rela.sum(dim=0), p=2)
            - 1
        )
        x_j = torch.cat(
            [x[col] * bin_rela[:, i].view(-1, 1) for i in range(self.c)], dim=1
        )
        out = scatter(x_j, row, dim=0, dim_size=x.size(0))


        return out.float()

    def reset_parameters(self):
        init.xavier_uniform_(self.lin.weight)
        init.xavier_uniform_(self.att.weight)


class M2MGNN(nn.Module):
    def __init__(
        self,
        in_features,
        class_num,
        device,
        args,
    ):
        super(M2MGNN, self).__init__()
        # ------------- Parameters ----------------
        self.device = device
        self.epochs = args.epochs
        self.patience = args.patience
        self.lr = args.lr
        self.l2_coef_1 = args.l2_coef_1
        self.l2_coef_2 = args.l2_coef_2
        self.lamda = args.lamda
        self.dropout = args.dropout
        self.dropout2 = args.dropout2
        self.nlayers = args.nlayers
        self.beta = args.beta
        self.reg = None

        # ---------------- Model -------------------
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.lin1 = nn.Linear(in_features, args.nhidden * args.c)
        self.lin2 = nn.Linear(args.nhidden * args.c, class_num)
        self.norms.append(nn.LayerNorm(args.nhidden * args.c))
        self.remove_self_loop = args.remove_self_loop
        for i in range(args.nlayers):
            self.convs.append(
                M2M2_layer(
                    args.nhidden * args.c,
                    args.nhidden,
                    args.c,
                    args.dropout,
                    args.temperature,
                )
            )
            self.norms.append(nn.LayerNorm(args.nhidden * args.c))
        self.params1 = list(self.lin2.parameters()) + list(self.lin1.parameters())
        self.params2 = list(self.convs.parameters()) + list(self.norms.parameters())
        self.reset_parameters()

    def set_precision(self):
        torch.backends.cuda.matmul.allow_tf32 = False  
        torch.backends.cudnn.allow_tf32 = False

    def forward(self, x, edge_index, return_Z=False):
        self.set_precision()
        if self.remove_self_loop == True:
            edge_index, _ = remove_self_loops(edge_index)
        self.reg = 0
        if self.dropout2 != 0:
            x = F.dropout(x, p=self.dropout2, training=self.training)
        x = F.relu(self.lin1(x))
        x = self.norms[0](x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        ego = x

        for i in range(self.nlayers):
            x = F.relu(self.convs[i](x, edge_index))
            x = self.norms[i + 1](x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = (1 - self.beta) * x + self.beta * ego
            self.reg = self.reg + self.convs[i].reg

        self.reg = self.reg / self.nlayers
        if return_Z:
            return x, F.log_softmax(self.lin2(x), dim=1)
        else:
            return F.log_softmax(self.lin2(x), dim=1)

    def reset_parameters(self):
        init.xavier_uniform_(self.lin1.weight)
        init.xavier_uniform_(self.lin2.weight)
        init.zeros_(self.lin1.bias)
        init.zeros_(self.lin2.bias)
        for i in range(self.nlayers + 1):
            self.norms[i].reset_parameters()


    def fit(self, graph, labels, train_mask, val_mask, test_mask):
        graph = graph.to(self.device)
        labels = labels.to(self.device)
        self.train_mask = train_mask.to(self.device)
        self.valid_mask = val_mask.to(self.device)
        self.test_mask = test_mask.to(self.device)
        self.to(self.device)
        X = graph.ndata["feat"]
        n_nodes, _ = X.shape
        adj = graph.adj(scipy_fmt='csr')
        edge_index = torch.tensor(
            np.array(adj.nonzero()), device=self.device, dtype=torch.long
        )

        optimizer = optim.Adam(
            [
                {'params': self.params1, 'weight_decay': self.l2_coef_1},
                {'params': self.params2, 'weight_decay': self.l2_coef_2},
            ],
            lr=self.lr,
            eps=1e-15
        )
        best_epoch = 0
        best_acc = 0.0
        cnt = 0
        best_state_dict = None
        for epoch in range(self.epochs):
            self.train()
            optimizer.zero_grad()
            output = self.forward(X, edge_index)
            loss = (
                F.nll_loss(output[self.train_mask], labels[self.train_mask])
                + self.lamda * self.reg
            )
            loss.backward()
            optimizer.step()

            [train_acc, valid_acc, test_acc] = self.test(
                X,
                edge_index,
                labels,
                [self.train_mask, self.valid_mask, self.test_mask],
            )

            if valid_acc > best_acc:
                cnt = 0
                best_acc = valid_acc
                best_epoch = epoch
                best_state_dict = copy.deepcopy(self.state_dict())
                print(f'\nEpoch:{epoch}, Loss:{loss.item()}')
                print(
                    f'train acc: {train_acc:.3f} valid acc: {valid_acc:.3f}, test acc: {test_acc:.3f}'
                )
            else:
                cnt += 1
                if cnt == self.patience:
                    print(
                        f"Early Stopping! Best Epoch: {best_epoch}, best val acc: {best_acc}"
                    )
                    break
        
        self.load_state_dict(best_state_dict)
        self.best_epoch = best_epoch

    def test(self, X, edge_index, labels, index_list):
        self.eval()
        with torch.no_grad():
            Z = self.forward(X, edge_index)
            y_pred = torch.argmax(Z, dim=1)
        acc_list = []
        for index in index_list:
            acc_list.append(ACC(labels[index].cpu(), y_pred[index].cpu()))
        return acc_list

    def predict(self, graph):
        self.eval()
        graph = graph.to(self.device)
        X = graph.ndata['feat']
        adj = graph.adj(scipy_fmt='csr')
        edge_index = torch.tensor(
            np.array(adj.nonzero()), device=self.device, dtype=torch.long
        )

        with torch.no_grad():
            Z, C = self.forward(X, edge_index, return_Z=True)
            y_pred = torch.argmax(C, dim=1)

        return y_pred.cpu(), C.cpu(), Z.cpu()
