from __future__ import division
from __future__ import print_function

import os
import glob
import json
import time
import random
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from sklearn.manifold import TSNE

from utils import load_data, accuracy, sel_loss, sel_accuracy, true_cov, find_tres, flip
from models import GAT

# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False, help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=72, help='Random seed.')
parser.add_argument('--epochs', type=int, default=10000, help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.005, help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=8, help='Number of hidden units.')
parser.add_argument('--nb_heads', type=int, default=8, help='Number of head attentions.')
parser.add_argument('--dropout', type=float, default=0.6, help='Dropout rate (1 - keep probability).')
parser.add_argument('--alpha', type=float, default=0.2, help='Alpha for the leaky_relu.')
parser.add_argument('--patience', type=int, default=100, help='Patience')

parser.add_argument('--coverage', type=float, default=0.8, help='Coverage')
parser.add_argument('--lamda', type=int, default=32, help='Lambda')
parser.add_argument('--alphaloss', type=float, default=0.5, help='Alpha')
parser.add_argument('--ls', type=float, default=0.0, help='Label Smoothing')
parser.add_argument('--noise', type=float, default=0.0, help='Label Noise')


args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# Load data
adj, features, labels, idx_train, idx_val, idx_test = load_data()

#Inject noise into training data
if args.noise != 0:
    labels[idx_train] = flip(labels[idx_train], args.noise)

# Model and optimizer
model = GAT(nfeat=features.shape[1], 
                nhid=args.hidden, 
                nclass=int(labels.max()) + 1, 
                dropout=args.dropout, 
                nheads=args.nb_heads, 
                alpha=args.alpha,
                coverage=args.coverage)

optimizer = optim.Adam(model.parameters(), 
                       lr=args.lr, 
                       weight_decay=args.weight_decay)

if args.cuda:
    model.cuda()
    features = features.cuda()
    adj = adj.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()

features, adj, labels = Variable(features), Variable(adj), Variable(labels)

loss_fn = nn.CrossEntropyLoss(label_smoothing=args.ls)

def train(epoch):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output = model(features, adj)[1]
    loss_train1 = sel_loss(output[0][idx_train], labels[idx_train], args.coverage, lamda=args.lamda, ls=args.ls)
    loss_train2 = loss_fn(output[1][idx_train], labels[idx_train])
    loss_train = (args.alphaloss * loss_train1) + ((1-args.alphaloss) * loss_train2)
    sel_acc_train = sel_accuracy(output[0][idx_train], labels[idx_train])
    true_cov_train = true_cov(output[0][idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()

    if not args.fastmode:
        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        model.eval()
        output = model(features, adj)[1]

    loss_val1 = sel_loss(output[0][idx_train], labels[idx_train], args.coverage, lamda=args.lamda, ls=args.ls)
    loss_val2 = loss_fn(output[1][idx_train], labels[idx_train])
    loss_val = (args.alphaloss * loss_val1) + ((1-args.alphaloss) * loss_val2)
    sel_acc_val = sel_accuracy(output[0][idx_val], labels[idx_val])
    true_cov_val = true_cov(output[0][idx_val], labels[idx_val])
    print('Epoch: {:04d}'.format(epoch+1),
          'Loss Train: {:.4f}'.format(loss_train.data.item()),
          'Sel Acc Train: {:.4f}'.format(sel_acc_train.data.item()),
          'Cov Train: {:.4f}'.format(true_cov_train.data.item()),
          'Time: {:.4f}s'.format(time.time() - t))
    
    print('Epoch: {:04d}'.format(epoch+1),
          'Loss Val: {:.4f}'.format(loss_val.data.item()),
          'Sel Acc Val: {:.4f}'.format(sel_acc_val.data.item()),
          'Cov Val: {:.4f}'.format(true_cov_val.data.item()),
          'Time: {:.4f}s'.format(time.time() - t))

    return loss_val.data.item()


def compute_test():
    model.eval()
    output = model(features, adj)[1]
    loss_test = sel_loss(output[0][idx_test], labels[idx_test], args.coverage, lamda=args.lamda, ls=args.ls)
    acc_test = accuracy(output[0][idx_test], labels[idx_test])
    tres = find_tres(output[0][idx_test], args.coverage)
    sel_acc_test = sel_accuracy(output[0][idx_test], labels[idx_test], t=tres)
    true_cov_test = true_cov(output[0][idx_test], labels[idx_test], t=tres)
    
    print("Test set results:")
    print("Loss = {:.4f}".format(loss_test.data.item()),
        "Accuracy = {:.4f}".format(acc_test.data.item()))
    print("Treshold = ", round(tres, 3),
        "Selective Accuracy = {:.4f}".format(sel_acc_test.data.item()),
        "Coverage = {:.4f}".format(true_cov_test.data.item()))

# Train model
t_total = time.time()
loss_values = []
bad_counter = 0
best = args.epochs + 1
best_epoch = 0
for epoch in range(args.epochs):
    loss_values.append(train(epoch))

    torch.save(model.state_dict(), '{}.pkl'.format(epoch))
    if loss_values[-1] < best:
        best = loss_values[-1]
        best_epoch = epoch
        bad_counter = 0
    else:
        bad_counter += 1

    if bad_counter == args.patience:
        break

    files = glob.glob('*.pkl')
    for file in files:
        epoch_nb = int(file.split('.')[0])
        if epoch_nb < best_epoch:
            os.remove(file)

files = glob.glob('*.pkl')
for file in files:
    epoch_nb = int(file.split('.')[0])
    if epoch_nb > best_epoch:
        os.remove(file)

print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

# Restore best model
print('Loading {}th epoch'.format(best_epoch))
model.load_state_dict(torch.load('{}.pkl'.format(best_epoch)))

# Testing
compute_test()

