import sys
sys.path.insert(0, "..")

from datasets.datasets import Datasets
import random
from torch.utils.data.dataloader import DataLoader
import torch
import os
from os.path import join
from analysis_argument_parser import argument_parser
from utils import model_loader
import matplotlib.pyplot as plt
from analysis.plot import concat_images
from metrics.cluster import map_confusion_matrix_labels

def plot_imgs(data, save_path):
    plt.figure()
    ax1 = plt.subplot(1, 1, 1)
    ax1.imshow(data, cmap='gray')
    plt.axis('off')
    plt.draw()
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    plt.close()

# Argument Parser
args = argument_parser()

out_folder = args.out_folder if args.out_folder.endswith("/") else args.out_folder + "/"
if not os.path.exists(os.path.dirname(out_folder)):
    os.makedirs(os.path.dirname(out_folder), exist_ok=True)

use_cuda = torch.cuda.is_available() and args.cuda

if use_cuda:
    torch.cuda.init()

device = torch.device('cuda:0' if use_cuda else 'cpu')

root = args.root
dataset_path = args.dataset
debug = args.debug
n_samples = args.n_samples
model = args.model
topk = args.topk


autoencoder_som, epochs, manual_seed, _, _, _ = model_loader.load_autoencodersom_model(model, device)
random.seed(manual_seed)
torch.manual_seed(manual_seed)

dataset = Datasets(dataset=dataset_path, root_folder=root, debug=debug,
                   n_samples=n_samples,)

test_loader = DataLoader(dataset.test_data, shuffle=False)

autoencoder_som.eval()

predicted_clusters, predicted_labels, true_labels, cluster_result = autoencoder_som.cluster(test_loader)

map_labels = map_confusion_matrix_labels(true_labels, predicted_clusters)

prototypes, relevance, moving_avg = autoencoder_som.som.get_prototypes()
prototypes_decoded = autoencoder_som.decoder(prototypes)

test_set_array = dataset.test_data.data.numpy()

for i, element in enumerate(map_labels):

    dists = []
    for batch_idx, (sample, target) in enumerate(test_loader):
        sample = sample.reshape(-1).to(device)
        dist = torch.mean((sample - prototypes_decoded[element[0]]) ** 2, dtype=torch.float)
        dists.append(dist)
    dists_res = torch.Tensor(dists).to(device)

    knn = dists_res.topk(topk, largest=False)


    plot_imgs(data=prototypes_decoded[element[0]].view(28,28).cpu().detach(), 
            save_path=join(out_folder, args.dataset.split('.')[0] + '_target_' + str(element[1]) + 
                            '_prototype_.png'))


    for k in range(0,topk):
        plot_imgs(data=test_set_array[knn.indices[k]], 
                save_path=join(out_folder, args.dataset.split('.')[0] + '_target_' + str(element[1]) + 
                                '_topk_' + str(k) + '.png'))

    if i > 10:
        exit()

