import warnings

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

warnings.simplefilter(action="ignore", category=FutureWarning)


def scatter_features(features: list, labels: list, label_names: list, colors_edge: list, filename="", show=False, method="PCA"):
    plt.cla()
    plt.axis("off")

    if (len(label_names) <= 10):
        colors = plt.cm.tab10(np.linspace(0.0, 1.0, len(label_names)))
    else:
        colors = plt.cm.gist_rainbow(np.linspace(0.0, 1.0, len(label_names)))

    if (method == "TSNE"):
        model = TSNE(n_components=2, learning_rate="auto", init="pca", random_state=0)
    else:
        model = PCA(n_components=2)

    data = model.fit_transform(np.concatenate(features))

    pointer = 0
    scatters = []

    for j in range(len(features)):
        for i in range(len(label_names)):
            scatter = plt.scatter(data[pointer:pointer + len(features[j])][:, 0][labels[j] == i], data[pointer:pointer + len(features[j])][:, 1][labels[j] == i], color=colors[i], s=8.0, alpha=0.8, edgecolors=colors_edge[j])
            if (j == 0):
                scatters.append(scatter)
        pointer += len(features[j])

    if (len(label_names) <= 10):
        plt.legend(scatters, label_names, loc="center left", bbox_to_anchor=(-0.1, 0.5), borderaxespad=0, markerscale=2.0)
    if filename:
        plt.savefig(filename)
    if (show):
        plt.show()


def draw_label_evolution(images, indices, filename, show=False):
    if (isinstance(indices, list)):
        evolution = np.ones((max(len(indices[i]) for i in range(len(indices))), len(indices)), dtype=np.int32) * -1
        for i in range(len(indices)):
            for j in range(len(indices[i])):
                evolution[j, i] = indices[i][j]
    if (evolution.shape[0] == 0):
        return
    canvas = Image.new("RGB" if images.shape[1] == 3 else "L", (evolution.shape[0] * images.shape[2], evolution.shape[1] * images.shape[3]))
    for i in range(evolution.shape[0]):
        for j in range(evolution.shape[1]):
            if (evolution[i, j] != -1):
                origin = Image.fromarray((images[evolution[i, j]].permute((1, 2, 0)).squeeze().numpy() * 255).astype(np.uint8))
                canvas.paste(origin, (i * images.shape[2], j * images.shape[3]))
    if (show):
        canvas.show()
    canvas.save(filename)
