import torch
import torch.backends.cudnn as cudnn
import numpy as np
from argument_parser import argument_parser
from datasets.datasets import Datasets
from torch.utils.data.dataloader import DataLoader
import pickle
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import os
from os.path import join, isfile
from utils import model_loader
from metrics.manifold import tsne
from analysis.plot import check_plot_save


def extract_features(data_loader, model, model_path, device, save, output_folder, debug, n_samples, train=True):
    extracted_features = []
    targets = []

    for samples, labels in tqdm(data_loader):
        sample = samples.to(device)
        label = labels.to(device)
        output = model.autoencoder_extract_features(sample)

        sample_cpu = output[0].to('cpu').data.numpy()
        label_cpu = int(label[0].to('cpu').data.numpy())

        extracted_features.append(sample_cpu)
        targets.append(label_cpu)

    if save:
        suffix = filename_pattern(debug, n_samples, train)
        model_name = os.path.splitext(os.path.basename(model_path))[0]

        with open(join(output_folder, model_name + "_encoded_features" + suffix + ".pickle"), 'wb') as f:
            pickle.dump(extracted_features, f)

        with open(join(output_folder, model_name + "_targets" + suffix + ".pickle"), 'wb') as f:
            pickle.dump(targets, f)

    return extracted_features, targets


def load_tsne_data(data_loader, model_path, model, debug, n_samples, output_folder, train=True):
    suffix = filename_pattern(debug, n_samples, train)
    model_name = os.path.splitext(os.path.basename(model_path))[0]

    features_file = join(output_folder, model_name + "_encoded_features" + suffix + ".pickle")
    targets_file = join(output_folder, model_name + "_targets" + suffix + ".pickle")

    if isfile(features_file) and isfile(targets_file):
        with open(features_file, 'rb') as f:
            features = pickle.load(f)

        with open(targets_file, 'rb') as f:
            targets = pickle.load(f)
    else:
        features, targets = extract_features(data_loader, model, model_path,
                                             device, save, out_folder, debug, n_samples, train)

    return features, targets, suffix


def filename_pattern(debug, n_samples, train=True):
    suffix = ""

    if train:
        suffix += "_train"
    else:
        suffix += "_test"

    if debug:
        suffix += "_n" + str(n_samples)

    return suffix


def run_tsne(features, prototypes, suffix):
    title = "tsne_output" + suffix
    filename = join(out_folder, title + ".pickle")

    if isfile(filename):
        with open(filename, 'rb') as f:
            tsne_output = pickle.load(f)
    else:
        tsne_features = np.concatenate((features, prototypes), axis=0)
        tsne_output = tsne(tsne_features)

    if save:
        with open(filename, 'wb') as f:
            pickle.dump(tsne_output, f)

    return title, tsne_output


def plot_tsne_output(title, tsne_output, model, targets, neighbors=True, plot=True, save=False, path=None):
    node_control_mask = model.som.node_control.bool()
    som_prototypes = model.som.weights[node_control_mask]
    som_prototypes = som_prototypes.cpu().data.numpy()

    tsne_dataset_output_x = tsne_output.transpose()[0][:-len(som_prototypes)]
    tsne_dataset_output_y = tsne_output.transpose()[1][:-len(som_prototypes)]

    tsne_som_output_x = tsne_output.transpose()[0][-len(som_prototypes):]
    tsne_som_output_y = tsne_output.transpose()[1][-len(som_prototypes):]

    som_targets = ["Prototypes"] * len(som_prototypes)

    # sns.set(rc={'figure.figsize':(11.7,8.27)})
    palette = sns.color_palette("bright", 10)
    plt.title(title)

    sns.scatterplot(tsne_dataset_output_x,
                    tsne_dataset_output_y,
                    hue=targets,
                    legend='full', palette=palette)

    sns.scatterplot(tsne_som_output_x,
                    tsne_som_output_y,
                    hue=som_targets,
                    s=130,
                    legend='full',
                    palette=sns.dark_palette("purple", 1),
                    alpha=0.7)

    if neighbors:
        plot_neighbors(model, tsne_som_output_x, tsne_som_output_y)

    check_plot_save(path, save, plot)


def plot_neighbors(model, tsne_som_output_x, tsne_som_output_y):
    activations_nn = model.som.activation(model.som.weights[node_control_mask])
    activations_nn = activations_nn.t()[node_control_mask].t()

    activations_connect = activations_nn >= 0.991
    activations_connect = activations_connect.cpu()

    for i in range(len(activations_connect)):
        prototype_nn = activations_connect[i]

        connections_x = tsne_som_output_x[prototype_nn]
        connections_y = tsne_som_output_y[prototype_nn]

        plt.plot(connections_x, connections_y, 'black', lw=1, alpha=0.7)
        plt.legend(loc='best', bbox_to_anchor=(0.7, 0.3, 0.5, 0.5))


args = argument_parser()

out_folder = args.out_folder if args.out_folder.endswith("/") else args.out_folder + "/"
if not os.path.exists(os.path.dirname(out_folder)):
    os.makedirs(os.path.dirname(out_folder), exist_ok=True)

use_cuda = torch.cuda.is_available() and args.cuda

if use_cuda:
    torch.cuda.init()

device = torch.device('cuda:0' if use_cuda else 'cpu')
ngpu = int(args.ngpu)

root = args.root
test_root = args.test_root
dataset_path = args.dataset
debug = args.debug
n_samples = args.n_samples

save = args.save
load = args.load
model = args.model

if not load or model is None:
    #TODO
    print("For now, you must have a pre-trained model beforehand to visualize the topology.")
    exit(0)

autoencoder_som, _, _, _, _, _ = model_loader.load_autoencodersom_model(model, device)

if use_cuda:
    autoencoder_som.cuda()
    cudnn.benchmark = True

dataset = Datasets(dataset=dataset_path, root_folder=root,
                   debug=debug, n_samples=n_samples)

train_loader = DataLoader(dataset.train_data, shuffle=True)
test_loader = DataLoader(dataset.test_data, shuffle=False)

train_features, train_targets, train_suffix = load_tsne_data(train_loader, model, autoencoder_som,
                                                             debug, n_samples, out_folder)
test_features, test_targets, test_suffix = load_tsne_data(test_loader, model, autoencoder_som,
                                                          debug, n_samples, out_folder, train=False)

node_control_mask = autoencoder_som.som.node_control.bool()
som_prototypes = autoencoder_som.som.weights[node_control_mask]
som_prototypes = som_prototypes.cpu().data.numpy()

train_title, tsne_train_output = run_tsne(train_features, som_prototypes, train_suffix)
test_title, tsne_test_output = run_tsne(test_features, som_prototypes, test_suffix)

plot_tsne_output(train_title, tsne_train_output, autoencoder_som, train_targets,
                 neighbors=True, plot=True, save=save, path=join(out_folder, train_title + ".png"))
plot_tsne_output(test_title, tsne_test_output, autoencoder_som, test_targets,
                 neighbors=True, plot=True, save=save, path=join(out_folder, test_title + ".png"))
