from __future__ import division
from __future__ import print_function

import os
import glob
import time
import random
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from utils import load_data, accuracy, coverage, flip, cost_loss
from models import GAT
from loss import CwRLoss

# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False, help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=10000, help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.005, help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=8, help='Number of hidden units.')
parser.add_argument('--nb_heads', type=int, default=8, help='Number of head attentions.')
parser.add_argument('--dropout', type=float, default=0.6, help='Dropout rate (1 - keep probability).')
parser.add_argument('--alpha', type=float, default=0.2, help='Alpha for the leaky_relu.')
parser.add_argument('--patience', type=int, default=100, help='Patience')

parser.add_argument('--cost', type=float, default=0.7, help='Cost of rejection')
parser.add_argument('--ls', type=float, default=0.0, help='Label Smoothing')
parser.add_argument('--noise', type=float, default=0.0, help='Label Noise')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# Load data
adj, features, labels, idx_train, idx_val, idx_test = load_data()

#Inject noise into training data
if args.noise != 0:
    labels[idx_train] = flip(labels[idx_train], args.noise)

#Give an extra class for reject option
model = GAT(nfeat=features.shape[1], 
                nhid=args.hidden, 
                nclass=int(labels.max()) + 2, 
                dropout=args.dropout, 
                nheads=args.nb_heads, 
                alpha=args.alpha)
optimizer = optim.Adam(model.parameters(), 
                       lr=args.lr, 
                       weight_decay=args.weight_decay)

criterion = CwRLoss(args.cost, args.ls)

if args.cuda:
    model.cuda()
    features = features.cuda()
    adj = adj.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()

features, adj, labels = Variable(features), Variable(adj), Variable(labels)

def train(epoch):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output = model(features, adj)
    loss_train = criterion(output[idx_train], labels[idx_train])
    acc_train = accuracy(output[idx_train], labels[idx_train])
    cov_train = coverage(output[idx_train], labels[idx_train])
    cost_loss_train = cost_loss(output[idx_train], labels[idx_train], args.cost)
    loss_train.backward()
    optimizer.step()

    if not args.fastmode:
        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        model.eval()
        output = model(features, adj)

    loss_val = criterion(output[idx_val], labels[idx_val])
    acc_val = accuracy(output[idx_val], labels[idx_val])
    cov_val = coverage(output[idx_val], labels[idx_val])
    cost_loss_val = cost_loss(output[idx_val], labels[idx_val], args.cost)
    print('Epoch: {:04d}'.format(epoch+1),
          '0-d-1 Loss Train: {:.3f}'.format(cost_loss_train.data.item()),
          'Acc Train: {:.1f}'.format(acc_train.data.item() * 100),
          'Cov train: {:.1f}'.format(cov_train.data.item() * 100))
    print('0-d-1 Loss Val: {:.3f}'.format(cost_loss_val.data.item()),
          'Acc Val: {:.1f}'.format(acc_val.data.item() * 100),
          'Cov Val: {:.1f}'.format(cov_val.data.item() * 100),
          'Time: {:.3f}s'.format(time.time() - t))

    return loss_val.data.item()


def compute_test():
    model.eval()
    output = model(features, adj)
    loss_test = criterion(output[idx_test], labels[idx_test])
    acc_test = accuracy(output[idx_test], labels[idx_test])
    cov_test = coverage(output[idx_test], labels[idx_test])
    cost_loss_test = cost_loss(output[idx_test], labels[idx_test], args.cost)
    print("Input parameters:",
          "Cost (d)= {:.2f}".format(args.cost),
          "Label Smoothing= {:.2f}".format(args.ls),
          "Noise= {:.2f}".format(args.noise))
    print("Test set results:",
          "0-d-1 Loss= {:.3f}".format(cost_loss_test.data.item()),
          "Accuracy= {:.2f}".format(acc_test.data.item() * 100),
          "Coverage= {:.2f}".format(cov_test.data.item() * 100))
    

# Train model
t_total = time.time()
loss_values = []
bad_counter = 0
best = args.epochs + 1
best_epoch = 0
for epoch in range(args.epochs):
    loss_values.append(train(epoch))

    torch.save(model.state_dict(), '{}.pkl'.format(epoch))
    if loss_values[-1] < best:
        best = loss_values[-1]
        best_epoch = epoch
        bad_counter = 0
    else:
        bad_counter += 1

    if bad_counter == args.patience:
        break

    files = glob.glob('*.pkl')
    for file in files:
        epoch_nb = int(file.split('.')[0])
        if epoch_nb < best_epoch:
            os.remove(file)

files = glob.glob('*.pkl')
for file in files:
    epoch_nb = int(file.split('.')[0])
    if epoch_nb > best_epoch:
        os.remove(file)

print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

# Restore best model
print('Loading {}th epoch'.format(best_epoch))
model.load_state_dict(torch.load('{}.pkl'.format(best_epoch)))

# Testing
compute_test()
