from __future__ import absolute_import, print_function
import warnings
import argparse
from torch.backends import cudnn
from torch.nn import functional as F
import torchvision.transforms as transforms
from ImageFolder import *
from utils import *
from sklearn.metrics.pairwise import euclidean_distances, cosine_distances, cosine_similarity
import torchvision.models as models
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import itertools
import pandas as pd
warnings.filterwarnings("ignore", category=DeprecationWarning)


def extract_features(model, data_loader):
    model = model.cuda()
    model.eval()

    features = []
    labels = []

    for i, data in enumerate(data_loader, 0):
        imgs, pids = data

        inputs = imgs.cuda()
        with torch.no_grad():
            outputs = model(inputs)
            outputs = torch.squeeze(outputs)
            outputs = F.normalize(outputs, p=2, dim=1)
            outputs = outputs.cpu().numpy()

        if features == []:
            features = outputs
            labels = pids
        else:
            features = np.vstack((features, outputs))
            labels = np.hstack((labels, pids))

    return features, labels


def plot_confusion_matrix(true_labels, predicted_labels, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues,
                          fontsize=12):  # Added fontsize parameter

    cm = confusion_matrix(true_labels, predicted_labels)


    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, fontsize=fontsize)
    plt.yticks(tick_marks, classes, fontsize=fontsize)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black",
                 fontsize=fontsize)

    plt.ylabel('True label', fontsize=12)
    plt.xlabel('Predicted label', fontsize=12)
    plt.tight_layout()
    plt.show()


parser = argparse.ArgumentParser(description='MLRPTM Testing')

parser.add_argument('-data', type=str, default='tiny-imagenet-200', help='path to Data Set')

parser.add_argument('-r', type=str, default='checkpoints\Tiny-imagenet-200',
                    metavar='PATH', help='path that the trained models had been saved')

parser.add_argument("-gpu", type=str, default='0', help='which gpu to choose')

parser.add_argument('-seed', default=1993, type=int, metavar='N',
                    help='the same seed as training process')

parser.add_argument('-epochs', default=50, type=int,
                    metavar='N', help='the same epochs as training process')

parser.add_argument('-task', default=20, type=int,
                    help='number of tasks')

parser.add_argument('-base', default=100, type=int,
                    help='number of classes in non_incremental_state')

args = parser.parse_args()

cudnn.benchmark = True
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

models_f = []

for i in os.listdir(args.r):
    if i.endswith("%d_model.pt" % args.epochs):
        models_f.append(os.path.join(args.r, i))

models_f.sort()
if args.task > 10:
    models_f.append(models_f[1])
    del models_f[1]

if args.data == "cifar100":
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    root = 'DataSet' + '/cifar100'
    traindir = os.path.join(root, 'train')
    testdir = os.path.join(root, 'test')
    num_classes = 100

if args.data == "cifar10":
    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                             (0.24703233, 0.24348505, 0.26158768)), ])

    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                             (0.24703233, 0.24348505, 0.26158768)),
    ])

    root = 'DataSet' + '/cifar10'
    traindir = os.path.join(root, 'train')
    testdir = os.path.join(root, 'test')
    num_classes = 10

if args.data == 'mini-imagenet-100':
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225)),
    ])
    root = 'DataSet' + '/mini-imagenet-100'
    traindir = os.path.join(root, 'train')
    testdir = os.path.join(root, 'test')
    num_classes = 100

if args.data == 'tiny-imagenet-200':
    transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225)),

    ])
    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225)),
    ])
    root = 'DataSet' + '/tiny-imagenet-200'
    traindir = os.path.join(root, 'train')
    testdir = os.path.join(root, 'test')
    num_classes = 200

num_task = args.task
num_class_per_task = int((num_classes - args.base) / (num_task - 1))
np.random.seed(args.seed)
random_perm = np.random.permutation(num_classes)

print('Test starting...\t')

class_mean = []
class_std = []
class_label = []

for task_id in range(num_task):

    index = random_perm[:args.base + task_id * num_class_per_task]
    if task_id == 0:
        index_train = random_perm[:args.base]
    else:
        index_train = random_perm[args.base +
                                  (task_id - 1) * num_class_per_task:args.base + task_id * num_class_per_task]

    trainfolder = ImageFolder(traindir, transform_train, index=index_train)
    testfolder = ImageFolder(testdir, transform_test, index=index)

    train_loader = torch.utils.data.DataLoader(
        trainfolder, batch_size=128, shuffle=False, drop_last=False)
    test_loader = torch.utils.data.DataLoader(
        testfolder, batch_size=128, shuffle=False, drop_last=False)
    if task_id != 0:
        print('Test %d\t' % task_id)

    model_res = models.resnet18(pretrained=False)
    model = nn.Sequential(*list(model_res.children())[:-1])
    state = torch.load(models_f[task_id])
    model.load_state_dict(state['state_dict'])

    train_embeddings_cl, train_labels_cl = extract_features(
        model, train_loader)
    val_embeddings_cl, val_labels_cl = extract_features(
        model, test_loader)

    # Test for each task
    mean_data_csv = []
    var_data_csv = []
    for i in index_train:
        ind_cl = np.where(i == train_labels_cl)[0]
        embeddings_tmp = train_embeddings_cl[ind_cl]
        class_label.append(i)
        class_mean.append(np.mean(embeddings_tmp, axis=0))
        mean_data_csv.append([i] + list(np.mean(embeddings_tmp, axis=0).flatten()))
        var_data_csv.append([i] + list(np.var(embeddings_tmp, axis=0).flatten()))

    embedding_mean_old = []
    embedding_std_old = []
    gt_all = []
    estimate_all = []

    acc_ave = 0
    for k in range(task_id + 1):
        if k == 0:
            tmp = random_perm[:args.base]
        else:
            tmp = random_perm[args.base +
                              (k - 1) * num_class_per_task:args.base + k * num_class_per_task]
        gt = np.isin(val_labels_cl, tmp)

        pairwise_distance = euclidean_distances(
            val_embeddings_cl, np.asarray(class_mean))
        estimate = np.argmin(pairwise_distance, axis=1)
        estimate_label = [index[j] for j in estimate]
        estimate_tmp = np.asarray(estimate_label)[gt]
        if task_id == num_task - 1:
            if estimate_all == []:
                estimate_all = estimate_tmp
                gt_all = val_labels_cl[gt]
            else:
                estimate_all = np.hstack((estimate_all, estimate_tmp))
                gt_all = np.hstack((gt_all, val_labels_cl[gt]))

        acc = np.sum(estimate_tmp ==
                     val_labels_cl[gt]) / float(len(estimate_tmp))
        if k == 0:
            acc_ave += acc * (float(args.base) /
                              (args.base + task_id * num_class_per_task))

            # classes = np.unique(index)
            # plot_confusion_matrix(val_labels_cl[gt], estimate_tmp, classes)
        else:
            acc_ave += acc * (float(num_class_per_task) /
                              (args.base + task_id * num_class_per_task))

            classes = np.unique(index)
            plot_confusion_matrix(val_labels_cl[gt], estimate_tmp, classes)
        if task_id != 0:
            print("Accuracy of Model %d on Task %d is %.3f" % (task_id, k, acc))
    if task_id != 0:
        print("Weighted Accuracy of Model %d is %.3f" % (task_id, acc_ave))
