from typing import List

from BACKEND import cp, sp, to_cpu

import numpy as np_fix

import matplotlib.pyplot as plt

colors = [
    "tab:blue",
    "tab:orange",
    "tab:green",
    "tab:red",
    "tab:purple",
    "tab:brown",
    "tab:pink",
    "tab:gray",
    "tab:olive",
    "tab:cyan",
]



def plot_raster_binned(
        spike_trains: cp.ndarray,
        title: str = "Spike Train",
        spike_color: str = "black",
        plt_ax: plt.Axes = None,
        figsize: tuple = (10, 6),
        put_time=False
):
    pass

def plot_raster_event(
        spike_trains: cp.ndarray,
        title: str = "Spike Train",
        spike_color: str = "black",
        plt_ax: plt.Axes = None,
        figsize: tuple = (10, 6),
        put_time=False
):
    num_trains, max_events  = spike_trains.shape

    spike_trains = to_cpu(spike_trains)

    if plt_ax is None:
        _, ax = plt.subplots(figsize=figsize)
    else:
        ax = plt_ax

    # Plot raster for spikes
    for train_idx in range(num_trains):
        spike_times = spike_trains[train_idx]
        ax.scatter(
            spike_times,
            np_fix.full_like(spike_times, train_idx + 1),
            color=spike_color,
            marker=",",
            s=1,
        )

    # Customize plot appearance
    ax.set_title(title)
    if put_time:
        ax.set_xlabel("Time (arb. units)")
    ax.set_ylabel("Train Index")
    if num_trains <= 25:
        ax.set_yticks(range(1, num_trains + 1))
    ax.set_ylim(0.5, num_trains + 0.5)
    #ax.legend(loc="upper right")
    ax.grid(True, which="both", linestyle="--", alpha=0.5)

    if plt_ax is None:
        plt.show()

def plot_spikes_voltages_grid(trains_in: cp.ndarray,
                              trains_out: cp.ndarray,
                              ts: cp.ndarray,
                              vs_list: List[cp.ndarray],
                              label_list: List[str]=None,
                              plot_samples = 3,
                              plot_neurons=3,
                              figsize=(12, 12),
                              plot_thresh=True,
                              save_name=None):
    """
    :param trains_in: (N, n_in, n_spikes) Input spike trains encoded by spike times
    :param trains_out: (N, n_out, n_spikes) Output spike trains encoded by spike times
    :param ts: (n_ts) Time points
    :param vs_target: (N, n_out, n_ts) Target voltages for each neuron and time point.
    :param vs_pred: (N, n_out, n_ts) Predicted voltages for each neuron and time point.
    :param plot_samples: How many samples to plot.
    :param plot_neurons: How many neurons to plot.
    :return:
    """
    trains_in = to_cpu(trains_in)
    trains_out = to_cpu(trains_out)
    ts = to_cpu(ts)
    vs_list = [to_cpu(v) for v in vs_list]


    num_samples, n_in, _ = trains_in.shape
    _, n_neurons, _ = trains_out.shape

    plot_samples = min(num_samples, plot_samples)
    plot_neurons = min(n_neurons, plot_neurons)

    fig, ax = plt.subplots(nrows=3, ncols=plot_samples, sharex=True, figsize=figsize)

    for k in range(plot_samples):
        plot_raster_event(trains_in[k], title="Input Spike Trains", plt_ax=ax[0, k] if plot_samples > 1 else ax[0], put_time=False)
        a = ax[1, k] if plot_samples > 1 else ax[1]
        a.set_title("Voltages")
        a.set_ylabel("Voltage")
        if plot_thresh:
            a.axhline(1, linestyle='--', linewidth=1.0, color='red', label=r'Threshold $\theta=1$')
        for i in range(plot_neurons):
            for vi, v in enumerate(vs_list):
                a.plot(ts, v[k, i, :], label=f"{i}: {label_list[vi]}" if label_list else (f"{i}" if vi==0 else None), linestyle="-" if vi == 0 else "-.", color=colors[i])
            for t_sp in trains_out[k, i, :]:
                a.scatter(t_sp, 1, marker='*', s=30, color=colors[i], zorder=1000000,
                             edgecolors='black')
        with plt.rc_context({"text.usetex": True, "font.family": "serif"}):
            a.legend()
        plot_raster_event(trains_out[k], title="Output Spike Trains", plt_ax=ax[2, k] if plot_samples > 1 else ax[2], put_time=True)
    if save_name is not None:
        plt.savefig(save_name, pad_inches=0)
    plt.show()

def plot_spikes_single(trains_in, ts, plot_samples):
        """
        :param trains_in: (N, n_in, n_spikes) Input spike trains encoded by spike times
        :param trains_out: (N, n_out, n_spikes) Output spike trains encoded by spike times
        :param ts: (n_ts) Time points
        :param vs_target: (N, n_out, n_ts) Target voltages for each neuron and time point.
        :param vs_pred: (N, n_out, n_ts) Predicted voltages for each neuron and time point.
        :param plot_samples: How many samples to plot.
        :param plot_neurons: How many neurons to plot.
        :return:
        """
        trains_in = to_cpu(trains_in)
        ts = to_cpu(ts)
        num_samples, n_in, _ = trains_in.shape

        plot_samples = min(num_samples, plot_samples)

        fig, ax = plt.subplots(nrows=1, ncols=plot_samples, sharex=True, figsize=(12, 12))

        for k in range(plot_samples):
            plot_raster_event(trains_in[k], title="Train in", plt_ax=ax[k])
        plt.show()

def plot_traces_grid(traces_in: cp.ndarray,
                     traces_out: cp.ndarray,
                     plot_samples=3,
                     plot_dims=3,
                     plot_steps = None,
                     cmap='cividis'):
    traces_in = to_cpu(traces_in[:plot_samples])
    traces_out = to_cpu(traces_out[:plot_samples])
    vmin = min(traces_in.min(), traces_out.min())
    vmax = max(traces_in.max(), traces_out.max())

    num_samples, n_in, n_t = traces_in.shape
    _, n_neurons, _ = traces_out.shape

    if plot_steps is not None:
        n_t = plot_steps

    plot_samples = min(num_samples, plot_samples) if plot_samples!=-1 else num_samples
    plot_dims = min(n_neurons, plot_dims) if plot_dims != -1 else n_neurons

    fig, ax = plt.subplots(nrows=2, ncols=plot_samples, figsize=(12, 8))

    for k in range(plot_samples):
        ax[0, k].imshow(traces_in[k, :plot_dims, :plot_steps], vmin=vmin, vmax=vmax, cmap=cmap)
        ax[1, k].imshow(traces_out[k, :plot_dims, :plot_steps], vmin=vmin, vmax=vmax, cmap=cmap)
    fig.subplots_adjust(right=0.85)

    # Add a single colorbar to the right of the figure
    cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    norm = plt.Normalize(vmin=vmin, vmax=vmax)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    fig.colorbar(sm, cax=cbar_ax)

    fig.show()