import warnings
import numpy as np
import argparse
import seaborn as sns

import torch
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.utils import to_dense_adj

from utils import set_seed, random_splits
from eval import unsupervised_test_linear
from dataset_loader import DataLoader
warnings.filterwarnings("ignore")


def print_params(model, string):
    print(f'----------- {string} ----------')
    params = model.encoder.named_parameters()
    for name, param in params:
        print(name)
        print(param)
    print('-----------------------------------')


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--K', type=int, default=1)

    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    data = data.to(device)

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))

    unsup_results = []
    for RP in range(args.runs):
        args.seed = SEEDS[RP]
        data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
        edge_index, _ = gcn_norm(edge_index=data.edge_index)
        A = to_dense_adj(edge_index=edge_index).squeeze(0)
        embeds = data.x
        for _ in range(args.K):
            embeds = A @ embeds

        eval_acc = unsupervised_test_linear(data=data, embeds=embeds, n_classes=dataset.num_classes, device=device, args=args)
        print(f'eval acc = {eval_acc}')
        unsup_results.append(eval_acc)

    test_acc_mean = np.mean(unsup_results) * 100
    values = np.asarray(unsup_results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')

