import sys; sys.path.append("..")

import torch
import argparse
import numpy as np

from utility.log import Log
from model.geometrics import *
from utility.initialize import initialize
from torch_geometric.utils import to_undirected
from torch_geometric.transforms import RandomNodeSplit
from torch_geometric.datasets import CitationFull
from utility.utils import *
from torch.optim.lr_scheduler import ReduceLROnPlateau

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default='gcn', type=str, help="select model")
    parser.add_argument("--epochs", default=2000, type=int, help="Total number of epochs.")
    parser.add_argument("--learning_rate", '-lr', default=1e-1, type=float, help="Base learning rate at the start of the training.")
    parser.add_argument("--dataset", default="cora_ml", type=str, help="dataset name")
    parser.add_argument("--weight_decay", default=0.0000, type=float, help="L2 weight decay.")
    parser.add_argument("--seed", default=42, type=int, help="L2 weight decay.")
    parser.add_argument("--noise_level", '-nl', type=float, default=0.0, help="noise level for PGD")
    parser.add_argument("--dropout", default=0.5, type=float, help="dropout.")
    parser.add_argument("--train_rest", action='store_true', help="if training on all samples with labels")
    args = parser.parse_args()

    initialize(args, seed=args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = CitationFull(root='./data', name=args.dataset, 
                        pre_transform=RandomNodeSplit(
                        split='test_rest', num_splits=10,
                        num_val = 500, num_train_per_class=20))[0].to(device)

    if data.is_undirected == False:
        data.edge_index = to_undirected(data.edge_index)
    num_classes = (data.y.max() - data.y.min() + 1).cpu().numpy()
    if args.model.lower() == 'gcn':
        model = GCN_Model(input_dim=data.x.shape[1], out_dim=num_classes, 
                            filter_num=32, dropout=args.dropout).to(device)
    elif args.model.lower() == 'gat':
        model = GAT_Model(input_dim=data.x.shape[1], out_dim=num_classes, 
                            filter_num=16, dropout=args.dropout).to(device)
    elif args.model.lower() == 'sage':
        model = SAGE_Model(input_dim=data.x.shape[1], out_dim=num_classes, 
                            filter_num=32, dropout=args.dropout).to(device)
    else:
        model = APPNP_Model(input_dim=data.x.shape[1], out_dim=num_classes, 
                            filter_num=32, K=10, alpha=0.1, dropout=args.dropout).to(device)
    file_name = (args.dataset+'lr'+str(args.learning_rate)
                  +'model'+str(args.model)
                  +'seed'+str(args.seed)
                  +'nl'+str(args.noise_level)
                  +'wd'+str(args.weight_decay)
                  +'dp'+str(args.dropout)+'_'
                  ) 
    log_folder = './graph/'+args.dataset+'/'+args.model
    criterion = torch.nn.NLLLoss(reduce=False)

    acc = []
    for split in range(data.train_mask.shape[1]):
        achieve_target_acc = 0
        log = Log(log_each=1, file_name=file_name+str(split), logs=log_folder)
        train_mask = data.train_mask[:,split]
        val_mask = data.val_mask[:,split]
        test_mask = data.test_mask[:,split]

        if args.train_rest:
            train_mask += val_mask
            
        model.reset_parameters()
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, 
                                        weight_decay=args.weight_decay)
        scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=20)

        best_val_acc = 0
        for epoch in range(args.epochs):
            model.train()
            log.train(len_dataset=1)

            if args.noise_level > 0:
                _, noises_scaled = noise_injection(model, args.noise_level)

            optimizer.zero_grad()
            predictions = model(data.x, data.edge_index)
            loss = criterion(predictions[train_mask], data.y[train_mask])

            loss.mean().backward()
            optimizer.step()

            if args.noise_level > 0:
                rm_injected_noises(model, noises_scaled)

            correct = (data.y[train_mask].cpu() == predictions.max(dim=1)[1][train_mask].cpu())
            log(model, loss.cpu(), correct.cpu(), optimizer.param_groups[0]['lr'])
            
            train_acc = log.epoch_state["accuracy"] / log.epoch_state["steps"]
            if train_acc >= 0.999:
                achieve_target_acc += 1
                if achieve_target_acc > 20:
                    break

            # no need to keep training
            if optimizer.param_groups[0]['lr'] < 1e-5:
                break
            
            # validation
            model.eval()
            log.eval(len_dataset=1)
            with torch.no_grad():
                predictions = model(data.x, data.edge_index)
                loss = criterion(predictions[val_mask], data.y[val_mask])
                correct = (data.y[val_mask].cpu() == predictions.max(dim=1)[1][val_mask].cpu())
                log(model, loss.cpu(), correct.cpu(), optimizer.param_groups[0]['lr'])
                accuracy = (torch.sum(correct) / torch.sum(val_mask).cpu()).numpy()
                
                correct = (data.y[test_mask].cpu() == predictions.max(dim=1)[1][test_mask].cpu())
                test_acc = (torch.sum(correct) / torch.sum(test_mask).cpu()).numpy()
            
            if accuracy > best_val_acc:
                best_val_acc = accuracy
                best_test_acc = test_acc

            scheduler.step(train_acc)
                
        acc.append([test_acc, best_val_acc, best_test_acc])
        log.flush()
    print(np.mean(acc, axis=0))
    np.save(log_folder+'/'+file_name+'.npy', acc)