import argparse

from model import  Model
from utils import load, remove_edges
import torch
import torch as th
import torch.nn as nn

import warnings
from time import time

warnings.filterwarnings('ignore')

parser = argparse.ArgumentParser(description='OrthoReg')

parser.add_argument('--dataname', type=str, default='cora', help='Name of dataset.')
parser.add_argument('--gpu', type=int, default=0, help='GPU index.')
parser.add_argument('--epochs', type=int, default=2000, help='Training epochs.')
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate of pretraining.')
parser.add_argument('--wd', type=float, default=1e-5, help='Weight decay of pretraining.')
parser.add_argument('--n_layers', type=int, default=2, help='Number of GNN layers')
parser.add_argument('--beta', type=float, default=1e-6, help='Trade-off hyperparameter.')
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate.')
parser.add_argument('--alpha', type=float, default=1e-3, help='Trade-off.')
parser.add_argument('--k', type=int, default=1, help='Maximum hops.')
parser.add_argument("--hid_dim", type=int, default=512, help='Hidden layer dim.')
parser.add_argument("--use_bn", action='store_true', default=False, help='Use batch normalization.')
parser.add_argument("--perturb", type=int, default=0,
                    help='0: no perturbation, 1: remove edges when training, 2: remove edges when testing')
parser.add_argument("--p_rate", type=float, default=0.0, help='perturbation rate')

args = parser.parse_args()

# check cuda
if args.gpu != -1 and th.cuda.is_available():
    args.device = 'cuda:{}'.format(args.gpu)
else:
    args.device = 'cpu'

if __name__ == '__main__':

    print(args)
    # load hyperparameters
    dataname = args.dataname
    hid_dim = args.hid_dim
    n_layers = args.n_layers
    beta = args.beta

    epochs = args.epochs
    lr = args.lr
    wd = args.wd

    device = args.device
    perturb = args.perturb
    p_rate = args.p_rate

    graph, feat, labels, num_class, train_idx, val_idx, test_idx = load(dataname)
    in_dim = feat.shape[1]
    out_dim = num_class

    model = Model(in_dim, hid_dim, out_dim, n_layers, args.beta, args.dropout, args.k, args.use_bn)
    model = model.to(device)

    print(graph)
    print(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)

    labels = labels.to(device)
    train_labels = labels[train_idx]
    val_labels = labels[val_idx]
    test_labels = labels[test_idx]

    best_val_acc = 0
    eval_acc = 0

    if perturb == 1:
        graph = remove_edges(graph, p_rate)
        test_graph = graph
        print('distrubing training and testing graph')
    elif perturb == 2:
        test_graph = remove_edges(graph, p_rate)
        print('disturbing testing graph')
    else:
        test_graph = graph
        print('no disturbing')

    graph = graph.to(device)
    test_graph = test_graph.to(device)
    feat = feat.to(device)

    print(test_graph)

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        graph = graph.remove_self_loop().add_self_loop()
        loss_reg, logits = model(graph, feat)

        train_logits = logits[train_idx]
        loss_ce = loss_fn(train_logits, train_labels)

        # loss = loss_ce + loss_reg * args.lam
        loss_inv, loss_dec = loss_reg

        loss = loss_ce + loss_inv * args.alpha + loss_dec * args.beta

        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():

            t0 = time()
            _, logits = model(test_graph, feat)
            t1 = time()
            print(f'Inference time = {t1 - t0}')

            preds = logits.argmax(dim=1)

            train_preds = preds[train_idx]
            val_preds = preds[val_idx]
            test_preds = preds[test_idx]

            train_acc = torch.sum(train_preds == train_labels).float() / train_labels.shape[0]
            val_acc = torch.sum(val_preds == val_labels).float() / val_labels.shape[0]
            test_acc = torch.sum(test_preds == test_labels).float() / test_labels.shape[0]

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                eval_acc = test_acc
            elif val_acc == best_val_acc and test_acc > eval_acc:
                eval_acc = test_acc

            print(
                'Epoch:{}, train_acc:{:.4f}, val_acc:{:4f}, test_acc:{:4f}'.format(epoch, train_acc, val_acc, test_acc))

    print(f'Validation Accuracy: {best_val_acc}, Test Accuracy: {eval_acc}')
