import os
import argparse
import sys
import pickle
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from matplotlib.pyplot import cm
import matplotlib.mlab as mlab
from matplotlib.ticker import NullFormatter
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib import cm

import torch
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix, plot_confusion_matrix
import pandas as pd
import seaborn as sns

from init import get_params, set_model


def str2bool(v):
    return v.lower() in ('true', '1')


def get_data(args):

    assert args.dataset in ['CIFAR10', 'SVHN', 'STL10']

    colors = ['red', 'green', 'blue', 'cyan', 'magenta', 'yellow', 'pink', 'gray', 'orange', 'purple']

    if args.dataset == 'CIFAR10':
        from datasets.cifar import get_test_loader
        labels=['airplane', 'automobile', 'bird', 'cat', 'deer','dog','frog','horse','ship','truck']
    elif args.dataset == 'SVHN':
        from datasets.svhn import get_test_loader
        labels=['0', '1', '2', '3', '4','5','6','7','8','9']
    elif args.dataset == 'STL10':
        from datasets.stl10 import get_test_loader
        labels=['airplane', 'bird', 'car', 'cat', 'deer','dog','horse','monkey','ship','truck']

    dltest = get_test_loader(dataset=args.dataset, batch_size=args.valbatchsize, num_workers=2)

    return dltest, labels, colors


def t_sne_vis(args, logit_list, ground_list, labels, colors):
    X_sample = logit_list.cpu().numpy()
    y_sample = ground_list.cpu().numpy()    
    
    if args.compute_embeddings:
        print("X_sample: {}".format(X_sample.shape))
        print("y_sample: {}".format(y_sample.shape))

        # flatten images to (N, D) for feeding to t-SNE
        X_sample_flat = np.reshape(X_sample, [X_sample.shape[0], -1])
        # compute t-SNE embeddings
        embeddings = TSNE(n_components=args.num_dimensions, init='pca', verbose=2, perplexity=50).fit_transform(X_sample_flat)
    else:
        pass
        # If you alraedy had embeddings, just loading it

    print('Plotting...')
    if args.num_dimensions == 3:    # 3-D plot
        # safeguard
        if args.with_images == True:
            sys.exit("Cannot plot images with 3D plots.")

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        xx = embeddings[:, 0]
        yy = embeddings[:, 1]
        zz = embeddings[:, 2]

        # plot the 3D data points
        for i in range(args.num_classes):
            ax.scatter(xx[y_sample==i], yy[y_sample==i], zz[y_sample==i], color=colors[i], label=labels[i], s=10)

        ax.xaxis.set_major_formatter(NullFormatter())
        ax.yaxis.set_major_formatter(NullFormatter())
        ax.zaxis.set_major_formatter(NullFormatter())
        plt.title(args.ckpt_name)
        plt.axis('tight')
        plt.legend(loc='best', scatterpoints=1, fontsize=5.0, markerscale=1.0)
        plt.savefig(args.vis_dir + '/' + args.ckpt_name +'.png', format='png', dpi=600)
        plt.show()
    else:   # 2-D plot
        fig = plt.figure()
        ax = fig.add_subplot(111)
        
        xx = embeddings[:, 0]
        yy = embeddings[:, 1]        

        # plot the images
        if args.with_images == True:
            for i, (x, y) in enumerate(zip(xx, yy)):
                im = OffsetImage(X_img[i], zoom=0.1, cmap='gray')
                ab = AnnotationBbox(im, (x, y), xycoords='data', frameon=False)
                ax.add_artist(ab)
            ax.update_datalim(np.column_stack([xx, yy]))
            ax.autoscale()

        # plot the 2D data points
        for i in range(args.num_classes):
            ax.scatter(xx[y_sample==i], yy[y_sample==i], label=labels[i], s=10, color=colors[i])

        ax.xaxis.set_major_formatter(NullFormatter())
        ax.yaxis.set_major_formatter(NullFormatter())
        plt.title(args.ckpt_name)
        plt.axis('tight')
        plt.legend(loc='best', scatterpoints=1, fontsize=5.0, markerscale=1.0)
        plt.savefig(args.vis_dir + '/' + args.ckpt_name +'.png', format='png', dpi=1000)
        plt.show()


def cmatrix(y_true, y_pred, filename, labels, ymap=None, figsize=(10,10)):
    """
    Generate matrix plot of confusion matrix with pretty annotations.
    The plot image is saved to disk.
    args: 
      y_true:    true label of the data, with shape (nsamples,)
      y_pred:    prediction of the data, with shape (nsamples,)
      filename:  filename of figure file to save
      labels:    string array, name the order of class labels in the confusion matrix.
                 use `clf.classes_` if using scikit-learn models.
                 with shape (nclass,).
      ymap:      dict: any -> string, length == nclass.
                 if not None, map the labels & ys to more understandable strings.
                 Caution: original y_true, y_pred and labels must align.
      figsize:   the size of the figure plotted.
    """
    if ymap is not None:
        y_pred = [ymap[yi] for yi in y_pred]
        y_true = [ymap[yi] for yi in y_true]
        labels = [ymap[yi] for yi in labels]
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum.astype(float) * 100
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = '%.1f%%\n%d' % (p, c)
    cm = pd.DataFrame(cm, index=labels, columns=labels)
    cm.index.name = 'Actual'
    cm.columns.name = 'Predicted'
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(cm, annot=annot, fmt='', ax=ax)
    plt.savefig(filename)


#############################################################################################
# Options
#############################################################################################
parser = argparse.ArgumentParser(description='Semi-supervised Learning Visualization')

parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0  0,1,2  0,2')

parser.add_argument('--backbone', type=str, default='wideresnet', help='Wideresnet')
parser.add_argument('--wresnet-k', default=2, type=int, help='width factor of wide resnet')
parser.add_argument('--wresnet-n', default=28, type=int, help='depth of wide resnet')
parser.add_argument('--large_model', action='store_true', help='default is False. If True, using WideResnetLarge model')
parser.add_argument('--dataset', type=str, default='CIFAR10', help='CIFAR10, SVHN, or STL10')
parser.add_argument('--valbatchsize', default=100, type=int, help='validation batch size')
parser.add_argument('--checkpoint', type=str, default='model_normal', help='checkpoint to run evaluation')
parser.add_argument('--random_seed', type=int, default=42, help='seed to ensure reproducibility')

parser.add_argument('--vis_dir', type=str, default='./vis_dir', help='directory where visualizations are saved')
parser.add_argument('--with_images', type=str2bool, default=False, help='whether to overlay images on data points. Only works with 2D plots.')
parser.add_argument('--num_dimensions', type=int, default=2, help='t-SNE dimension. Can be 2 or 3.')
parser.add_argument('--compute_embeddings', type=str2bool, default=True, help='Whether to compute embeddings. Do this once per sample size.')
#############################################################################################


def main():
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids 

    np.random.seed(args.random_seed)

    args.num_classes, _ = get_params(args.dataset)

    model = set_model(args.num_classes, args.wresnet_k, args.wresnet_n, stl=True) if args.dataset == 'STL10' else set_model(args.num_classes, args.wresnet_k, args.wresnet_n, large=args.large_model)
    print("Total params: {:.2f}M".format(sum(p.numel() for p in model.parameters()) / 1e6))

    args.ckpt_name = args.checkpoint.rsplit('/')[-1]
    state = torch.load(args.checkpoint)
    model.load_state_dict(state)
    print('{} loaded'.format(args.checkpoint))
    model.eval()

    if not os.path.exists(args.vis_dir):
        os.makedirs(args.vis_dir)

    # load data
    dltest, labels, colors = get_data(args)

    eps=8.0/255.0
    acc=0
    pred_list = []
    ground_list = []

    for i,(ims,lbs) in enumerate(dltest):
        if i==0:
            ims = ims.cuda()
            lbs = lbs.cuda()
            logits,_ = model(ims)
            _, pred = logits.max(1)
            logit_list = logits.data
            pred_list = pred
            ground_list = lbs
        else:
            ims = ims.cuda()
            lbs = lbs.cuda()
            logits,_ = model(ims)
            _,pred = logits.max(1)
            logit_list = torch.cat((logit_list, logits.data), dim=0).cuda()
            pred_list = torch.cat((pred_list, pred.data), dim=0).cuda()
            ground_list = torch.cat((ground_list, lbs), dim=0).cuda()

    if labels is None:
        labels = np.arange(args.num_classes)
    
    t_sne_vis(args, logit_list, ground_list, labels, colors)
    cmatrix(ground_list.cpu(), pred_list.cpu(), args.vis_dir + '/' + args.ckpt_name +'_cm.png', np.arange(args.num_classes))


if __name__ == '__main__':
    main()
