import warnings
import time
import numpy as np
import argparse
import seaborn as sns
from sklearn.metrics import roc_auc_score

import torch
import torch.nn as nn

from models import *
from dataset_loader import HeterophilousGraphDataset
from utils import set_seed
from unsup_model import DGI, LogReg
warnings.filterwarnings("ignore")


def print_params(model, string):
    print(f'----------- {string} ----------')
    params = model.encoder.named_parameters()
    for name, param in params:
        print(name)
        print(param)
    print('-----------------------------------')


def get_encoder(dataset, args):
    if args.net == 'ChebNetII':
        encoder = ChebNetII(dataset=dataset, args=args)
    elif args.net == 'GCN':
        encoder = GCN_Net(dataset=dataset, args=args)
    elif args.net == 'BernNet':
        encoder = BernNet(dataset=dataset, args=args)
    elif args.net == 'GPRGNN':
        encoder = GPRGNN(dataset=dataset, args=args)
    elif args.net == 'PropChebNetII':
        encoder = PropChebNetII(dataset=dataset, args=args)
    elif args.net == 'PropBernNet':
        encoder = PropBernNet(dataset=dataset, args=args)
    elif args.net == 'PropGPRGNN':
        encoder = PropGPRGNN(dataset=dataset, args=args)
    else:
        raise NotImplementedError
    return encoder


def unsupervised_test_linear(embeds, label, train_mask, val_mask, test_mask, n_classes):
    train_embs = embeds[train_mask]
    val_embs = embeds[val_mask]
    test_embs = embeds[test_mask]

    train_labels = label[train_mask]
    val_labels = label[val_mask]
    test_labels = label[test_mask]

    best_val_acc = 0
    eval_acc = 0
    bad_counter = 0

    n_classes = n_classes if args.dataset in ['roman_empire', 'amazon_ratings'] else 1

    logreg = LogReg(hid_dim=train_embs.size(1), n_classes=n_classes)
    opt = torch.optim.Adam(logreg.parameters(), lr=args.lr2, weight_decay=args.wd2)
    logreg = logreg.to(args.device)

    loss_fn = nn.CrossEntropyLoss() if args.dataset in ['roman_empire', 'amazon_ratings'] else nn.BCEWithLogitsLoss()

    for epoch in range(2000):
        logreg.train()
        opt.zero_grad()
        logits = logreg(train_embs)
        logits = logits if args.dataset in ['roman_empire', 'amazon_ratings'] else logits.squeeze(-1)

        loss = loss_fn(logits, train_labels)
        loss.backward()
        opt.step()

        logreg.eval()
        with torch.no_grad():
            val_logits = logreg(val_embs)
            test_logits = logreg(test_embs)

            if args.dataset in ['roman_empire', 'amazon_ratings']:
                val_preds = torch.argmax(val_logits, dim=1)
                test_preds = torch.argmax(test_logits, dim=1)
                val_acc = torch.sum(val_preds == val_labels).float() / val_labels.shape[0]
                test_acc = torch.sum(test_preds == test_labels).float() / test_labels.shape[0]
                val_acc = val_acc.cpu().data.item()
                test_acc = test_acc.cpu().data.item()
            else:
                val_acc = roc_auc_score(y_true=val_labels.cpu().numpy(), y_score=val_logits.squeeze(-1).cpu().numpy())
                test_acc = roc_auc_score(y_true=test_labels.cpu().numpy(), y_score=test_logits.squeeze(-1).cpu().numpy())

            if val_acc >= best_val_acc:
                bad_counter = 0
                best_val_acc = val_acc
                if test_acc > eval_acc:
                    eval_acc = test_acc
            else:
                bad_counter += 1
    return eval_acc


def unsupervised_learning(dataset, data, args, device):
    encoder = get_encoder(dataset=dataset, args=args)

    n_node = data.x.shape[0]
    lbl1 = torch.ones(n_node)
    lbl2 = torch.zeros(n_node)
    lbl = torch.cat((lbl1, lbl2)).to(device)

    if args.net in ['PropChebNetII', 'PropBernNet', 'PropGPRGNN']:
        out_dim = dataset.num_node_features * 2 if args.residual else dataset.num_node_features
        if args.hidden < out_dim:
            rand_idx = torch.randperm(n=data.num_node_features)
            feat = data.x[:, rand_idx[:args.hidden]]
            out_dim = args.hidden * 2 if args.residual else args.hidden
        else:
            feat = data.x 
    else:
        out_dim = args.hidden
        feat = data.x

    model = DGI(encoder=encoder, out_dim=out_dim).to(device)

    optimizer = torch.optim.Adam([{'params': model.parameters(), 'weight_decay': args.wd1, 'lr': args.lr1}])

    best = float("inf")
    cnt_wait = 0
    best_t = 0
    unsup_tag = str(int(time.time()))

    for epoch in range(args.unsup_epochs):
        model.train()
        optimizer.zero_grad()

        shuf_idx = np.random.permutation(n_node)
        out = model(edge_index=data.edge_index, feat=feat, shuf_feat=feat[shuf_idx, :])
        loss = model.loss_fn(out, lbl)

        loss.backward()
        optimizer.step()

        if loss < best:
            best = loss
            best_t = epoch
            cnt_wait = 0
            torch.save(model.state_dict(), 'unsup_pkl/' + 'dgi_cl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            break

    model.load_state_dict(torch.load('unsup_pkl/' + 'dgi_cl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl'))
    model.eval()
    embeds = model.get_embedding(data.edge_index, data.x)
    return embeds


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, default='ChebNetII')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden', type=int, default=64, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    parser.add_argument('--K', type=int, default=10, help='propagation steps.')
    parser.add_argument('--alpha', type=float, default=0.1, help='alpha for APPN/GPRGNN.')
    parser.add_argument('--dprate', type=float, default=0.0, help='dropout for propagation layer.')
    parser.add_argument('--q', type=int, default=0, help='The constant for ChebBase.')
    parser.add_argument('--Init', type=str,choices=['SGC', 'PPR', 'NPPR', 'Random', 'WS', 'Null'], default='PPR', help='initialization for GPRGNN.')
    
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    # unsupervised learning
    parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--unsup_epochs", type=int, default=500, help="Unupservised training epochs.")
    parser.add_argument("--lr1", type=float, default=0.001, help="Learning rate of the unsupervised model.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of the unsupervised model.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")

    parser.add_argument('--residual', action='store_true')
    parser.add_argument('--random', action='store_true')
    parser.add_argument('--constant', action='store_true')
    parser.add_argument('--reg_coef', default=0.5, type=float)

    # norm layer
    parser.add_argument('--norm_type', default='none', type=str)
    parser.add_argument('--scale', default=0.5, type=float)
    parser.add_argument('--norm_x', action='store_true')
    parser.add_argument('--plusone', action='store_true')
    parser.add_argument('--layer_norm', action='store_true')

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    assert args.dataset in ['roman_empire', 'amazon_ratings', 'tolokers', 'minesweeper', 'questions']
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    root = './data/'
    dataset = HeterophilousGraphDataset(root=root,name=args.dataset)
    data = dataset[0]
    data = data.to(device)

    label = data.y
    n_classes = dataset.num_classes
    n_node = data.x.size(0)
    n_feat = data.x.size(1)
    
    embeds = unsupervised_learning(dataset=dataset, data=data, args=args, device=device)

    results = []
    label = label if args.dataset in ['roman_empire', 'amazon_ratings'] else label.to(torch.float)
    label = label.to(args.device)

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))

    for i in range(args.runs):
        train_mask, val_mask, test_mask = data.train_mask[:, i].to(args.device), data.val_mask[:, i].to(args.device), data.test_mask[:, i].to(args.device)
        assert torch.sum(train_mask + val_mask + test_mask) == data.num_nodes
        assert label.shape[0] == n_node

        eval_acc = unsupervised_test_linear(embeds, label, train_mask, val_mask, test_mask, n_classes)
        if torch.is_tensor(eval_acc):
            results.append(eval_acc.cpu().data)
        else:
            results.append(eval_acc)

    test_acc_mean = np.mean(results, axis=0) * 100
    values = np.asarray(results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')
