import argparse
import time
from alive_progress import alive_bar
from tqdm import tqdm
import seaborn as sns
import numpy as np
import pickle as pkl

import torch
import torch.nn.functional as F

from sup_models import *
from dataset_loader import DataLoader
from utils import random_splits, set_seed


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str, default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, choices=['GCN', 'APPNP', 'ChebNet','MLP','ChebNetII','ChebBase', 'SGC', 'BernNet', 'GPRGNN'], default='ChebNetII')

    parser.add_argument('--epochs', type=int, default=1000, help='max epochs.')
    parser.add_argument('--lr', type=float, default=0.01, help='learning rate.')       
    parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay.')  
    parser.add_argument('--prop_lr', type=float, default=0.01, help='learning rate for propagation layer.')
    parser.add_argument('--prop_wd', type=float, default=0.0005, help='learning rate for propagation layer.')
    parser.add_argument('--early_stopping', type=int, default=200, help='early stopping.')

    parser.add_argument('--hidden', type=int, default=64, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    parser.add_argument('--K', type=int, default=10, help='propagation steps.')
    parser.add_argument('--dprate', type=float, default=0.5, help='dropout for propagation layer.')
    parser.add_argument('--q', type=int, default=0, help='The constant for ChebBase.')
    parser.add_argument('--alpha', type=float, default=0.1, help='alpha for APPN/GPRGNN.')
    parser.add_argument('--Init', type=str,choices=['SGC', 'PPR', 'NPPR', 'Random', 'WS', 'Null'], default='PPR', help='initialization for GPRGNN.')

    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')
    parser.add_argument('--full', type=bool, default=False, help='full-supervise with random splits')

    parser.add_argument('--pretrain', action='store_true')
    parser.add_argument('--fix_prop', action='store_true')
    parser.add_argument('--fix_linear', action='store_true')
    parser.add_argument('--random_prop', action='store_true')

    parser.add_argument('--residual', action='store_true')
    parser.add_argument('--num_layers', type=int, default=2)
    args = parser.parse_args()
    return args


def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(x=data.x, edge_index=data.edge_index)[data.train_mask]
    loss = F.nll_loss(out, data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    del out
    return loss


def test(model, data):
    model.eval()
    logits, accs, losses, preds = model(x=data.x, edge_index=data.edge_index), [], [], []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        loss = F.nll_loss(model(x=data.x, edge_index=data.edge_index)[mask], data.y[mask])
        preds.append(pred.detach().cpu())
        accs.append(acc)
        losses.append(loss.detach().cpu())
    return accs, preds, losses


def RunExp(data, args, model, tmp_net, tag): 
    if args.net in ['ChebNetII','ChebBase', 'BernNet', 'GPRGNN']:
        optimizer = torch.optim.Adam([
            {'params': model.lin1.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
            {'params': model.lin2.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
            {'params': model.prop1.parameters(), 'weight_decay': args.prop_wd, 'lr': args.prop_lr}])
    else:
        optimizer = torch.optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)

    best_val_acc = test_acc = 0
    best_val_loss = float('inf')
    val_loss_history = []
    val_acc_history = []

    time_run = []
    with alive_bar(args.epochs) as bar:
        for epoch in range(args.epochs):
            t_st=time.time()
            loss = train(model, optimizer, data)
            time_epoch = time.time() - t_st
            time_run.append(time_epoch)
            [train_acc, val_acc, tmp_test_acc], preds, [train_loss, val_loss, tmp_test_loss] = test(model, data)

            if val_loss < best_val_loss:
                best_val_acc = val_acc
                best_val_loss = val_loss
                test_acc = tmp_test_acc
                if args.net in ['ChebNetII', 'BernNet', 'GPRGNN']:
                    TEST = tmp_net.prop1.temp.clone()
                    theta = TEST.detach().cpu()
                    theta = torch.relu(theta).numpy()
                elif args.net in ['ChebBase']:
                    TEST = tmp_net.prop1.temp.clone()
                    theta = TEST.detach().cpu().numpy()
                else:
                    theta = args.alpha
                # torch.save(model.state_dict(), f'pretrain_pkl/sl_{args.dataset}_{args.net}_{args.hidden}_{args.K}_{tag}.pkl')

            if epoch >= 0:
                val_loss_history.append(val_loss)
                val_acc_history.append(val_acc)
                if args.early_stopping > 0 and epoch > args.early_stopping:
                    tmp = torch.tensor(val_loss_history[-(args.early_stopping + 1):-1])
                    if val_loss > tmp.mean().item():
                        break
            bar()
    return test_acc, best_val_acc, theta, time_run


if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    gnn_name = args.net
    if gnn_name == 'GCN':
        Net = GCN_Net
    elif gnn_name =='MLP':
        Net = MLP
    elif gnn_name == 'APPNP':
        Net = APPNP_Net
    elif gnn_name == 'ChebNet':
        Net = ChebNet
    elif gnn_name == "ChebNetII":
        Net = ChebNetII
    elif gnn_name == 'SGC':
        Net = SGC_Net
    elif gnn_name == 'BernNet':
        Net = BernNet
    elif gnn_name == 'GPRGNN':
        Net = GPRGNN
    
    dataset = DataLoader(args.dataset)
    data = dataset[0]

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))
    
    results = []
    time_results=[]
    for RP in range(args.runs):
        args.seed = SEEDS[RP]
        tag = str(args.seed)

        tmp_net = Net(dataset, args)
        if args.pretrain and args.net in ['ChebNetII', 'BernNet', 'GPRGNN']:
            pretrained_dict = torch.load(f'pretrain_pkl/cl_{args.dataset}_{args.net}_{args.hidden}_{args.K}_{tag}.pkl', map_location=device)
            encoder_dict = tmp_net.state_dict()
            if args.fix_prop:
                pretrained_dict = {'prop1.temp': pretrained_dict['encoder.prop1.temp']}
                tmp_net.prop1.temp.requires_grad = False
                encoder_dict.update(pretrained_dict)
                tmp_net.load_state_dict(encoder_dict)
                    
            if args.fix_linear:
                pretrained_dict = {'lin1.weight': pretrained_dict['encoder.lin1.weight'], 'lin1.bias': pretrained_dict['encoder.lin1.bias']}
                tmp_net.lin1.weight.requires_grad = False
                tmp_net.lin1.bias.requires_grad = False
                encoder_dict.update(pretrained_dict)
                tmp_net.load_state_dict(encoder_dict)

            if args.random_prop:
                temp = np.random.uniform(0, 2, args.K+1)
                tmp_net.prop1.temp = Parameter(torch.tensor(temp), requires_grad=False)
            
            theta = torch.relu(tmp_net.prop1.temp)
            print([float('{:.4f}'.format(i)) for i in theta])

        data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed)
        model, data = tmp_net.to(device), data.to(device)

        test_acc, best_val_acc, theta_0, time_run = RunExp(data=data, args=args, model=model, tmp_net=tmp_net, tag=tag)
        time_results.append(time_run)
        results.append([test_acc, best_val_acc, theta_0])
        print(f'run_{str(RP+1)} \t test_acc: {test_acc:.4f}')
        if args.net in ["ChebNetII", 'BernNet', 'GPRGNN']:
            print('Weights:', [float('{:.4f}'.format(i)) for i in theta_0])

    run_sum=0
    epochsss=0
    for i in time_results:
        run_sum+=sum(i)
        epochsss+=len(i)

    test_acc_mean, val_acc_mean, _ = np.mean(results, axis=0) * 100
    test_acc_std = np.sqrt(np.var(results, axis=0)[0]) * 100
    values=np.asarray(results, dtype=object)[:,0]
    uncertainty=np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values,func=np.mean,n_boot=1000),95)-values.mean()))
    print(f'{gnn_name} on dataset {args.dataset}, in {args.runs} repeated experiment:')
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty*100:.4f}  \t val acc mean = {val_acc_mean:.4f}')