# Graph Neural Networks with Heterophily, AAAI 2021
# A PyTorch reproduce version

import os
import dgl
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from torch.nn import Parameter
import numpy as np
import scipy.sparse as sp
import math
import time
import random
import copy
from sklearn.metrics import accuracy_score as ACC
from utils import sys_normalized_adjacency, sparse_mx_to_torch_sparse_tensor


class CPGNN(nn.Module):

    def __init__(
            self,
            in_features: int,
            class_num: int,
            device,
            args,
        ) -> None:
        super().__init__()
        #------------- Parameters ----------------
        self.class_num = class_num
        self.device = device
        self.lr = args.lr
        self.l2_coef = args.l2_coef
        self.epochs = args.epochs
        self.patience = args.patience
        self.class_num = class_num
        
        self.backbone_type = args.backbone_type
        self.eta = args.eta
        self.K = args.K

        #---------------- Layer -------------------
        if self.backbone_type == 'MLP':
            self.backbone = MLP_(in_features, class_num, args.nhidden, args.nlayers)
        elif self.backbone_type == 'GCN':
            self.backbone = GCN_(in_features, class_num, args.nhidden, args.nlayers)


    def fit(self, graph, labels, train_mask, val_mask, test_mask):
        # model init
        graph = graph.to(self.device)
        labels = labels.to(self.device)
        self.train_mask = train_mask.to(self.device)
        self.valid_mask = val_mask.to(self.device)
        self.test_mask = test_mask.to(self.device)
        self.to(self.device)
        
        graph = graph.remove_self_loop().add_self_loop()
        adj = graph.adj(scipy_fmt='csr')
        adj_norm = sys_normalized_adjacency(adj)
        self.A = sparse_mx_to_torch_sparse_tensor(adj_norm).to(self.device)
        X = graph.ndata["feat"]
        n_nodes, _ = X.shape

        graph_ = graph.remove_self_loop()
        adj_ = graph_.adj(scipy_fmt='csr')
        self.A_prop = sparse_mx_to_torch_sparse_tensor(adj_).to(self.device)

        B = self.pretrain(X, labels)
        self.init_H(B, labels)

        best_epoch = 0
        best_acc = 0.
        cnt = 0
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.l2_coef)
        loss_fn = torch.nn.CrossEntropyLoss()
        best_state_dict = None

        for epoch in range(self.epochs):
            self.train()

            Z = self.forward(X)
            loss = loss_fn(Z[self.train_mask], labels[self.train_mask]) + self.eta * self.reg_loss()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            [train_acc, valid_acc, test_acc] = self.test(X, labels, [self.train_mask, self.valid_mask, self.test_mask])

            if valid_acc > best_acc:
                cnt = 0
                best_acc = valid_acc
                best_epoch = epoch
                best_state_dict = copy.deepcopy(self.state_dict())
                print(f'\nEpoch:{epoch}, Loss:{loss.item()}')
                print(f'train acc: {train_acc:.3f} valid acc: {valid_acc:.3f}, test acc: {test_acc:.3f}')

            else:
                cnt += 1
                if cnt == self.patience:
                    print(f"Early Stopping! Best Epoch: {best_epoch}, best val acc: {best_acc}")
                    break
        self.load_state_dict(best_state_dict)
        self.best_epoch = best_epoch

    def pretrain(self, X, labels):
        print("pretrining...")
        optimizer = torch.optim.Adam(self.backbone.parameters(), lr=self.lr, weight_decay=self.l2_coef)
        loss_fn = torch.nn.CrossEntropyLoss()
        best_epoch = 0
        best_acc = 0.
        best_state_dict = None
        best_B = None

        for epoch in range(200):
            self.backbone.train()

            B = self.backbone_forward(X)
            loss = loss_fn(B[self.train_mask], labels[self.train_mask])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            [train_acc, valid_acc, test_acc], B_ = self.backbone_test(X, labels, [self.train_mask, self.valid_mask, self.test_mask])

            if valid_acc > best_acc:
                best_acc = valid_acc
                best_B = B_
                best_state_dict = copy.deepcopy(self.backbone.state_dict())
                print(f'\nEpoch:{epoch}, Loss:{loss.item()}')
                print(f'train acc: {train_acc:.3f} valid acc: {valid_acc:.3f}, test acc: {test_acc:.3f}')

        self.backbone.load_state_dict(best_state_dict)
        return best_B

    def backbone_forward(self, X):
        if self.backbone_type == 'MLP':
            B = self.backbone(X)
        elif self.backbone_type == 'GCN':
            B = self.backbone(X, self.A)
        return F.softmax(B, dim=1)

    def backbone_test(self, X, labels, index_list):
        self.backbone.eval()
        with torch.no_grad():
            B = self.backbone_forward(X)
            y_pred = torch.argmax(B, dim=1)
        acc_list = []
        for index in index_list:
            acc_list.append(ACC(labels[index].cpu(), y_pred[index].cpu()))
        return acc_list, B

    def init_H(self, B_, labels):
        C_true = F.one_hot(labels, num_classes=self.class_num).float().to(self.device)
        B_[self.train_mask] = C_true[self.train_mask]
        C_train = torch.zeros_like(C_true).to(self.device)
        C_train[self.train_mask] = C_true[self.train_mask]
        H_ = torch.mm(C_train.t(), torch.sparse.mm(self.A, B_))
        H_ = self.makeDoubleStochasticH(H_)
        H_0 = (H_ + H_.t()) / 2 - 1 / self.class_num
        self.H = Parameter(H_0, requires_grad=True)

    
    def makeDoubleStochasticH(self, H, max_iterations=3000, delta=1e-7):
        converge = False
        prev_H = H
        i = 0
        while not converge and i < max_iterations:
            prev_H = H
            H /= torch.sum(H, dim=0, keepdims=True)
            H /= torch.sum(H, dim=1, keepdims=True)

            delta = torch.linalg.norm(H - prev_H, ord=1)
            if delta < 1e-12:
                converge = True
            i += 1
        if i == max_iterations:
            warnings.warn(
                "makeDoubleStochasticH: maximum number of iterations reached.")
        return H

    def reg_loss(self,):
        return torch.mean(torch.abs(torch.sum(self.H, dim=1)))

    def forward(self, X):
        B_0 = self.backbone_forward(X) - 1 / self.class_num 

        B = B_0
        for i in range(self.K):
            B = B_0 + torch.mm(torch.sparse.mm(self.A_prop, B), self.H)
        C = F.softmax(B, dim=1)
        return C

    def test(self, X, labels, index_list):
        self.eval()
        with torch.no_grad():
            Z = self.forward(X)
            y_pred = torch.argmax(Z, dim=1)
        acc_list = []
        for index in index_list:
            acc_list.append(ACC(labels[index].cpu(), y_pred[index].cpu()))
        return acc_list

    def predict(self, graph):
        self.eval()
        graph = graph.remove_self_loop().add_self_loop()
        graph = graph.to(self.device)
        X = graph.ndata['feat']
        with torch.no_grad():
            C = self.forward(X)
            y_pred = torch.argmax(C, dim=1)

        return y_pred.cpu(), C.cpu(), None



class MLP_(nn.Module):

    def __init__(
            self,
            in_dim: int,
            out_dim: int,
            hidden_dim: int,
            nlayers = 2,
        ) -> None:
        super().__init__()
        #------------- Parameters ----------------
        self.nlayers = nlayers
        layers = []
        if nlayers == 1:
            layers.append(nn.Linear(in_dim, out_dim))
        else:
            layers.append(nn.Linear(in_dim, hidden_dim))
            for i in range(nlayers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.Linear(hidden_dim, out_dim))
        self.layers = nn.ModuleList(layers)


    def forward(self, X):
        X = F.dropout(X, 0.5, training=self.training)
        if self.nlayers == 1:
            return self.layers[0](X)
        Z = X
        for i in range(self.nlayers - 1):
            Z = self.layers[i](Z)
            Z = nn.ReLU()(Z)
            Z = F.dropout(Z, 0.5, training=self.training)
        Z = self.layers[-1](Z)

        return Z


class GCN_(nn.Module):

    def __init__(
            self,
            in_dim: int,
            out_dim: int,
            hidden_dim: int,
            nlayers = 2,
        ) -> None:
        super().__init__()
        #------------- Parameters ----------------
        self.nlayers = nlayers
        layers = []
        if nlayers == 1:
            layers.append(nn.Linear(in_dim, out_dim))
        else:
            layers.append(nn.Linear(in_dim, hidden_dim))
            for i in range(nlayers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.Linear(hidden_dim, out_dim))
        self.layers = nn.ModuleList(layers)


    def forward(self, X, A):
        X = F.dropout(X, 0.5, training=self.training)
        if self.nlayers == 1:
            Z = self.layers[0](X)
            Z = torch.sparse.mm(A, Z)
            return Z
        Z = X
        for i in range(self.nlayers - 1):
            Z = self.layers[i](Z)
            Z = torch.sparse.mm(A, Z)
            Z = nn.ReLU()(Z)
            Z = F.dropout(Z, 0.5, training=self.training)
        Z = self.layers[-1](Z)
        Z = torch.sparse.mm(A, Z)

        return Z
