import argparse
from ast import Pass
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from time import perf_counter
from utils import load_citation, set_seed, load_sbm, load_heterdata, load_csbm_v2

from dgc import sign_precompute, DGC, sgc_precompute, center_precompute, label_precompute, base_precompute # , sys_precompute, op_precompute, CN_precompute
from cSBM_dataset import ContextualSBM
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default="texas",
                    help='Dataset to use.')
parser.add_argument('--lr', type=float, default=1.39,
                    help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=2e-4,
                    help='Weight decay (L2 loss on parameters).')
parser.add_argument('--epochs', type=int, default=100,
                    help='Number of epochs to train.')
parser.add_argument('--normalization', type=str, default='AugNormAdj',
                    choices=['AugNormAdj, NormAdj'],
                    help='Normalization method for the adjacency matrix.')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='Disables CUDA training.')                    
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--trials', type=int, default=10, help='Run multiple trails for fair evaluation')
# DGC parameters
parser.add_argument('--T', type=float, default=1,
                    help='real-valued diffusion terminal time.')
parser.add_argument('--K', type=int, default=300,
                    help='number of propagation steps (larger K implies better numerical precision).')
# model select
parser.add_argument('--model', type=str, default='Label', help='op, repel, SGC, Center, Sign, Sys, Label, CN, LN, PN')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--a', type=float, default=1.0) # for positive adj
parser.add_argument('--b', type=float, default=1.0) # for negative adj
parser.add_argument('--phi',type=float, default=0.25)
# split 
parser.add_argument('--train_set',type=float, default=0., help='this is for the train label rate')
args = parser.parse_args()
set_seed(args.seed)
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
path = '../data/'
name = f'cSBM_phi_{args.phi}'
def main(args):
    # load data
    if args.dataset == 'sbm':
        adj, features, labels, idx_train, idx_val, idx_test = load_sbm(args.normalization, args.device, args.train_set)
    elif args.dataset in ['cora','citeseer','pubmed']:
        adj, features, labels, idx_train, idx_val, idx_test = load_citation(args.dataset, args.normalization, args.device)
    else:
        adj, features, labels, idx_train, idx_val, idx_test = load_heterdata(args.dataset,args.device)
   
    # adj, features, labels, idx_train, idx_val, idx_test = load_csbm_v2(path, name=name,train_set=0.6,device= args.device, normalization=args.normalization)

    # val 500, test 1000
    # preprocessing with all DGC propagation steps
    if args.model == 'SGC':
        features, precompute_time = sgc_precompute(features, adj, args.T, args.K,labels)
    # if args.model =='Center':
    #     features, precompute_time = center_precompute(features, adj, args.T, args.K)
    if args.model == 'Sign':
        features, precompute_time = sign_precompute(features, adj, args.T, args.K, args.b, labels)
    # if args.model == 'Sys':
    #     features, precompute_time = sys_precompute(features, adj, args.T, args.K)
    if args.model == 'Label':
        features, precompute_time = label_precompute(features, adj, labels, idx_train, args.T, args.K, args.b)
    # if args.model == 'CN_neg':
    #     features, precompute_time = CN_precompute(features, adj, labels, idx_train, args.T, args.K)
    if args.model in ['CN','LN','PN','BN','drop','res','appnp','jk-net','dagnn']:
        features, precompute_time = base_precompute(features, adj, args.T, args.K, args.model, labels)
    if args.model =='mlp':
        features = features
        precompute_time = 0
    # if args.model in ['op', 'repel']:
    #     features, precompute_time = op_precompute(features, adj, labels, idx_train, args.a, args.b, args.K)


    # print("{:.4f}s".format(precompute_time))

    # initialize model (a linear head)
    if args.model in ['jk-net']:
        model = DGC(features.size(1), labels.max().item()+1)
    else: # mlp
        model = DGC(features.size(1), labels.max().item()+1)

    model = model.to(device)

    # train logistic regression and collect test accuracy
    model, train_time = train(model, features[idx_train], labels[idx_train], args.epochs, args.weight_decay, args.lr)
    acc_test = test(model, features[idx_test], labels[idx_test])

    print("Test accuracy: {:.4f},  pre-compute time: {:.4f}s, train time: {:.4f}s, total: {:.4f}s".format(acc_test, precompute_time, train_time, precompute_time+train_time))

    return acc_test


def train(model,
        train_features, train_labels,
        epochs=100, weight_decay=5e-6, lr=0.2):
    optimizer = optim.Adam(model.parameters(), lr=lr,
                           weight_decay=weight_decay)
    t = perf_counter()

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        output = model(train_features)
        loss_train = F.cross_entropy(output, train_labels)
        loss_train.backward()
        optimizer.step()
    train_time = perf_counter()-t
    return model, train_time


def test(model, test_features, test_labels):
    model.eval()
    return accuracy(model(test_features), test_labels)


def accuracy(output, labels):
    preds = output.argmax(dim=1).type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum().item()
    return correct / len(labels)

# perform n trials for a fair evaluation
accu_acc = []
print(args)
for _ in range(args.trials):
    acc_test = main(args)
    accu_acc.append(acc_test)

accu_acc = np.array(accu_acc)
acc_mean, acc_std = accu_acc.mean(), accu_acc.std()

print('='*20)
print(f'Dataset: {args.dataset} Test accuracy of {args.trials} runs: mean {acc_mean:.5f}, std {acc_std:.5f}')