import argparse
import warnings
import seaborn as sns
import time
import os
import math

import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_self_loops, get_laplacian
from sklearn.metrics import roc_auc_score

from utils import cheby
from dataset_loader import HeterophilousGraphDataset
warnings.filterwarnings("ignore")


def presum_tensor(h, initial_val):
    length = len(h) + 1
    temp = torch.zeros(length)
    temp[0] = initial_val
    for idx in range(1, length):
        temp[idx] = temp[idx-1] + h[idx-1]
    return temp

def preminus_tensor(h, initial_val):
    length = len(h) + 1
    temp = torch.zeros(length)
    temp[0] = initial_val
    for idx in range(1, length):
        temp[idx] = temp[idx-1] - h[idx-1]
    return temp

def reverse_tensor(h):
    temp = torch.zeros_like(h)
    length = len(temp)
    for idx in range(0, length):
        temp[idx] = h[length-1-idx]
    return temp


class ChebnetII_prop(MessagePassing):
    def __init__(self, K, **kwargs):
        super(ChebnetII_prop, self).__init__(aggr='add', **kwargs)
        self.K = K
        self.initial_val_low = Parameter(torch.tensor(2.0), requires_grad=False)
        self.temp_low = Parameter(torch.Tensor(self.K), requires_grad=True)
        self.temp_high = Parameter(torch.Tensor(self.K), requires_grad=True)
        self.initial_val_high = Parameter(torch.tensor(0.0), requires_grad=False)
        self.reset_parameters()

    def reset_parameters(self):
        self.temp_low.data.fill_(2.0/self.K)
        self.temp_high.data.fill_(2.0/self.K)

    def forward(self, x, edge_index, edge_weight=None, highpass=True):
        if highpass:
            TEMP = F.relu(self.temp_high)
            coe_tmp = presum_tensor(TEMP, self.initial_val_high)
        else:
            TEMP = F.relu(self.temp_low)
            coe_tmp = preminus_tensor(TEMP, self.initial_val_low)

        coe = coe_tmp.clone()
        for i in range(self.K + 1):
            coe[i] = coe_tmp[0] * cheby(i, math.cos((self.K + 0.5) * math.pi / (self.K + 1)))
            for j in range(1, self.K + 1):
                x_j = math.cos((self.K - j + 0.5) * math.pi / (self.K + 1))
                coe[i] = coe[i] + coe_tmp[j] * cheby(i, x_j)
            coe[i] = 2 * coe[i] / (self.K + 1)

        edge_index1, norm1 = get_laplacian(edge_index, edge_weight, normalization='sym', dtype=x.dtype, num_nodes=x.size(self.node_dim))
        edge_index_tilde, norm_tilde = add_self_loops(edge_index1, norm1, fill_value=-1.0, num_nodes=x.size(self.node_dim))

        Tx_0 = x
        Tx_1 = self.propagate(edge_index_tilde, x=x, norm=norm_tilde, size=None)
        out = coe[0] / 2 * Tx_0 + coe[1] * Tx_1
        for i in range(2, self.K + 1):
            Tx_2 = self.propagate(edge_index_tilde, x=Tx_1, norm=norm_tilde, size=None)
            Tx_2 = 2 * Tx_2 - Tx_0
            out = out + coe[i] * Tx_2
            Tx_0, Tx_1 = Tx_1, Tx_2
        return out

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def __repr__(self):
        return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K, self.temp)
    

class LogReg(nn.Module):
    def __init__(self, hid_dim, n_classes):
        super(LogReg, self).__init__()
        self.fc = nn.Linear(hid_dim, n_classes)

    def forward(self, x):
        ret = self.fc(x)
        return ret
    

class Discriminator(nn.Module):
    def __init__(self, dim):
        super(Discriminator, self).__init__()
        self.fn = nn.Bilinear(dim, dim, 1)


    def forward(self, h1, h2, h3, h4, c):
        c_x = c.expand_as(h1).contiguous()
        # positive
        sc_1 = self.fn(h2, c_x).squeeze(1)
        sc_2 = self.fn(h1, c_x).squeeze(1)
        # negative
        sc_3 = self.fn(h4, c_x).squeeze(1)
        sc_4 = self.fn(h3, c_x).squeeze(1)

        logits = torch.cat((sc_1, sc_2, sc_3, sc_4))
        return logits


class Model(nn.Module):
    def __init__(self, in_dim, out_dim, K, dprate, dropout, is_bns, act_fn):
        super(Model, self).__init__()
        self.encoder = ChebNetII(num_features=in_dim, hidden=out_dim, K=K, dprate=dprate, dropout=dropout, is_bns=is_bns, act_fn=act_fn)
        self.disc = Discriminator(in_dim)
        self.act_fn = nn.ReLU()
        self.alpha = nn.Parameter(torch.tensor(0.5), requires_grad=True)
        self.beta = nn.Parameter(torch.tensor(0.5), requires_grad=True)

    def get_embedding(self, edge_index, feat):
        h1 = self.encoder(x=feat, edge_index=edge_index, highpass=True)
        h2 = self.encoder(x=feat, edge_index=edge_index, highpass=False)
        h = torch.mul(self.alpha, h1) + torch.mul(self.beta, h2)
        return h.detach()


    def forward(self, edge_index, feat, shuf_feat):
        # positive
        h1 = self.encoder(x=feat, edge_index=edge_index, highpass=True)
        h2 = self.encoder(x=feat, edge_index=edge_index, highpass=False)

        # negative
        h3 = self.encoder(x=shuf_feat, edge_index=edge_index, highpass=True)
        h4 = self.encoder(x=shuf_feat, edge_index=edge_index, highpass=False)

        h = torch.mul(self.alpha, h1) + torch.mul(self.beta, h2)
        c = self.act_fn(torch.mean(h, dim=0))
        out = self.disc(h1, h2, h3, h4, c)
        return out


class ChebNetII(torch.nn.Module):
    def __init__(self, num_features, hidden=512, K=10, dprate=0.50, dropout=0.50, is_bns=False, act_fn='relu'):
        super(ChebNetII, self).__init__()
        self.prop1 = ChebnetII_prop(K=K)
        assert act_fn in ['relu', 'prelu']
        self.act_fn = nn.PReLU() if act_fn == 'prelu' else nn.ReLU()
        self.bn = torch.nn.BatchNorm1d(num_features, momentum=0.01)
        self.is_bns = is_bns
        self.dprate = dprate
        self.dropout = dropout
        self.reset_parameters()

    def reset_parameters(self):
        self.prop1.reset_parameters()

    def forward(self, x, edge_index, highpass=True):
        if self.dprate == 0.0:
            x = self.prop1(x, edge_index, highpass=highpass)
        else:
            x = F.dropout(x, p=self.dprate, training=self.training)
            x = self.prop1(x, edge_index, highpass=highpass)

        x = F.dropout(x, p=self.dropout, training=self.training)
        if self.is_bns:
            x = self.bn(x)
        x = self.act_fn(x)
        return x
    

parser = argparse.ArgumentParser(description="PolyGCL")
parser.add_argument('--seed', type=int, default=42, help='Random seed.')  # Default seed same as GCNII
parser.add_argument('--device', type=int, default=0, help='device id')
parser.add_argument("--dataset", type=str, default="roman_empire", help="Name of dataset.")

parser.add_argument("--unsup_epochs", type=int, default=500, help="Training epochs.")
parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
parser.add_argument("--prop_lr", type=float, default=0.010, help="Learning rate of prop.")
parser.add_argument("--lr1", type=float, default=0.001, help="Learning rate of PolyGCL.")
parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
parser.add_argument("--prop_wd", type=float, default=0.0, help="Weight decay of PolyGCL prop.")
parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of PolyGCL.")
parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")
parser.add_argument("--hidden", type=int, default=512, help="Hidden layer dim.")

parser.add_argument("--K", type=int, default=10, help="Layer of encoder.")
parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
parser.add_argument('--dprate', type=float, default=0.5, help='dropout for propagation layer.')
parser.add_argument('--is_bns', type=bool, default=False)
parser.add_argument('--act_fn', default='relu', help='activation function')
args = parser.parse_args()


device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)



if __name__ == "__main__":
    print(args)
    # Step 1: Load data =================================================================== #
    root = './data/'
    dataset = HeterophilousGraphDataset(root=root,name=args.dataset)
    data = dataset[0]

    feat = data.x
    label = data.y
    edge_index = data.edge_index

    n_feat = feat.shape[1]
    n_classes = np.unique(label).shape[0]

    edge_index = edge_index.to(device)
    feat = feat.to(device)

    n_node = feat.shape[0]
    lbl1 = torch.ones(n_node * 2)
    lbl2 = torch.zeros(n_node * 2)
    lbl = torch.cat((lbl1, lbl2))

    # Step 2: Create model =================================================================== #
    model = Model(in_dim=n_feat, out_dim=args.hidden, K=args.K, dprate=args.dprate, dropout=args.dropout, is_bns=args.is_bns, act_fn=args.act_fn)
    model = model.to(device)

    lbl = lbl.to(device)

    # Step 3: Create training components ===================================================== #
    optimizer = torch.optim.Adam([{'params': model.disc.parameters(), 'weight_decay': args.wd1, 'lr': args.lr1},
                                  {'params': model.encoder.prop1.parameters(), 'weight_decay': args.prop_wd, 'lr': args.prop_lr},
                                  {'params': model.alpha, 'weight_decay': args.prop_wd, 'lr': args.prop_lr},
                                  {'params': model.beta, 'weight_decay': args.prop_wd, 'lr': args.prop_lr}
                                  ])

    loss_fn = nn.BCEWithLogitsLoss()

    # Step 4: Training epochs ================================================================ #
    best = float("inf")
    cnt_wait = 0
    best_t = 0

    #generate a random number --> later use as a tag for saved model
    tag = str(int(time.time()))

    for epoch in range(args.unsup_epochs):
        model.train()
        optimizer.zero_grad()

        shuf_idx = np.random.permutation(n_node)
        shuf_feat = feat[shuf_idx, :]

        out = model(edge_index, feat, shuf_feat)
        loss = loss_fn(out, lbl)

        loss.backward()
        optimizer.step()
        # if epoch % 20 == 0:
            # print("Epoch: {0}, Loss: {1:0.4f}".format(epoch, loss.item()))

        if loss < best:
            best = loss
            best_t = epoch
            cnt_wait = 0
            torch.save(model.state_dict(), 'unsup_pkl/polygcl_best_model_'+ args.dataset + tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            # print("Early stopping")
            break

    # print('Loading {}torch epoch'.format(best_t + 1))
    model.load_state_dict(torch.load('unsup_pkl/polygcl_best_model_'+ args.dataset + tag + '.pkl'))
    model.eval()
    embeds = model.get_embedding(edge_index, feat)
    os.remove('unsup_pkl/polygcl_best_model_'+ args.dataset + tag + '.pkl')


    # print("=== Evaluation ===")
    ''' Linear Evaluation '''
    results = []

    label = label if args.dataset in ['roman_empire', 'amazon_ratings'] else label.to(torch.float)
    label = label.to(device)

    for i in range(10):
        assert label.shape[0] == n_node

        train_mask, val_mask, test_mask = data.train_mask[:, i].to(device), data.val_mask[:, i].to(device), data.test_mask[:, i].to(device)

        assert torch.sum(train_mask + val_mask + test_mask) == n_node

        train_embs = embeds[train_mask]
        val_embs = embeds[val_mask]
        test_embs = embeds[test_mask]

        train_labels = label[train_mask]
        val_labels = label[val_mask]
        test_labels = label[test_mask]

        best_val_acc = 0
        eval_acc = 0
        bad_counter = 0

        n_classes = n_classes if args.dataset in ['roman_empire', 'amazon_ratings'] else 1

        logreg = LogReg(hid_dim=train_embs.shape[1], n_classes=n_classes)
        opt = torch.optim.Adam(logreg.parameters(), lr=args.lr2, weight_decay=args.wd2)
        logreg = logreg.to(device)

        loss_fn = nn.CrossEntropyLoss() if args.dataset in ['roman_empire', 'amazon_ratings'] else nn.BCEWithLogitsLoss()

        for epoch in range(2000):
            logreg.train()
            opt.zero_grad()
            logits = logreg(train_embs)
            logits = logits if args.dataset in ['roman_empire', 'amazon_ratings'] else logits.squeeze(-1)

            loss = loss_fn(logits, train_labels)
            loss.backward()
            opt.step()

            logreg.eval()
            with torch.no_grad():
                val_logits = logreg(val_embs)
                test_logits = logreg(test_embs)

                if args.dataset in ['roman_empire', 'amazon_ratings']:
                    val_preds = torch.argmax(val_logits, dim=1)
                    test_preds = torch.argmax(test_logits, dim=1)
                    val_acc = torch.sum(val_preds == val_labels).float() / val_labels.shape[0]
                    test_acc = torch.sum(test_preds == test_labels).float() / test_labels.shape[0]
                else:
                    val_acc = roc_auc_score(y_true=val_labels.cpu().numpy(), y_score=val_logits.squeeze(-1).cpu().numpy())
                    test_acc = roc_auc_score(y_true=test_labels.cpu().numpy(), y_score=test_logits.squeeze(-1).cpu().numpy())

                if val_acc >= best_val_acc:
                    bad_counter = 0
                    best_val_acc = val_acc
                    if test_acc > eval_acc:
                        eval_acc = test_acc
                else:
                    bad_counter += 1

        # print(i, 'Linear evaluation accuracy:{:.4f}'.format(eval_acc))
        if torch.is_tensor(eval_acc):
            results.append(eval_acc.cpu().data)
        else:
            results.append(eval_acc)

    results = [v.item() for v in results]
    test_acc_mean = np.mean(results, axis=0) * 100
    values = np.asarray(results, dtype=object)
    uncertainty = np.max(
        np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')