import os
import sys
import argparse
import random
import json
import collections
import math

from distutils.version import LooseVersion as LV

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader
import torch.backends.cudnn as cudnn
from gensim.utils import simple_preprocess
from gensim.corpora import Dictionary
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

from models import Model_CNN

# Import metrics to compute
from Metrics.metrics import test_classification_net_logits
from Metrics.metrics import ECELoss, AdaptiveECELoss, ClasswiseECELoss

# Import temperature scaling and NLL utilities
from temperature_scaling import ModelWithTemperature

# From Kumar et. al.
import calibration as cal


def parseArgs():
    test_batch_size = 128
    save_loc = './'
    saved_model_name = "cnn_cross_entropy_50.model"
    cross_validation_error = 'ece'

    parser = argparse.ArgumentParser(description="Evaluating a single model on calibration metrics.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--save-path", type=str, default=save_loc, dest="save_loc", help='Path to import the model')
    parser.add_argument("--saved_model_name", type=str, default=saved_model_name, dest="saved_model_name", help="file name of the pre-trained model")
    parser.add_argument("--num-bins", type=int, default=15, dest="num_bins", help='Number of bins')
    parser.add_argument("-tb", type=int, default=test_batch_size, dest="test_batch_size", help="Test Batch size")
    parser.add_argument("--cverror", type=str, default=cross_validation_error, dest="cross_validation_error", help='Error function to do temp scaling')
    parser.add_argument("-log", action="store_true", dest="log", help="whether to print log data")

    return parser.parse_args()


def get_logits_labels(data_loader, net):
    logits_list = []
    labels_list = []
    net.eval()
    with torch.no_grad():
        for data, label in data_loader:
            data = data.cuda()
            logits = net(data)
            logits_list.append(logits)
            labels_list.append(label)
        logits = torch.cat(logits_list).cuda()
        labels = torch.cat(labels_list).cuda()
    return logits, labels



if __name__ == "__main__":

    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    cuda = False
    if torch.cuda.is_available(): cuda = True
    # Setting additional parameters
    torch.manual_seed(1)
    device = torch.device("cuda" if cuda else "cpu")

    args = parseArgs()

    DATADIR = "./"
    GLOVE_DIR = os.path.join(DATADIR, "glove.6B")
    TEXT_DATA_DIR = os.path.join(DATADIR, "20_newsgroup")
    MAX_SEQUENCE_LENGTH = 1000
    MAX_NUM_WORDS = 20000 # 2 words reserved: 0=pad, 1=oov
    EMBEDDING_DIM = 100
    NUM_CLASSES = 20

    # GloVe word embeddings
    # The datafile contains 100-dimensional embeddings for 400,000 English words.
    print('Indexing word vectors.')
    embeddings_index = {}
    with open(os.path.join(GLOVE_DIR, 'glove.6B.100d.txt')) as f:
        for line in f:
            values = line.split()
            word = values[0]
            coefs = np.asarray(values[1:], dtype='float32')
            embeddings_index[word] = coefs
    print('Found %s word vectors.' % len(embeddings_index))

    # 20 Newsgroups data set
    print('Processing text dataset')
    texts = []  # list of text samples
    labels_index = {}  # dictionary mapping label name to numeric id
    labels = []  # list of label ids
    for name in sorted(os.listdir(TEXT_DATA_DIR)):
        path = os.path.join(TEXT_DATA_DIR, name)
        if os.path.isdir(path):
            label_id = len(labels_index)
            labels_index[name] = label_id
            for fname in sorted(os.listdir(path)):
                if fname.isdigit():
                    fpath = os.path.join(path, fname)
                    fargs = {} if sys.version_info < (3,) else {'encoding': 'latin-1'}
                    with open(fpath, **fargs) as f:
                        t = f.read()
                        i = t.find('\n\n')  # skip header
                        if 0 < i:
                            t = t[i:]
                        texts.append(t)
                    labels.append(label_id)
    print('Found %s texts.' % len(texts))

    # Tokenize the texts using gensim.
    tokens = list()
    for text in texts:
        tokens.append(simple_preprocess(text))

    # Vectorize the text samples into a 2D integer tensor.
    dictionary = Dictionary(tokens)
    dictionary.filter_extremes(no_below=0, no_above=1.0, keep_n=MAX_NUM_WORDS-2)
    word_index = dictionary.token2id
    print('Found %s unique tokens.' % len(word_index))

    data = [dictionary.doc2idx(t) for t in tokens]

    # Truncate and pad sequences.
    data = [i[:MAX_SEQUENCE_LENGTH] for i in data]
    data = np.array([np.pad(i, (0, MAX_SEQUENCE_LENGTH-len(i)), mode='constant', constant_values=-2)
                     for i in data], dtype=int)
    data = data + 2
    print('Shape of data tensor:', data.shape)
    print('Length of label vector:', len(labels))

    # Split the data into a training set and a validation set
    VALIDATION_SET, TEST_SET = 900, 3999
    x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=TEST_SET, shuffle=True, random_state=42)
    x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=VALIDATION_SET, shuffle=False)
    print('Shape of training data tensor:', x_train.shape)
    print('Length of training label vector:', len(y_train))
    print('Shape of validation data tensor:', x_val.shape)
    print('Length of validation label vector:', len(y_val))
    print('Shape of test data tensor:', x_test.shape)
    print('Length of test label vector:', len(y_test))

    # Create PyTorch DataLoaders for all data sets
    print('Validation: ', end="")
    validation_dataset = TensorDataset(torch.LongTensor(x_val), torch.LongTensor(y_val))
    val_loader = DataLoader(validation_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=4)
    print(len(validation_dataset), 'messages')

    print('Test: ', end="")
    test_dataset = TensorDataset(torch.LongTensor(x_test), torch.LongTensor(y_test))
    test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=4)
    print(len(test_dataset), 'messages')

    # Prepare the embedding matrix:
    print('Preparing embedding matrix.')
    embedding_matrix = np.zeros((MAX_NUM_WORDS, EMBEDDING_DIM))
    n_not_found = 0
    for word, i in word_index.items():
        if i >= MAX_NUM_WORDS-2:
            continue
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            # words not found in embedding index will be all-zeros.
            embedding_matrix[i+2] = embedding_vector
        else:
            n_not_found += 1
    embedding_matrix = torch.FloatTensor(embedding_matrix)
    print('Shape of embedding matrix:', embedding_matrix.shape)
    print('Words not found in pre-trained embeddings:', n_not_found)


    # Model definition
    net = Model_CNN(embedding_matrix) # model.cuda() and model.to(device) are the same
    net.cuda()
    # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True
    net.load_state_dict(torch.load(args.save_loc + '/' + args.saved_model_name))

    nll_criterion = nn.CrossEntropyLoss().cuda()
    ece_criterion = ECELoss().cuda()
    adaece_criterion = AdaptiveECELoss().cuda()
    cece_criterion = ClasswiseECELoss().cuda()

    logits, labels = get_logits_labels(test_loader, net)
    conf_matrix, p_accuracy, _, _, _ = test_classification_net_logits(logits, labels)

    p_ece = ece_criterion(logits, labels).item()
    p_adaece = adaece_criterion(logits, labels).item()
    p_cece = cece_criterion(logits, labels).item()
    p_nll = nll_criterion(logits, labels).item()

    # Printing the required evaluation metrics
    if args.log:
        # print (conf_matrix)
        print ('ECE: {:.2f}'.format(p_ece*100))
        print ('AdaECE: {:.2f}'.format(p_adaece*100))
        print ('Classwise ECE: {:.2f}'.format(p_cece*100))
        print ('Test error: {:.2f}'.format((1 - p_accuracy)*100), '\n')
        # print ('Test NLL: {:.2f}'.format(p_nll))

    scaled_model = ModelWithTemperature(net, args.log)
    scaled_model.set_temperature(val_loader, cross_validate=args.cross_validation_error)
    T_opt = scaled_model.get_temperature()
    logits, labels = get_logits_labels(test_loader, scaled_model)
    conf_matrix, accuracy, _, _, _ = test_classification_net_logits(logits, labels)

    ece = ece_criterion(logits, labels).item()
    adaece = adaece_criterion(logits, labels).item()
    cece = cece_criterion(logits, labels).item()
    nll = nll_criterion(logits, labels).item()

    if args.log:
        print ('\nOptimal temperature: {:.2f}'.format(T_opt))
        # print (conf_matrix)
        print ('ECE: {:.2f}'.format(ece*100))
        print ('AdaECE: {:.2f}'.format(adaece*100))
        print ('Classwise ECE: {:.2f}'.format(cece*100))
        # print ('Test error: {:.2f}'.format((1 - accuracy)*100))
        # print ('Test NLL: {:.2f}'.format(nll))