import os
import numpy as np
import matplotlib.pyplot as plt
# from sklearn.manifold import TSNE
from cuml import TSNE  # much faster
from matplotlib.collections import LineCollection
from sklearn.metrics import r2_score



def compute_tsne(x, tsne_kwargs={"perplexity": 30, "random_state": 0}):
    """
    x: (samples, features)
    """
    tsne = TSNE(**tsne_kwargs)
    emb = tsne.fit_transform(x)
    return tsne, emb


def plot_emb_ax(emb, labels, label_names, ax, alpha=0.1, show_legend=True):
    """
    this one assumes that you are using passing in an ax that is already set up
    """
    cmap = "tab10"
    cmap_f = plt.get_cmap(cmap)
    colors = [cmap_f(i) for i in range(10)]

    unique_labels = np.unique(labels)

    ax.set_aspect(1)
    for u in unique_labels:
        sc = ax.scatter(
            *emb[labels == u].T,
            alpha=alpha,
            c=labels[labels == u],
            cmap="tab10",
            vmin=0,
            vmax=10,
            label=u,
        )

    if show_legend:
        handles = [
            plt.Line2D(
                [0],
                [0],
                marker="o",
                color="w",
                markerfacecolor=color,
                markersize=10,
                label=label_names[label],
            )
            for color, label in zip(colors, unique_labels)
        ]

        ax.legend(handles=handles, bbox_to_anchor=(1, 0.75))


def plot_colored_1darray(
    array,
    colors,
    cmap="viridis",
    colorbar=True,
    figsize=(12, 2),
    x_margin=100,
    y_margin=1,
):
    """
    Plots an array with different colors based on a second vector.

    Parameters:
    - array (numpy.ndarray): 1D or 2D array of data points.
    - colors (numpy.ndarray): 1D array of integers of the same length as the first dimension of `array`.
                              Values >= 0 map to the colormap; -1 maps to gray.
    - cmap (str): Colormap for the plot.
    - colorbar (bool): Whether to show a colorbar.

    Returns:
    - None
    """
    if len(array) != len(colors):
        raise ValueError("`array` and `colors` must have the same length.")

    # Convert 1D array to 2D if necessary
    if array.ndim == 1:
        array = np.c_[np.arange(len(array)), array]
        # array = np.expand_dims(array, axis=1)

    if array.shape[1] > 2:
        raise ValueError("`array` must be 1D or 2D with at most 2 columns.")

    # Prepare segments for LineCollection
    segments = np.stack([array[:-1], array[1:]], axis=1)

    # Map colors: -1 to gray, others to colormap
    color_map = plt.get_cmap(cmap)
    max_color = max(colors[colors >= 0]) if any(colors >= 0) else 1
    # mapped_colors = np.array([-1 if c == -1 else c for c in colors[:-1]])
    color_list = ["lightgrey" if c == -1 else color_map(c / 10) for c in colors]

    plt.figure(figsize=figsize)

    # Create a LineCollection for efficient plotting
    lc = LineCollection(segments, colors=color_list, linewidths=2)
    plt.gca().add_collection(lc)

    if array.shape[1] == 2:
        plt.xlim(array[:, 0].min() - x_margin, array[:, 0].max() + x_margin)
        plt.ylim(array[:, 1].min() - y_margin, array[:, 1].max() + y_margin)
    elif array.shape[1] == 1:
        plt.xlim(0, len(array))
        plt.ylim(array[:, 0].min(), array[:, 0].max())

    if colorbar:
        # Create a colorbar for values >= 0
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=10))
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=plt.gca(), label="Color values")

    plt.xlabel("Time")
    plt.ylabel("" if array.shape[1] == 2 else "Values")


def plot_loss_curves(config, metrics_dict, fig_kwargs={"figsize": (5, 2)}):
    """
    plot train and val curves
    """

    fig, ax = plt.subplots(1, 1, **fig_kwargs)
    ix = np.arange(len(metrics_dict["val_loss"])) * config.eval_every_n
    ax.plot(metrics_dict["train_loss"], label="train")
    ax.plot(ix, metrics_dict["val_loss"], label="val")
    return fig, ax


def plot_reconstruction(config, true, pred, labels, save_path, save=True, show=False):

    ixs = np.arange(len(true))
    np.random.seed(1234)
    ixs = np.random.permutation(ixs)

    n_show = 10

    i = 0

    fig, axs = plt.subplots(n_show, 3, figsize=(10, 2 * n_show), dpi=150)
    while i < n_show:
        ix = ixs[i]

        t, p, y = (
            true[ix][: config["data_args"]["subseq_size"]],
            pred[ix][: config["data_args"]["subseq_size"]],
            labels[ix],
        )

        axs[i, 0].plot(t)
        axs[i, 0].set_title(f"ix: {ix} y: {y}. True")

        axs[i, 1].set_title(f"Pred")
        axs[i, 1].plot(p)

        ylim = get_ylim(axs[i,])

        axs[i, 0].set_ylim(ylim)
        axs[i, 1].set_ylim(ylim)

        r2 = r2_score(t.flatten(), p.flatten())

        axs[i, 2].scatter(t.flatten(), p.flatten(), alpha=0.1)
        axs[i, 2].plot([-2, 2], [-2, 2], "r--")
        axs[i, 2].set_aspect(1)
        axs[i, 2].set_title(f"r2: {r2:.3f}")

        i += 1

    # save_path = os.path.join(config["save_dir"], fname)

    plt.tight_layout()

    if save:
        plt.savefig(save_path)
    if show:
        plt.show()


def get_ylim(axs):
    """
    retrieve the ylims from a list of axes
    """
    ylims = [ax.get_ylim() for ax in axs]
    ymin = min([ylim[0] for ylim in ylims])
    ymax = max([ylim[1] for ylim in ylims])
    return ymin, ymax
