import os
import sys
import argparse
import random
import json
import collections
import math

from distutils.version import LooseVersion as LV

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader
import torch.backends.cudnn as cudnn
from gensim.utils import simple_preprocess
from gensim.corpora import Dictionary
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

from models import Model_CNN
from train_utils import train_single_epoch, test_single_epoch

# Import validation metrics
from Metrics.metrics import test_classification_net
from Metrics.metrics import expected_calibration_error, maximum_calibration_error, adaECE_error, ClasswiseECELoss
from Metrics.metrics import l2_error
from Metrics.plots import reliability_plot, bin_strength_plot



os.environ["CUDA_VISIBLE_DEVICES"] = '2'
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

DATADIR = "./"
GLOVE_DIR = os.path.join(DATADIR, "glove.6B")
TEXT_DATA_DIR = os.path.join(DATADIR, "20_newsgroup")
MAX_SEQUENCE_LENGTH = 1000
MAX_NUM_WORDS = 20000 # 2 words reserved: 0=pad, 1=oov
EMBEDDING_DIM = 100
NUM_CLASSES = 20

OPTIMISER = "adam"
train_batch_size = 128
test_batch_size = 128
num_epochs = 50
learning_rate = 1e-3
momentum = 0.9
weight_decay = 5e-4
loss = "cross_entropy"
gamma = 1.0
gamma2 = 1.0
gamma3 = 1.0
gamma_schedule_step1 = 100
gamma_schedule_step2 = 250
first_milestone = 150 #Milestone for change in lr
second_milestone = 250 #Milestone for change in lr
lamda = 1.0
log_interval = 5
save_interval = 5
save_loc = './'
model_name = None
saved_model_name = "cnn_cross_entropy_50.model"
load_loc = './'
model = "cnn"

parser = argparse.ArgumentParser(description="Training for calibration.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-g", action="store_true", dest="gpu", help="Use GPU")
parser.set_defaults(gpu=True)
parser.add_argument("--num-epochs", type=int, default=num_epochs, dest="num_epochs", help='Number of training epochs')
parser.add_argument("--opt", type=str, default=OPTIMISER, dest="optimiser", help='Choice of optimisation algorithm')
parser.add_argument("--lr", type=float, default=learning_rate, dest="learning_rate", help='Learning rate')
parser.add_argument("--mom", type=float, default=momentum, dest="momentum", help='Momentum')
parser.add_argument("--nesterov", action="store_true", dest="nesterov", help="Whether to use nesterov momentum in SGD")
parser.set_defaults(nesterov=False)
parser.add_argument("--decay", type=float, default=weight_decay, dest="weight_decay", help="Weight Decay")
parser.add_argument("--load", action="store_true", dest="load", help="Load from pretrained model")
parser.set_defaults(load=False)
parser.add_argument("-b", type=int, default=train_batch_size, dest="train_batch_size", help="Batch size")
parser.add_argument("-tb", type=int, default=test_batch_size, dest="test_batch_size", help="Test Batch size")

parser.add_argument("--model", type=str, default=model, dest="model", help='Model to train')
parser.add_argument("--model-name", type=str, default=model_name, dest="model_name", help='name of the model')
parser.add_argument("--loss", type=str, default=loss, dest="loss_function", help="Loss function to be used for training")
parser.add_argument("--loss-mean", action="store_true", dest="loss_mean", help="whether to take mean of loss instead of sum to train")
parser.set_defaults(loss_mean=False)
parser.add_argument("--gamma", type=float, default=gamma, dest="gamma", help="Gamma for focal components")
parser.add_argument("--gamma2", type=float, default=gamma2, dest="gamma2", help="Gamma for different focal components")
parser.add_argument("--gamma3", type=float, default=gamma3, dest="gamma3", help="Gamma for different focal components")
parser.add_argument("--lamda", type=float, default=lamda, dest="lamda", help="Regularization factor")
parser.add_argument("--gamma-schedule", type=int, default=0, dest="gamma_schedule", help="Schedule gamma or not")
parser.add_argument("--gamma-schedule-step1", type=int, default=gamma_schedule_step1, dest="gamma_schedule_step1", help="1st step for gamma schedule")
parser.add_argument("--gamma-schedule-step2", type=int, default=gamma_schedule_step2, dest="gamma_schedule_step2", help="2nd step for gamma schedule")
parser.add_argument("--first-milestone", type=int, default=first_milestone, dest="first_milestone", help="First milestone to change lr")
parser.add_argument("--second-milestone", type=int, default=second_milestone, dest="second_milestone", help="Second milestone to change lr")

parser.add_argument("--log-interval", type=int, default=log_interval, dest="log_interval", help="Log Interval on Terminal")
parser.add_argument("--save-interval", type=int, default=save_interval, dest="save_interval", help="Save Interval on Terminal")
parser.add_argument("--saved_model_name", type=str, default=saved_model_name, dest="saved_model_name", help="file name of the pre-trained model")
parser.add_argument("--save-path", type=str, default=save_loc, dest="save_loc", help='Path to export the model')
parser.add_argument("--load-path", type=str, default=load_loc, dest="load_loc", help='Path to load the model from')

parser.add_argument("--num-bins", nargs="+", type=int, default=[15], dest="num_bins", help="Number of calibration bins")
parser.add_argument("--gamma_lambda", type=float, default=1.0, dest="gamma_lambda", help="lambda for auto adaptive focal gamma = gamma0*exp(lambda*ece)")
parser.add_argument("--gamma_max", type=float, default=1e10, dest="gamma_max", help="Maximum cutoff value for clipping exploding gammas")
parser.add_argument("--adafocal_start_epoch", type=int, default=0, dest="adafocal_start_epoch", help="Epoch to start the sample adaptive focal calibration")

args = parser.parse_args()

def loss_function_save_name(loss_function,
                            scheduled=False,
                            gamma=1.0,
                            gamma1=1.0,
                            gamma2=1.0,
                            gamma3=1.0,
                            lamda=1.0):
    res_dict = {
        'cross_entropy': 'cross_entropy',
        'focal_loss': 'focal_loss_gamma_' + str(gamma),
        'focal_loss_sd': 'focal_loss_sd_gamma_' + str(gamma),
        'mmce': 'mmce_lamda_' + str(lamda),
        'mmce_weighted': 'mmce_weighted_lamda_' + str(lamda),
        'brier_score': 'brier_score',
        'adafocal': 'adafocal_' + str(gamma)
    }
    if (loss_function == 'focal_loss' and scheduled == True):
        res_str = 'focal_loss_scheduled_gamma_' + str(gamma1) + '_' + str(gamma2) + '_' + str(gamma3)
    else:
        res_str = res_dict[loss_function]
    return res_str


# GloVe word embeddings
# The datafile contains 100-dimensional embeddings for 400,000 English words.
print('Indexing word vectors.')
embeddings_index = {}
with open(os.path.join(GLOVE_DIR, 'glove.6B.100d.txt')) as f:
    for line in f:
        values = line.split()
        word = values[0]
        coefs = np.asarray(values[1:], dtype='float32')
        embeddings_index[word] = coefs
print('Found %s word vectors.' % len(embeddings_index))

# 20 Newsgroups data set
print('Processing text dataset')
texts = []  # list of text samples
labels_index = {}  # dictionary mapping label name to numeric id
labels = []  # list of label ids
for name in sorted(os.listdir(TEXT_DATA_DIR)):
    path = os.path.join(TEXT_DATA_DIR, name)
    if os.path.isdir(path):
        label_id = len(labels_index)
        labels_index[name] = label_id
        for fname in sorted(os.listdir(path)):
            if fname.isdigit():
                fpath = os.path.join(path, fname)
                fargs = {} if sys.version_info < (3,) else {'encoding': 'latin-1'}
                with open(fpath, **fargs) as f:
                    t = f.read()
                    i = t.find('\n\n')  # skip header
                    if 0 < i:
                        t = t[i:]
                    texts.append(t)
                labels.append(label_id)
print('Found %s texts.' % len(texts))

# Tokenize the texts using gensim.
tokens = list()
for text in texts:
    tokens.append(simple_preprocess(text))

# Vectorize the text samples into a 2D integer tensor.
dictionary = Dictionary(tokens)
dictionary.filter_extremes(no_below=0, no_above=1.0, keep_n=MAX_NUM_WORDS-2)
word_index = dictionary.token2id
print('Found %s unique tokens.' % len(word_index))

data = [dictionary.doc2idx(t) for t in tokens]

# Truncate and pad sequences.
data = [i[:MAX_SEQUENCE_LENGTH] for i in data]
data = np.array([np.pad(i, (0, MAX_SEQUENCE_LENGTH-len(i)), mode='constant', constant_values=-2)
                 for i in data], dtype=int)
data = data + 2
print('Shape of data tensor:', data.shape)
print('Length of label vector:', len(labels))

# Split the data into a training set and a validation set
VALIDATION_SET, TEST_SET = 900, 3999
x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=TEST_SET, shuffle=True, random_state=42)
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=VALIDATION_SET, shuffle=False)
print('Shape of training data tensor:', x_train.shape)
print('Length of training label vector:', len(y_train))
print('Shape of validation data tensor:', x_val.shape)
print('Length of validation label vector:', len(y_val))
print('Shape of test data tensor:', x_test.shape)
print('Length of test label vector:', len(y_test))

# Create PyTorch DataLoaders for all data sets
print('Train: ', end="")
train_dataset = TensorDataset(torch.LongTensor(x_train), torch.LongTensor(y_train))
train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=4)
print(len(train_dataset), 'messages')

print('Validation: ', end="")
validation_dataset = TensorDataset(torch.LongTensor(x_val), torch.LongTensor(y_val))
val_loader = DataLoader(validation_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=4)
print(len(validation_dataset), 'messages')

print('Test: ', end="")
test_dataset = TensorDataset(torch.LongTensor(x_test), torch.LongTensor(y_test))
test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=4)
print(len(test_dataset), 'messages')

# Prepare the embedding matrix:
print('Preparing embedding matrix.')
embedding_matrix = np.zeros((MAX_NUM_WORDS, EMBEDDING_DIM))
n_not_found = 0
for word, i in word_index.items():
    if i >= MAX_NUM_WORDS-2:
        continue
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        # words not found in embedding index will be all-zeros.
        embedding_matrix[i+2] = embedding_vector
    else:
        n_not_found += 1
embedding_matrix = torch.FloatTensor(embedding_matrix)
print('Shape of embedding matrix:', embedding_matrix.shape)
print('Words not found in pre-trained embeddings:', n_not_found)

# Setting model name
if args.model_name is None:
    args.model_name = args.model
# Model definition
model = Model_CNN(embedding_matrix).to(device)

# Optimiser definition
if args.optimiser == "rmsprop":
    optimizer = optim.RMSprop(model.parameters(), lr=0.005)
elif args.optimiser == "sgd":
    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
elif args.optimiser == "adam":
    optimizer = optim.Adam(model.parameters())  #torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)           


# Initialize the prev_epoch_adabin_dict for epoch=0 (for auto_adaptive_focal_loss)
prev_epoch_adabin_dict = collections.defaultdict(dict)
default_num_bins = args.num_bins[0]
for bin_no in range(default_num_bins):
    bin_lower, bin_upper = bin_no*(1/default_num_bins), (bin_no+1)*(1/default_num_bins)
    prev_epoch_adabin_dict[bin_no]['lower_bound'] = bin_lower
    prev_epoch_adabin_dict[bin_no]['upper_bound'] = bin_upper
    prev_epoch_adabin_dict[bin_no]['prop_in_bin'] = 1/default_num_bins
    prev_epoch_adabin_dict[bin_no]['accuracy_in_bin'] = (bin_lower+bin_upper)/2.0
    prev_epoch_adabin_dict[bin_no]['avg_confidence_in_bin'] = (bin_lower+bin_upper)/2.0
    prev_epoch_adabin_dict[bin_no]['ece'] = prev_epoch_adabin_dict[bin_no]['avg_confidence_in_bin'] - prev_epoch_adabin_dict[bin_no]['accuracy_in_bin']
    prev_epoch_adabin_dict[bin_no]['gamma_next_epoch'] = args.gamma

best_val_acc = 0
# Training loop
for epoch in range(0, args.num_epochs):
    if (args.loss_function == 'focal_loss' and args.gamma_schedule == 1):
        if (epoch < args.gamma_schedule_step1):
            gamma = args.gamma
        elif (epoch >= args.gamma_schedule_step1 and epoch < args.gamma_schedule_step2):
            gamma = args.gamma2
        else:
            gamma = args.gamma3
    else:
        gamma = args.gamma
    
    train_loss = train_single_epoch(epoch,
                                    model,
                                    train_loader,
                                    optimizer,
                                    device,
                                    loss_function=args.loss_function,
                                    gamma=gamma,
                                    lamda=args.lamda,
                                    loss_mean=args.loss_mean,
                                    prev_epoch_adabin_dict=prev_epoch_adabin_dict,
                                    gamma_lambda=args.gamma_lambda,
                                    adafocal_start_epoch=args.adafocal_start_epoch)
    
    val_loss = test_single_epoch(epoch,
                                 model,
                                 val_loader,
                                 device,
                                 loss_function='cross_entropy',
                                 gamma=gamma,
                                 lamda=args.lamda,
                                 prev_epoch_adabin_dict=prev_epoch_adabin_dict,
                                 gamma_lambda=args.gamma_lambda,
                                 adafocal_start_epoch=args.adafocal_start_epoch)
    
    test_loss = test_single_epoch(epoch,
                                  model,
                                  test_loader,
                                  device,
                                  loss_function='cross_entropy',
                                  gamma=gamma,
                                  lamda=args.lamda,
                                  prev_epoch_adabin_dict=prev_epoch_adabin_dict,
                                  gamma_lambda=args.gamma_lambda,
                                  adafocal_start_epoch=args.adafocal_start_epoch)

    for num_bins in args.num_bins:
        # Evaluate val set
        val_confusion_matrix, val_acc, val_labels, val_predictions, val_confidences, val_logits = \
                                        test_classification_net(model, val_loader, device, num_bins=num_bins, num_labels=NUM_CLASSES)
        val_ece, val_bin_dict = expected_calibration_error(val_confidences, val_predictions, val_labels, num_bins=num_bins)
        val_mce = maximum_calibration_error(val_confidences, val_predictions, val_labels, num_bins=num_bins)
        val_adaece, val_adabin_dict = adaECE_error(val_confidences, val_predictions, val_labels, num_bins=num_bins)
        val_classwise_ece = ClasswiseECELoss(n_bins=num_bins)(val_logits, torch.tensor(val_labels))
        # Update the gamma for the next epoch
        if 'adafocal' in args.loss_function and epoch+1 >= args.adafocal_start_epoch:
            for bin_num in range(num_bins):
                next_gamma = prev_epoch_adabin_dict[bin_num]['gamma_next_epoch'] * math.exp(args.gamma_lambda * val_adabin_dict[bin_num]['ece'])
                val_adabin_dict[bin_num]['gamma_next_epoch'] = min(next_gamma, args.gamma_max)
            prev_epoch_adabin_dict = val_adabin_dict

        # Evaluate test set
        test_confusion_matrix, test_acc, test_labels, test_predictions, test_confidences, test_logits = \
                                        test_classification_net(model, test_loader, device, num_bins=num_bins, num_labels=NUM_CLASSES)
        test_ece, test_bin_dict = expected_calibration_error(test_confidences, test_predictions, test_labels, num_bins=num_bins)
        test_mce = maximum_calibration_error(test_confidences, test_predictions, test_labels, num_bins=num_bins)
        test_adaece, test_adabin_dict = adaECE_error(test_confidences, test_predictions, test_labels, num_bins=num_bins)
        test_classwise_ece = ClasswiseECELoss(n_bins=num_bins)(test_logits, torch.tensor(test_labels))

        # Metric logging
        output_train_file = os.path.join(args.save_loc, "train_log_"+str(num_bins)+"bins.txt")
        if not os.path.isdir(args.save_loc):
            os.mkdir(args.save_loc)
        with open(output_train_file, "a") as writer:
            # epoch, train_loss, val_loss, test_loss, val_error, val_ece, val_mce, val_classwise_ece, test_error, test_ece, test_mce, test_classwise_ece
            writer.write("%d\t" % (epoch))
            writer.write("%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % (train_loss, val_loss, test_loss, 1 - val_acc, val_ece, val_mce, val_classwise_ece, val_adaece))
            writer.write("%.4f\t%.4f\t%.4f\t%.4f\t%.4f" % (1 - test_acc, test_ece, test_mce, test_classwise_ece, test_adaece))
            writer.write("\n")

        # Save the val_bin_dict, test_bin_dict to json files
        val_bin_dict_file = os.path.join(args.save_loc, "val_bin_dict_"+str(num_bins)+"bins.txt")
        with open(val_bin_dict_file, "a") as write_file:
            json.dump(val_bin_dict, write_file) 
            write_file.write("\n")
        test_bin_dict_file = os.path.join(args.save_loc, "test_bin_dict_"+str(num_bins)+"bins.txt")
        with open(test_bin_dict_file, "a") as write_file:
            json.dump(test_bin_dict, write_file) 
            write_file.write("\n")

        # Save the val_adabin_dict, test_adabin_dict to json files
        val_adabin_dict_file = os.path.join(args.save_loc, "val_adabin_dict_"+str(num_bins)+"bins.txt")
        with open(val_adabin_dict_file, "a") as write_file:
            json.dump(val_adabin_dict, write_file) 
            write_file.write("\n")
        test_adabin_dict_file = os.path.join(args.save_loc, "test_adabin_dict_"+str(num_bins)+"bins.txt")
        with open(test_adabin_dict_file, "a") as write_file:
            json.dump(test_adabin_dict, write_file) 
            write_file.write("\n")

    # Save models
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        print('New best error: %.4f' % (1 - best_val_acc))
        save_name = args.save_loc + '/' + \
                    args.model_name + '_' + \
                    loss_function_save_name(args.loss_function, args.gamma_schedule, gamma, args.gamma, args.gamma2, args.gamma3, args.lamda) + \
                    '_best_' + \
                    str(epoch + 1) + '.model'
        torch.save(model.state_dict(), save_name)

    if (epoch + 1) % args.save_interval == 0:
        save_name = args.save_loc + '/' + \
                    args.model_name + '_' + \
                    loss_function_save_name(args.loss_function, args.gamma_schedule, gamma, args.gamma, args.gamma2, args.gamma3, args.lamda) + \
                    '_' + str(epoch + 1) + '.model'
        torch.save(model.state_dict(), save_name)