import sys
sys.path.insert(0, "..")

from datasets.datasets import Datasets
import random
from torch.utils.data.dataloader import DataLoader
import metrics
import torch
import os
from os.path import join
from analysis.analysis_argument_parser import argument_parser
from utils import model_loader
import matplotlib.pyplot as plt
import numpy as np
from metrics.cluster import map_confusion_matrix_labels
from plot import plot_image as custom_plot
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt



def plot_image(title='My Title', data=None, save_path='./data.png', figsize=(20, 20), constrained_layout=False, cmap='gray', save=False,handles=None):

    plt.figure(figsize=figsize, constrained_layout=constrained_layout)
    ax1 = plt.subplot(1, 1, 1)
    ax1.imshow(data, cmap=cmap)
    #ax1.set_title(title)
    plt.legend(handles=handles, fontsize="x-large")
    plt.axis('off')

    if(save):
        plt.draw()
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
        #plt.savefig(save_path[0:-3] + 'eps')
        plt.close()
    else:
        plt.show()

# Argument Parser
args = argument_parser()

out_folder = args.out_folder if args.out_folder.endswith("/") else args.out_folder + "/"
if not os.path.exists(os.path.dirname(out_folder)):
    os.makedirs(os.path.dirname(out_folder), exist_ok=True)

use_cuda = torch.cuda.is_available() and args.cuda

if use_cuda:
    torch.cuda.init()
device = torch.device('cuda:0' if use_cuda else 'cpu')

flag = 'inter'#'union'#args.comp_flag
cluster_list = [3, 9, 49]#[51,36,44]#[3, 9, 49]# args.cluster_list

autoencoder_som, epochs, manual_seed, _, _, _ = model_loader.load_autoencodersom_model(args.model, device)
random.seed(manual_seed)
torch.manual_seed(manual_seed)

dataset = Datasets(dataset=args.dataset, root_folder=args.root, debug=args.debug,
                   n_samples=args.n_samples)

test_loader = DataLoader(dataset.test_data, shuffle=False)

autoencoder_som.eval()

predicted_clusters, predicted_labels, true_labels, cluster_result = autoencoder_som.cluster(test_loader)

map_labels = map_confusion_matrix_labels(true_labels, predicted_clusters)

print(map_labels)

prototypes, relevance, moving_avg = autoencoder_som.som.get_prototypes()
prototypes_decoded = autoencoder_som.decoder(prototypes)

'''
for i in range(len(prototypes_decoded)):
    custom_plot(title='{}:{}'.format(map_labels[i][0],map_labels[i][1]),
        data=prototypes_decoded[i].cpu().detach().view(dataset.hw_in, dataset.hw_in).numpy(), save_path=join(args.out_folder,"{0}.png".format(i)), figsize=(20, 20), constrained_layout=False, cmap='gray', save=True)
'''

cluster_composition = None

if flag == 'union':
    cluster_composition = np.zeros((dataset.hw_in, dataset.hw_in,3))
elif flag == 'inter':
    cluster_composition = np.ones((dataset.hw_in, dataset.hw_in,3))


pallete = [[0,0,1], [0,1,0], [1,0,0], [1,1,1]]
#cluster_list 0,  1, 2
target = 4
blue_patch = mpatches.Patch(color=pallete[0], label='Target {} : Cluster index {}'.format(target, cluster_list[0]))
green_patch = mpatches.Patch(color=pallete[1], label='Target {} : Cluster index {}'.format(target, cluster_list[1]))
red_patch = mpatches.Patch(color=pallete[2], label='Target {} : Cluster index {}'.format(target, cluster_list[2]))

handles = [blue_patch, green_patch, red_patch]

for i, cluster_idx in enumerate(cluster_list):
    cluster_idx = int(cluster_idx)
    if flag == 'union':
        cluster_composition[:,:,0] += pallete[i][0]*prototypes_decoded[cluster_idx].cpu().detach().view(dataset.hw_in, dataset.hw_in).numpy()
        cluster_composition[:,:,1] += pallete[i][1]*prototypes_decoded[cluster_idx].cpu().detach().view(dataset.hw_in, dataset.hw_in).numpy()
        cluster_composition[:,:,2] += pallete[i][2]*prototypes_decoded[cluster_idx].cpu().detach().view(dataset.hw_in, dataset.hw_in).numpy()
    elif flag == 'inter':
        #cluster_composition = np.ones((dataset.hw_in, dataset.hw_in,3))
        # cluster_composition[0] *= (prototypes_decoded[cluster_idx].cpu().detach().view(dataset.hw_in, dataset.hw_in).numpy())
        cluster_composition[:,:,0] *= (pallete[-1][0]*prototypes_decoded[cluster_idx].cpu().detach().view(dataset.hw_in, dataset.hw_in).numpy())
        cluster_composition[:,:,1] *= (pallete[-1][1]*prototypes_decoded[cluster_idx].cpu().detach().view(dataset.hw_in, dataset.hw_in).numpy())
        cluster_composition[:,:,2] *= (pallete[-1][2]*prototypes_decoded[cluster_idx].cpu().detach().view(dataset.hw_in, dataset.hw_in).numpy())
        handles = None
        plot_image(title='Cluster Index: {label}'.format(label=cluster_idx), 
            data=prototypes_decoded[cluster_idx].cpu().detach().view(dataset.hw_in, dataset.hw_in).numpy(), 
            save_path=join(args.out_folder,"_raw_{0}.png".format(cluster_idx)),save=True, handles=handles)

    plot_image(title='Cluster Index: {label}'.format(label=cluster_idx), data=cluster_composition, 
        save_path=join(args.out_folder,"{0}.png".format(cluster_idx)),save=True, handles=handles)