import sys; sys.path.append("..")

import torch
import argparse, math
import numpy as np

from utility.log import Log
from scipy.optimize import fsolve
from torch.optim import SGD, Adam
from model.geometrics import *
from utility.initialize import initialize
from torch_geometric.transforms import RandomNodeSplit
from torch_geometric.datasets import CitationFull
from utility.utils import *
from torch.optim.lr_scheduler import ReduceLROnPlateau

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default='gcn', type=str, help="select model")
    parser.add_argument("--epochs", default=2000, type=int, help="Total number of epochs.")
    parser.add_argument("--shift", default=500, type=int, help="the end epoch of stage1.")
    parser.add_argument("--learning_rate", '-lr', default=1e-2, type=float, help="Base learning rate at the start of the training.")
    parser.add_argument("--dataset", default="cora_ml", type=str, help="dataset name")
    parser.add_argument("--seed", default=42, type=int, help="L2 weight decay.")
    parser.add_argument("--train_rest", action='store_true', help="if training on all samples with labels")
    args = parser.parse_args()

    initialize(args, seed=args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = CitationFull(root='./data', name=args.dataset, 
                        pre_transform=RandomNodeSplit(
                        split='test_rest', num_splits=10,
                        num_val = 500, num_train_per_class=20))[0].to(device)
    num_classes = (data.y.max() - data.y.min() + 1).cpu().numpy()

    if data.is_undirected == False:
        data.edge_index = to_undirected(data.edge_index)
    if args.model.lower() == 'gcn':
        model = GCN_Model(input_dim=data.x.shape[1], out_dim=num_classes, 
                            filter_num=32, 
                            dropout=0.0).to(device)
    elif args.model.lower() == 'gat':
        model = GAT_Model(input_dim=data.x.shape[1], out_dim=num_classes, 
                            filter_num=8, dropout=0.0).to(device)
    elif args.model.lower() == 'sage':
        model = SAGE_Model(input_dim=data.x.shape[1], out_dim=num_classes, 
                            filter_num=32,
                            dropout=0.0).to(device)
    else:
        model = APPNP_Model(input_dim=data.x.shape[1], out_dim=num_classes, 
                            filter_num=32, K=10, alpha=0.1,
                            dropout=0.0).to(device)
    file_name = (args.dataset+'lr'+str(np.round(args.learning_rate, 5))
                  +'model'+str(args.model)
                  +'seed'+str(args.seed))
    log_folder = './graph_auto1_K_r_all/'+args.dataset+'/'+args.model
    criterion = torch.nn.NLLLoss(reduce=False)

    # initialization
    max_gamma = 10
    min_gamma = 0.5
    acc = []
    for split in range(data.train_mask.shape[1]):
        achieve_target_acc = 0
        log = Log(log_each=1, file_name=file_name+str(split), logs=log_folder)

        train_mask = data.train_mask[:,split]
        val_mask = data.val_mask[:,split]
        test_mask = data.test_mask[:,split]

        if args.train_rest:
            train_mask += val_mask

        model.reset_parameters()
        w0, p, layers = initialization(model)
        b = nn.Parameter(torch.ones(1, device=device)*(torch.log(w0.abs().mean()))*1, requires_grad=True)

        m = train_mask.sum().cpu().numpy()
        init = torch.tensor(np.max((torch.log(w0.abs().mean()).cpu().numpy(), -np.log(10)))).to(device)  
        p.data = init*torch.ones(len(p), device=device)
        b.data = init*torch.ones(len(b), device=device)

        prior_list, K_list = compute_K_sample(model, data, criterion, min_gamma, max_gamma, 
                                                max_nu=0, gcn=True, split = split)
        
        print(prior_list, K_list)
        opt1 = Adam(model.parameters(),lr=args.learning_rate)
        opt2 = Adam([p],lr=args.learning_rate)
        opt3 = Adam([b],lr=args.learning_rate)
        scheduler = ReduceLROnPlateau(opt1, mode='max', factor=0.1, patience=20)

        best_val_acc = 0
        for epoch in range(args.epochs):
            model.train()
            log.train(len_dataset=1)

            # noise injection
            wdecay = weight_decay(model, w0)
            noises, noises_scaled = noise_injection(model, p)

            opt1.zero_grad()
            opt2.zero_grad()
            opt3.zero_grad()
            predictions = model(data.x, data.edge_index)
            loss1 = criterion(predictions[train_mask], data.y[train_mask])

            # loss 2
            if epoch < args.shift:
                kl = get_kl_term_with_b(wdecay, p, b)
                gamma1 = fun_K_auto(torch.exp(b),prior_list,K_list)**(-1)*( 2*(kl+10) /torch.sum(train_mask)/3 )**0.5
                gamma1 = torch.clip(gamma1,max=max_gamma,min=min_gamma)
                loss2 = 3*fun_K_auto(torch.exp(b),prior_list,K_list)**2*gamma1/2 + (kl+10)/torch.sum(train_mask)/gamma1
            else:
                loss2 = 0*loss2

            # backward
            loss1.mean().backward(retain_graph=True)
            if epoch < args.shift:
                kl_term_backward(loss2, model, p, noises)

            # remove noises
            rm_injected_noises(model, noises_scaled)

            # update
            opt1.step()
            if epoch < args.shift:
                opt2.step()
                opt3.step()

            correct = (data.y[train_mask].cpu() == predictions.max(dim=1)[1][train_mask].cpu())
            log(model, loss1.cpu()+loss2.cpu(), correct.cpu(), 
                fun_K_auto(torch.exp(b),prior_list,K_list).detach().cpu().item())

            train_acc = log.epoch_state["accuracy"] / log.epoch_state["steps"]
            if train_acc >= 0.999 and epoch > args.shift:
                achieve_target_acc += 1
                if achieve_target_acc > 20:
                    break

            # no need to keep training
            if opt1.param_groups[0]['lr'] < 1e-5:
                break

            # prediction
            model.eval()
            log.eval(len_dataset=1)
            with torch.no_grad():
                predictions = model(data.x, data.edge_index)
                loss = criterion(predictions[val_mask], data.y[val_mask])
                correct = (data.y[val_mask].cpu() == predictions.max(dim=1)[1][val_mask].cpu())
                log(model, loss.cpu(), correct.cpu(), 
                    fun_K_auto(torch.exp(b),prior_list,K_list).detach().cpu().item())
                accuracy = (torch.sum(correct) / torch.sum(val_mask).cpu()).numpy()

                correct = (data.y[test_mask].cpu() == predictions.max(dim=1)[1][test_mask].cpu())
                test_acc = (torch.sum(correct) / torch.sum(test_mask).cpu()).numpy()
            
            if accuracy > best_val_acc:
                best_val_acc = accuracy
                best_test_acc = test_acc

            if epoch > args.shift:
                scheduler.step(train_acc)
                
        acc.append([test_acc, best_val_acc, best_test_acc])
        log.flush()

    print(np.mean(acc, axis=0))
    np.save(log_folder+'/'+file_name+'.npy', acc)
