# Simplifying Node Classification on Heterophilous Graphs with Compatible Label Propagation, TMLR 2022
# The source of model's main code: https://github.com/zhiqiangzhongddu/TMLR-CLP

import os
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn.parameter import Parameter
import math
import torch.optim as optim
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch import Tensor
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.data import Data
from typing import Callable, Optional
import torch_geometric.transforms as T
from torch_sparse import SparseTensor, matmul
import copy
import time
from sklearn.metrics import accuracy_score as ACC
from utils import row_normalized_adjacency, sparse_mx_to_torch_sparse_tensor


class CLP(nn.Module):

    def __init__(
            self,
            in_features: int,
            class_num: int,
            device,
            args,
        ) -> None:
        super().__init__()
        #------------- Parameters ----------------
        self.device = device
        
        self.alpha = args.alpha
        self.diff_post_step = args.diff_post_step
        self.echo_set = args.echo_set
        self.select_eval = args.select_eval
        self.weight_type = args.weight_type
        
        os.makedirs(f'./save/res/CLP/{args.dataset}/', exist_ok=True)
        self.res_path = f'./save/res/CLP/{args.dataset}/{args.run}.pt'
        #---------------- Layer -------------------
        self.model_mlp = Basic_MLP(in_features, class_num, hid_dim=64, dropout=0.5).to(device)

        self.model = LabelPropagationHeterophilyTog(max_layers=20, num_hops=1)


    def fit(self, graph, labels, train_mask, val_mask, test_mask):
        # model init
        graph = graph.to(self.device)
        labels = labels.to(self.device)
        self.train_mask = train_mask.to(self.device)
        self.valid_mask = val_mask.to(self.device)
        self.test_mask = test_mask.to(self.device)
        self.to(self.device)
        X = graph.ndata["feat"]
        n_nodes, _ = X.shape
        adj = graph.adj(scipy_fmt='csr')
        edge_index = torch.tensor(np.array(adj.nonzero()), device=self.device, dtype=torch.long)

        data = Data(x=X, y=labels, edge_index=edge_index)
        to_sparse = T.ToSparseTensor(remove_edge_index=False)
        data = to_sparse(data)

        adj_t = data.adj_t
        
        best_pred = self.init_prediction(X, labels)
        prior_estim = best_pred.clone()

        deg = adj_t.sum(dim=1).to(torch.float)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        DA = deg_inv_sqrt.view(-1, 1) * deg_inv_sqrt.view(-1, 1) * adj_t
        DA = DA.storage.value()
        
        gcn_weight = gcn_norm(adj_t, add_self_loops=False).storage.value()
        
        A = adj_t
        Y_true = F.one_hot(data.y.view(-1)).float()
        H_true = torch.matmul(Y_true.transpose(0, 1), matmul(A, Y_true)) \
                / torch.matmul(Y_true.transpose(0, 1), matmul(A, Y_true)).sum(-1, keepdim=True)
        H_true = makeDoubleStochastic(H_true)
        H, B = get_myo_h(
            A, labels, prior_estim,
            mask=self.train_mask
        )
        A_row, A_col, _ = A.coo()
        cm_weight = torch.matmul(B[A_col], H) * B[A_row]
        echo_H = torch.matmul(H, H)
        echo_value = torch.matmul(B[A_col], echo_H) * B[A_row]
        echo_value = echo_value * DA.view(-1, 1)

        y_soft = best_pred.clone()
        edge_weight = cm_weight if self.weight_type==0 else cm_weight * gcn_weight.view(-1, 1)
        echo_weight = echo_value if self.echo_set is True else None
        if self.diff_post_step:
            post_step=lambda y: F.softmax(y, dim=-1)
        else:
            post_step=lambda y: y.clamp_(0., 1.)

        idx, best_epoch, val_acc, test_acc, output = self.model(
            y_true=labels, y_soft=y_soft, alpha=self.alpha,
            spread_mask=self.train_mask, eval_mask=self.valid_mask, test_mask=self.test_mask,
            adj=adj_t, edge_weight=edge_weight, echo_weight=echo_weight,
            verbose=False, post_step=post_step, select_eval=self.select_eval,
        )
        print('Method: {}, Best epoch: {} -- Best Val: {:.4f}, Best Test: {:.4f}'.
                format(idx+1, best_epoch, val_acc, test_acc))
        self.C_pred = output


    def init_prediction(self, X, labels):
        if os.path.exists(self.res_path):
            res = torch.load(self.res_path)
            best_pred = res['best_pred']
            return best_pred.to(self.device)
    
        optimizer = torch.optim.Adam(params=self.model_mlp.parameters(), lr=0.01, weight_decay=5e-5)

        best_train_acc = best_val_acc = best_test_acc = 0
        best_epoch = 0
        criterion = nn.CrossEntropyLoss()
        for epoch in range(1, 500):
            self.model_mlp.train()
            optimizer.zero_grad()
            out = self.model_mlp(X)
            loss = criterion(out[self.train_mask], labels[self.train_mask])

            loss.backward()
            optimizer.step()
            [train_acc, val_acc, test_acc], pred = self.test_MLP(X, labels, [self.train_mask, self.valid_mask, self.test_mask])
            if val_acc > best_val_acc:
                best_epoch = epoch
                best_train_acc = train_acc
                best_val_acc = val_acc
                best_test_acc = test_acc
                best_pred = pred
            if epoch - best_epoch == 100:
                break
        print('Best Train: {:.4f}, Best Val: {:.4f}, Best Test: {:.4f}'.
            format(best_train_acc, best_val_acc, best_test_acc))
        res = {
            'best_pred': best_pred.cpu(),
        }
        torch.save(res, self.res_path)
        return best_pred
        
    def test_MLP(self, X, labels, index_list):
        self.model_mlp.eval()
        with torch.no_grad():
            C = self.model_mlp(X)
            logits = F.softmax(C, dim=1)
            y_pred = torch.argmax(logits, dim=1)
        acc_list = []
        for index in index_list:
            acc_list.append(ACC(labels[index].cpu(), y_pred[index].cpu()))
        return acc_list, logits


    def predict(self, graph):
        self.eval()
        y_pred = torch.argmax(self.C_pred, dim=1)

        return y_pred.cpu(), self.C_pred.cpu(), None



class Basic_MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=64, num_layers=3, dropout=0.5):
        super(Basic_MLP, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.dropout = dropout
        self.num_layers = num_layers

        self.conv_layers = nn.ModuleList()
        if num_layers == 1:
            self.lin = nn.Linear(self.in_dim, self.out_dim)
        else:
            for i in range(num_layers):
                if i == 0:
                    self.conv_layers.append(nn.Linear(self.in_dim, hid_dim))
                elif i == num_layers-1:
                    self.conv_layers.append(nn.Linear(hid_dim, self.out_dim))
                else:
                    self.conv_layers.append(nn.Linear(hid_dim, hid_dim))

    def forward(self, h):
        if self.num_layers == 1:
            h = self.lin(h)
        else:
            for layer in self.conv_layers[:-1]:
                h = F.relu(layer(h))
                h = F.dropout(h, p=self.dropout, training=self.training)
            h = self.conv_layers[-1](h)
        return h


def makeDoubleStochastic(h, max_iterations=1000, delta_limit=1e-12):
    converge = False
    i = 0
    while not converge and i < max_iterations:
        prev_h = h.clone()
        h /= h.sum(0, keepdim=True)
        h /= h.sum(1, keepdim=True)

        delta = torch.linalg.norm(h - prev_h, ord=1)
        # print(i, delta)
        if delta < delta_limit:
            converge = True
        i += 1
    # if i == max_iterations:
    #     print("makeDoubleStochasticH: maximum number of iterations reached.")

    return h


def get_myo_h(A, y_true, prior_estimation, mask) -> [Tensor, Tensor]:
    if A.has_value():
        A.set_value_(None)
    B = prior_estimation.clone()
    B[mask] = F.one_hot(y_true.view(-1)).float()[mask]

    # Y_true = F.one_hot(y_true.view(-1)).float()
    # H_true = torch.matmul(Y_true.transpose(0, 1), matmul(A, Y_true)) \
    #          / torch.matmul(Y_true.transpose(0, 1), matmul(A, Y_true)).sum(-1, keepdim=True)
    # H_true = makeDoubleStochastic(H_true)

    Y = torch.zeros_like(F.one_hot(y_true.view(-1))).float()
    Y[mask] = F.one_hot(y_true.view(-1)).float()[mask]
    # H = torch.matmul(Y.transpose(0, 1), matmul(A, B))
    H = torch.matmul(matmul(A, Y).transpose(0, 1), B)
    H[H == 0] = 1e-7
    H = makeDoubleStochastic(H)

    return H, B


def epoch_eval(y_true, out, eval_mask, test_mask, epoch, verbose):
    # epoch evaluation
    pred = out.max(1)[1]
    acc_eval = pred[eval_mask].eq(y_true[eval_mask]).sum().item() / eval_mask.sum().item()
    acc_test = pred[test_mask].eq(y_true[test_mask]).sum().item() / test_mask.sum().item()
    if verbose:
        print('Layer {}, Evaluation acc: {:.4f}, Test acc: {:.4f}'.format(
            epoch, acc_eval, acc_test
        ))
    return acc_eval, acc_test


class LabelPropagationHeterophily(MessagePassing):
    def __init__(self, max_layers: int, num_hops: int = 1):
        super(LabelPropagationHeterophily, self).__init__(aggr='add')
        self.max_layers = max_layers
        self.num_hops = num_hops

    def forward(
            self, y_true: Tensor, y_soft: Tensor, alpha: [Tensor, float],
            eval_mask: Tensor, test_mask: Tensor, 
            adj: Adj, edge_weight: OptTensor = None, echo_weight: OptTensor = None,
            select_eval: bool = True,
            verbose: bool = True, post_step: Callable = lambda y: y.clamp_(0., 1.)
    ) -> (int, Tensor):
        # init
        best_eval = best_test = best_epoch = 0
        best_out = out = y_soft.clone()
        res = (1 - alpha) * out

        for epoch in range(1, self.max_layers+1):
            adj = adj.set_value(value=edge_weight)
            layer_out = out.clone()
            for _ in range(self.num_hops):
                layer_out = self.propagate(
                    edge_index=adj, x=layer_out, edge_weight=edge_weight, size=None
                )
            if epoch > 1 and echo_weight is not None:
                adj = adj.set_value(value=echo_weight)
                echo = self.propagate(
                    edge_index=adj, x=out, edge_weight=echo_weight, size=None
                )
                out = layer_out - echo
            else:
                out = layer_out
            out = F.normalize(out, p=1, dim=-1)
            out.mul_(alpha).add_(res)
            out = post_step(out)

            if select_eval:
                acc_eval, acc_test = epoch_eval(
                    y_true=y_true, out=out, eval_mask=eval_mask, test_mask=test_mask, epoch=epoch, verbose=verbose
                )
                # save best states
                if acc_eval > best_eval:
                    best_epoch, best_eval, best_test, best_out = epoch, acc_eval, acc_test, out.clone()
                    # if verbose:
                    #     print('Update edge weight!')
                    # # TODO: speed up this calculation
                    # edge_weight = get_edge_weight(y_true, best_out, adj, spread_mask)

        if select_eval:
            return best_epoch, best_eval, best_test, best_out
        else:
            acc_eval, acc_test = epoch_eval(
                y_true=y_true, out=out, eval_mask=eval_mask, test_mask=test_mask, epoch=epoch,
                verbose=verbose
            )
            return -1, acc_eval, acc_test, out
        # return best_epoch, best_eval, best_test, best_out, -1, acc_eval, acc_test, out

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        assert adj_t.has_value()
        edge_weight = adj_t.storage.value()
        if len(edge_weight.size()) == 1:
            # adj_t.set_value_(value=edge_weight)
            return matmul(adj_t, x, reduce=self.aggr)

        elif len(edge_weight.size()) == 2:
            res = []
            for idx in range(edge_weight.size(1)):
                adj_t = adj_t.set_value(edge_weight[:, idx])
                res.append(matmul(adj_t, x[:, idx].view(-1, 1), reduce=self.aggr))
            return torch.cat(res, dim=-1)

    def __repr__(self):
        return '{}(max_layers={})'.format(self.__class__.__name__, self.max_layers)


class LabelPropagationHeterophilyTog(MessagePassing):
    def __init__(self, max_layers: int, num_hops: int = 1):
        super(LabelPropagationHeterophilyTog, self).__init__(aggr='add')
        self.max_layers = max_layers
        self.num_hops = num_hops
        self.LPHete = LabelPropagationHeterophily(max_layers=max_layers, num_hops=num_hops)

    def forward(
            self, y_true: Tensor, y_soft: Tensor, alpha: [Tensor, float],
            spread_mask: Tensor, eval_mask: Tensor, test_mask: Tensor,
            adj: Adj, edge_weight: OptTensor = None, echo_weight: OptTensor = None,
            select_eval: bool = True,
            verbose: bool = True, post_step: Callable = lambda y: y.clamp_(0., 1.),
    ) -> (int, Tensor):
        """"""
        assert y_true.unique().size(0) == y_soft.size(1)

        # propagate prior belief
        out_1 = y_soft.clone()
        # propagate prior belief + True labels of spread nodes
        out_2 = y_soft.clone()
        out_2[spread_mask] = F.one_hot(y_true.view(-1)).float()[spread_mask]
        # propagate True labels of spread  (tested: same as the original LP)
        out_3 = torch.zeros_like(F.one_hot(y_true.view(-1)).float())
        out_3[spread_mask] = F.one_hot(y_true.view(-1)).float()[spread_mask]
        # propagate prior belief of spread nodes
        out_4 = torch.zeros_like(F.one_hot(y_true.view(-1)).float())
        out_4[spread_mask] = y_soft[spread_mask]

        if verbose:
            print('Method 1')
        epoch_1, eval_1, test_1, out_1 = self.LPHete(
            y_true=y_true, y_soft=out_1, alpha=alpha,
            eval_mask=eval_mask, test_mask=test_mask,
            adj=adj, edge_weight=edge_weight, echo_weight=echo_weight, 
            verbose=verbose, post_step=post_step, select_eval=select_eval,
        )
        if verbose:
            print('Method 2')
        epoch_2, eval_2, test_2, out_2 = self.LPHete(
            y_true=y_true, y_soft=out_2, alpha=alpha,
            eval_mask=eval_mask, test_mask=test_mask,
            adj=adj, edge_weight=edge_weight, echo_weight=echo_weight, 
            verbose=verbose, post_step=post_step, select_eval=select_eval,
        )
        if verbose:
            print('Method 3')
        epoch_3, eval_3, test_3, out_3 = self.LPHete(
            y_true=y_true, y_soft=out_3, alpha=alpha,
            eval_mask=eval_mask, test_mask=test_mask,
            adj=adj, edge_weight=edge_weight, echo_weight=echo_weight, 
            verbose=verbose, post_step=post_step, select_eval=select_eval,
        )
        if verbose:
            print('Method 4')
        epoch_4, eval_4, test_4, out_4 = self.LPHete(
            y_true=y_true, y_soft=out_4, alpha=alpha,
            eval_mask=eval_mask, test_mask=test_mask,
            adj=adj, edge_weight=edge_weight, echo_weight=echo_weight, 
            verbose=verbose, post_step=post_step, select_eval=select_eval,
        )
        res_eval = [eval_1, eval_2, eval_3, eval_4]
        idx = res_eval.index(max(res_eval))
        best_epoch = [epoch_1, epoch_2, epoch_3, epoch_4][idx]
        best_eval = [eval_1, eval_2, eval_3, eval_4][idx]
        best_test = [test_1, test_2, test_3, test_4][idx]
        best_out = [out_1, out_2, out_3, out_4][idx]

        return idx, best_epoch, best_eval, best_test, best_out

    def __repr__(self):
        return '{}(max_layers={})'.format(
            self.__class__.__name__,
            self.max_layers
        )
