import argparse
import warnings
import seaborn as sns

import torch
from alive_progress import alive_bar
import random
import numpy as np
import torch as th
import torch.nn as nn
from sklearn.metrics import roc_auc_score
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.utils import to_dense_adj
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

from unsup_model import LogReg
from dataset_loader import HeterophilousGraphDataset
from utils import set_seed, random_splits
warnings.filterwarnings("ignore")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--K', type=int, default=10, help='propagation steps.')

    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")

    parser.add_argument('--random_split', action='store_true')
    parser.add_argument('--eval', default='linear', type=str)
    parser.add_argument('--knn_k', type=int, default=10)
    parser.add_argument('--knn_temp', type=float, default=5)
    parser.add_argument('--train_rate', type=float, default=0.6)
    parser.add_argument('--val_rate', type=float, default=0.2)
    args = parser.parse_args()
    return args


def unsupervised_test_linear(data, embeds, n_classes, device, args):
    results = []

    label = data.y if args.dataset in ['roman_empire', 'amazon_ratings'] else data.y.to(torch.float)
    label = label.to(args.device)
    assert label.shape[0] == n_node
    
    for i in range(args.runs):
        train_mask, val_mask, test_mask = data.train_mask[:, i].to(args.device), data.val_mask[:, i].to(args.device), data.test_mask[:, i].to(args.device)
        assert torch.sum(train_mask + val_mask + test_mask) == n_node

        train_embs = embeds[train_mask]
        val_embs = embeds[val_mask]
        test_embs = embeds[test_mask]

        train_labels = label[train_mask]
        val_labels = label[val_mask]
        test_labels = label[test_mask]

        best_val_acc = 0
        eval_acc = 0
        bad_counter = 0

        logreg = LogReg(hid_dim=train_embs.size(1), n_classes=n_classes)
        opt = th.optim.Adam(logreg.parameters(), lr=args.lr2, weight_decay=args.wd2)
        logreg = logreg.to(args.device)

        loss_fn = nn.CrossEntropyLoss() if args.dataset in ['roman_empire', 'amazon_ratings'] else nn.BCEWithLogitsLoss()

        for epoch in range(2000):
            logreg.train()
            opt.zero_grad()
            logits = logreg(train_embs)
            logits = logits if args.dataset in ['roman_empire', 'amazon_ratings'] else logits.squeeze(-1)

            loss = loss_fn(logits, train_labels)
            loss.backward()
            opt.step()

            logreg.eval()
            with th.no_grad():
                val_logits = logreg(val_embs)
                test_logits = logreg(test_embs)

                if args.dataset in ['roman_empire', 'amazon_ratings']:
                    val_preds = th.argmax(val_logits, dim=1)
                    test_preds = th.argmax(test_logits, dim=1)
                    val_acc = th.sum(val_preds == val_labels).float() / val_labels.shape[0]
                    test_acc = th.sum(test_preds == test_labels).float() / test_labels.shape[0]
                else:
                    val_acc = roc_auc_score(y_true=val_labels.cpu().numpy(), y_score=val_logits.squeeze(-1).cpu().numpy())
                    test_acc = roc_auc_score(y_true=test_labels.cpu().numpy(), y_score=test_logits.squeeze(-1).cpu().numpy())

                if val_acc >= best_val_acc:
                    bad_counter = 0
                    best_val_acc = val_acc
                    if test_acc > eval_acc:
                        eval_acc = test_acc
                else:
                    bad_counter += 1

        if torch.is_tensor(eval_acc):
            results.append(eval_acc.cpu().data.item())
        else:
            results.append(eval_acc)
    return results


@torch.no_grad()
def unsupervised_test_knn(data, embeds, n_classes, device, args):
    results = []

    label = data.y if args.dataset in ['roman_empire', 'amazon_ratings'] else data.y.to(torch.float)
    label = label.to(args.device)
    assert label.shape[0] == n_node
    
    for i in range(args.runs):
        train_mask, val_mask, test_mask = data.train_mask[:, i].to(args.device), data.val_mask[:, i].to(args.device), data.test_mask[:, i].to(args.device)
        assert torch.sum(train_mask + val_mask + test_mask) == data.num_nodes

        train_embs = embeds[train_mask].detach().cpu().numpy()
        val_embs = embeds[val_mask].detach().cpu().numpy()
        test_embs = embeds[test_mask].detach().cpu().numpy()

        train_labels = label[train_mask].detach().cpu().numpy()
        val_labels = label[val_mask].detach().cpu().numpy()
        test_labels = label[test_mask].detach().cpu().numpy()

        neigh = KNeighborsClassifier(n_neighbors=args.knn_k)
        neigh.fit(train_embs, train_labels)
        preds = neigh.predict(test_embs)
        probs = neigh.predict_proba(test_embs)
        probs = np.max(probs, axis=1, keepdims=True)

        if args.dataset in ['roman_empire', 'amazon_ratings']:
            acc = accuracy_score(test_labels, preds)
        else:
            acc = roc_auc_score(y_true=test_labels, y_score=probs)

        print(i, 'KNN evaluation accuracy:{:.4f}'.format(acc))
        if torch.is_tensor(acc):
            results.append(acc.cpu().data.item())
        else:
            results.append(acc)
    return results


if __name__ == "__main__":
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    root = './data/'
    dataset = HeterophilousGraphDataset(root=root,name=args.dataset)
    data = dataset[0]
    data = data.to(device)

    n_classes = dataset.num_classes if args.dataset in ['roman_empire', 'amazon_ratings'] else 1
    n_node = data.x.size(0)
    n_feat = data.x.size(1)

    edge_index, _ = gcn_norm(edge_index=data.edge_index)
    A = to_dense_adj(edge_index=edge_index).squeeze(0)
    embeds = data.x
    for _ in range(args.K):
        embeds = A @ embeds

    results = unsupervised_test_linear(data=data, embeds=embeds, n_classes=n_classes, device=device, args=args)

    test_acc_mean = np.mean(results, axis=0) * 100
    values = np.asarray(results, dtype=object)
    uncertainty = np.max(
        np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')