from __future__ import division
from __future__ import print_function
import time
import random
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from process import *
from utils import *
from model_c import *
from model_GeomGCN import *
from torch_geometric.data import Data
import dgl
from h2gcn_model import H2GCN



import uuid

# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=1500, help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate.')
parser.add_argument('--weight_decay', type=float, default=0.01, help='weight decay (L2 loss on parameters).')
parser.add_argument('--layer', type=int, default=2, help='Number of layers.')
parser.add_argument('--hidden', type=int, default=64, help='hidden dimensions.')
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')


parser.add_argument('--patience', type=int, default=200, help='Patience')
parser.add_argument('--data', default='cora', help='dateset')

parser.add_argument('--model', type=str, default="GCN", help='choose models')
parser.add_argument('--hops', type=int, default=1, help='number of hops for order transformer')


args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)


cudaid = "cuda:"+str(args.dev)
device = torch.device(cudaid if torch.cuda.is_available() else "cpu")
current_time = time.strftime("%d_%H_%M_%S", time.localtime(time.time()))
checkpt_file = 'pretrained/'+"{}_{}_{}".format(args.model, args.data, current_time)+'.pt'
print(cudaid,checkpt_file)
print('device:', device)


def train_step(model,optimizer, features, labels, adj, idx_train, use_geom):
    model.train()
    optimizer.zero_grad()
    if use_geom:
        output = model(features)
    else:
        output = model(features, adj)
    acc_train = accuracy(output[idx_train], labels[idx_train].to(device))
    loss_train = focal_loss(output[idx_train], labels[idx_train].to(device))
    loss_train.backward()
    optimizer.step()
    return loss_train.item(),acc_train.item()


def validate_step(model,features,labels,adj,idx_val, use_geom):
    model.eval()
    with torch.no_grad():
        if use_geom:
            output = model(features)
        else:
            output = model(features, adj)
        loss_val = F.nll_loss(output[idx_val], labels[idx_val].to(device))
        acc_val = accuracy(output[idx_val], labels[idx_val].to(device))
        return loss_val.item(),acc_val.item(), output

def test_step(model, features, labels, adj, idx_test, use_geom, deg_vec, raw_adj):
    model.load_state_dict(torch.load(checkpt_file))
    model.eval()
    with torch.no_grad():
        if use_geom:
            output = model(features)
        else:
            output = model(features, adj)
        loss_test = F.nll_loss(output[idx_test], labels[idx_test].to(device))
        acc_test = accuracy(output[idx_test], labels[idx_test].to(device))
        return loss_test.item(), acc_test.item()
    

def train(datastr,splitstr):
    use_geom=(args.model=='GEOMGCN')
    get_degree = (args.get_degree) & (args.model=="GCN")
    adj, features, labels, idx_train, idx_val, idx_test, num_features, num_labels, deg_vec, raw_adj = full_load_data(datastr,splitstr,args.row_normalized_adj, model_type=args.model, embedding_method=args.emb, get_degree=get_degree)

    features = features.to(device)
    adj = adj.to(device)
    labels = labels.to(device)
    idx_train = idx_train.to(device)
    idx_val = idx_val.to(device)
    idx_test = idx_test.to(device)
    
    lpe = laplacian_positional_encoding(adj, 3)
    features = torch.cat((features, lpe.cuda()), dim=1)


    args.model=='OrderTransformer':
    model = MPformer(args.hops, num_labels, features.shape[1], args.layer, args.hidden, 
                                args.nb_heads, features.shape[0], tran_dropout=0.5, feat_dropout=0.5).cuda()
    optimizer = optim.Adam(model.parameters(), lr=args.lr,
                            weight_decay=args.weight_decay)
    
    bad_counter = 0
    best = 999999999
    acc = []
    for epoch in range(args.epochs):
        loss_tra,acc_tra = train_step(model,optimizer,features,labels,adj,idx_train, use_geom)
        loss_val,acc_val, out = validate_step(model,features,labels,adj,idx_val, use_geom)
        acc_test = accuracy(out[idx_test], labels[idx_test].to(device))
        acc.append(acc_test.item())
        if(epoch+1)%1 == 0: 
            print('Epoch:{:04d}'.format(epoch+1),
                'train',
                'loss:{:.3f}'.format(loss_tra),
                'acc:{:.2f}'.format(acc_tra*100),
                '| val',
                'loss:{:.3f}'.format(loss_val),
                'acc:{:.2f}'.format(acc_val*100))
        if loss_val < best:
            best = loss_val
            torch.save(model.state_dict(), checkpt_file)
            bad_counter = 0
        else:
            bad_counter += 1

        if bad_counter == args.patience:
            break
    
    test_res = test_step(model,features,labels,adj,idx_test, use_geom, deg_vec, raw_adj)
    acc = test_res[1]

    return acc*100, max(acc)*100

t_total = time.time()
acc_list = []
best_test_list = []
for i in range(10):
    datastr = args.data
    splitstr = 'splits/'+args.data+'_split_0.6_0.2_'+str(i)+'.npz'
    acc, test_best = train(datastr,splitstr)
    acc_list.append(acc)
    best_test_list.append(test_best)
    print(i,": {:.2f}".format(acc_list[-1]))
    print(i,": {:.2f}".format(best_test_list[-1]))
print("Train cost: {:.4f}s".format(time.time() - t_total))
print("Test acc.:{:.2f}".format(np.mean(best_test_list)))
print("Test std.:{:.2f}".format(np.std(acc_list)))
        
        
