# These functions are very heavily inspired by Sabrina Kern's plotting for the purpose of the publication:
# Sabrina Kern, Michael Eberhard, and Gerta Köster. Detecting dynamical patterns in pedestrian bottlenecks
# koopman-based quantification of crowd dynamics. In European Physical Journal Web of Conferences,
# volume 334, pp. 04020, 2025.
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datafold import TSCDataFrame, EDMD
from matplotlib.axes import Axes
from matplotlib.colors import LinearSegmentedColormap, Colormap
from matplotlib.image import AxesImage
from matplotlib.quiver import Quiver

from kirnn import KIRNN

COLORS = [(0, 'steelblue'), (0.2, 'aqua'), (0.4, 'teal'), (0.7, 'darkslategray'), (1.0, 'black')]  # lightskyblue
CMAP_NAME = "custom_density_map"
CUSTOM_CMAP = LinearSegmentedColormap.from_list(CMAP_NAME, COLORS, N=1000)
CYCLIC_CMAP = LinearSegmentedColormap.from_list('custom_cyclic_map',
                                                ['#D79A12', '#E55C4B', '#EC2CA3', '#BF40E5', '#8270F1', '#35ADE0',
                                                 '#3BC4A3', '#71C22F', '#D79A12', ],
                                                N=1000)  # https://stackoverflow.com/a/48055435


def plot_snapshot(grid_data: np.ndarray,
                      title: str = '',
                      vmin: float = -0.8,
                      vmax: float = 0.8,
                      ax: Axes | None = None,
                      cmap: Colormap | str = CUSTOM_CMAP,
                      grid_resolution: list[float] = [0.2, 0.2],
                      flip_ud: bool = True,
                      transpose: bool = False,
                      remove_axis_labels: bool = False,
                      vector_scaling: float = 0.5,
                      ) -> tuple[Axes, AxesImage | Quiver]:
    '''
    This method plots a single snapshot, given a numpy array.
    In case the grid_shape implies 1d data (one value per cell in 2d grid): the data is plotted as an image.
    In case the grid_shape implies 2d data (two values per cell in 2d grid): the data is plotted as a vector field.
    Optionally, an axis object can be passed to plot the grid on a specific axis (e.g. for subplots).
    '''
    ax = ax or plt.gca()  # get current axis if none is passed explicitly

    # handle the resolution
    resolution_x, resolution_y = grid_resolution

    # calculate the tick interval based on the resolution
    tick_interval_x, tick_interval_y = max(1, int(2 / resolution_x)), max(1, int(2 / resolution_y))

    if ((len(grid_data.shape) == 2) or
            (len(grid_data.shape) == 3 and grid_data.shape[0] == 1)
    ):
        grid_data = grid_data[0, :, :] if len(grid_data.shape) == 3 else grid_data
        grid_data = grid_data.T if transpose else grid_data
        im = ax.imshow(grid_data,
                       cmap=cmap,
                       origin='lower',  # origin is the lower left corner, where (x, y) = (0,0)
                       extent=(0, grid_data.shape[1], 0, grid_data.shape[0]),
                       aspect=resolution_y / resolution_x,  # aspect ratio of the grid
                       # vmin=vmin,
                       # vmax=vmax
                       )

        x_length, y_length = grid_data.shape[1], grid_data.shape[0]

    else:
        print(f"Could not plot grid data with shape {grid_data.shape}. Aborting...")
        raise ValueError(f"Could not plot grid data with shape {grid_data.shape}. Aborting...")

    # set title and labels
    ax.set_title(title)
    x_labels = list(map(str, map(int, np.arange(0, x_length + 1, tick_interval_x) * resolution_x)))
    y_labels = list(map(str, map(int, np.arange(0, y_length + 1, tick_interval_y) * resolution_y)))
    ax.set_xticks(ticks=np.arange(0, x_length + 1, tick_interval_x), labels=x_labels)
    ax.set_yticks(ticks=np.arange(0, y_length + 1, tick_interval_y), labels=y_labels)

    if remove_axis_labels:
        ax.tick_params(labelbottom=False, labelleft=False)  # remove x-axis and y-axis labels

    ax.xaxis.tick_top()  # move the x-axis to the top of the grid
    if flip_ud:
        y_labels.reverse()
        ax.invert_yaxis()  # turn the contents including labels upside down

    print(f"Plotted grid for single timestep with {title=}")
    plt.savefig(f"results/density_{title}.pdf")
    plt.close()


def plot_predictions_per_cell(df: pd.DataFrame,
                              df_predicted: pd.DataFrame,
                              series_id: int,
                              observable: str,
                              visualization_timesteps: list[int],
                              measurement_cells: list[str],
                              temporal_resolution: float = 0.2,
                              path: str | None = None,
                              is_test: bool = False
                              ) -> None:
    x_true = df.index * temporal_resolution
    x_predicted = df_predicted.index * temporal_resolution
    # combine the measurement cells with the system observable postfix
    visualization_columns = [f'{cell}_{observable}' for cell in measurement_cells]

    # create one row per measurement cell
    fig, axs = plt.subplots(nrows=len(measurement_cells), ncols=1,
                            figsize=[4, len(measurement_cells) * 3 + 1])
    axs = axs if len(measurement_cells) > 1 else [axs]  # make it iterable in case of one cell

    for i, column in enumerate(visualization_columns):
        try:
            y_true = df[column]
            y_predicted = df_predicted[column]
        except KeyError as e:
            print(f"Error while plotting predictions per cell: {e}")
            print(f"Please validate the {measurement_cells=} w.r.t. the grid size.")
            print(f"Available columns are: {df.columns}")
            print(f"{visualization_columns=}")
            raise e

        axs[i].plot(x_true, y_true, c="black", label="true", linewidth=0.7, alpha=0.6)
        axs[i].plot(x_predicted, y_predicted, c="steelblue", label="predicted", linewidth=1.2)
        axs[i].axvspan(0, x_predicted[0], color='gray', alpha=0.3, label='state embedding depth')
        # add vertical lines for the visualization timesteps
        for timestep in visualization_timesteps:
            axs[i].axvline(timestep * temporal_resolution, color='black', linestyle='--', alpha=0.5, linewidth=0.7)
        axs[i].set_xlabel(f'time in seconds')
        axs[i].set_ylabel(observable)
        axs[i].grid()
        axs[i].set_title(f"{observable.capitalize()} value in grid cell {column.split(')')[0]})", wrap=True)

    plt.suptitle(f"Predictions per cell for {observable} as system observable", wrap=True)
    plt.tight_layout()

    # make space for the legend
    fig.subplots_adjust(bottom=0.2)
    l, b, _, _ = axs[-1].get_position().bounds  # (left, bottom, width, height)
    axs[-1].legend(loc='upper left',
                   bbox_to_anchor=(l - 0.3, b - 1.8, 1, 1),
                   ncol=1)  # last axis gets the legend

    postfix = 'test' if is_test else 'train'
    plt.savefig(f'./results/{observable}_per_cell_{series_id}_{postfix}.pdf',
                dpi=200)

    plt.close()


def plot_eigenvalues(model: KIRNN,
                     selection: list | None = None,
                     xlim: tuple[float, float] = (-1.4, 1.4),
                     ylim: tuple[float, float] = (-1.4, 1.4),
                     t_max: int | None = None,
                     compute_pseudospectrum: bool = False,
                     temporal_resolution: float = 0.2,
                     embedding_timeshift: int = 0,
                     scenario: str | None = None) -> None:
    '''
    Visualize the eigenvalues of the Koopman operator.
    The plot also displays the unit circle for reference.
    '''

    model = model.edmd

    # get eigenvalues
    evs = model.koopman_eigenvalues.to_numpy()

    # create a scatter plot for the eigenvalues
    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(111)

    # add unit circle
    t = np.linspace(0, 2 * np.pi, 100)
    ax.plot(-np.cos(t), np.sin(t), linewidth=1, zorder=0, color='black')

    # add selection eigenvalues
    if selection is not None:
        ax.scatter(evs.real, evs.imag, label="eigenvalue", s=3, zorder=5, alpha=0.7, marker='x')
        selection = [s - 1 for s in selection]
        ax.scatter(evs[selection].real, evs[selection].imag, zorder=10, color='orange', label='selected eigenvalue')

    elif compute_pseudospectrum:
        if model.dmd_model.compute_pseudospectrum is None:
            print("Pseudospectrum is not available for this model.")
        else:
            try:

                ev_residuals = model.dmd_model.residuals_
            except AttributeError as e:
                print(
                    f"Residuals were not stored for the model. This is only turned when removing spectral pollution.")
                print(f"Calculating them...")

                # compute the pseudospectrum
                eigenvectors_right = model.dmd_model.eigenvectors_right_
                ev_residuals = model.dmd_model._compute_spectral_residuals(eigenvalues=evs,
                                                                           eigenvectors_right=eigenvectors_right)
            im = ax.scatter(np.real(evs),
                            np.imag(evs),
                            s=20, c=np.real(ev_residuals[:len(evs)]),
                            cmap='jet',
                            marker='.',
                            zorder=5,
                            vmax=1,
                            vmin=0)
    else:
        pass

    # plot settings
    ax.grid()
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    # plt.legend(loc=3)
    if scenario is not None:
        ax.set_title(f'{scenario}', size=8)
    else:
        ax.set_title(f"Eigenvalues in unit circle", wrap=True)
    plt.xlabel(r"$Re(\lambda)$")
    plt.ylabel(r"$Im(\lambda)$")
    plt.tight_layout()

    # add color bar
    try:
        if im is not None:
            fig.subplots_adjust(right=0.85)
            fig.set_size_inches(4, 4 * 0.87)
            _, b, _, h = ax.get_position().bounds  # (left, bottom, width, height)
            cbar_ax = fig.add_axes((0.9, b, 0.01, h))
            fig.colorbar(im, cax=cbar_ax)
    except UnboundLocalError:
        pass

    plt.savefig(f'./results/eigenvalues_unit_circle.pdf',
                dpi=200)
    plt.close()

    # plot the eigenvalue evolution over time
    if t_max is not None:
        fig = plt.figure(figsize=(12, 5))
        ax = fig.add_subplot(111)
        for j, ev in enumerate(evs):
            # create the eigenvalue evolution in absolute value, starting from 1
            # consider embedding timeshift to have a matching time scale
            timesteps = np.arange(0, t_max - embedding_timeshift, step=1)
            ev_timeseries = np.abs([ev ** t for t in np.arange(0, t_max - embedding_timeshift, step=1)])
            ax.plot(timesteps * temporal_resolution, ev_timeseries,
                    linestyle='solid',
                    linewidth=0.5,
                    label=f'$\lambda_{{{j + 1}}}$')

        ax.set_title(f"Koopman eigenvalues over time")
        ax.set_xlabel('time in seconds')
        ax.set_ylabel('Absolute eigenvalue')
        ax.legend()
        ax.grid()

        plt.tight_layout()

        fig.subplots_adjust(right=0.65)
        _, b, _, h = ax.get_position().bounds  # (left, bottom, width, height)
        ax.legend(loc='upper left',
                  bbox_to_anchor=(1.05, b + h + 0.2, 0, 0),
                  ncol=3)

        # save plot
        plt.savefig(f'./results/eigenvalues_2.pdf',
                    dpi=200)

        plt.close()

    # plot the eigenvalues as a bar chart with the absolute value
    fig = plt.figure(figsize=(7, 4))
    ax = fig.add_subplot(111)
    idx = np.argsort(np.abs(evs))
    ax.bar(range(len(evs)), np.abs(evs[idx[::-1]]), color='steelblue', alpha=0.7)
    ax.set_xlabel('Eigenvalue index')
    ax.set_ylabel('Absolute value')
    ax.set_title('Eigenvalues sorted by magnitude')
    ax.set_xticks(range(len(evs)))
    ax.set_xticklabels(idx[::-1], rotation=70, fontsize=5)
    ax.grid()
    plt.tight_layout()

    plt.savefig(f'./results/eigenvalues.pdf',
                dpi=200)
    plt.close()


def _get_koopman_triplet(tsc: TSCDataFrame | None,
                         model: EDMD,
                         observable_type: str,
                         selection: list[int] | None = None,
                         prediction_timeshift: int | None = None,
                         ) -> tuple[np.ndarray, np.ndarray, TSCDataFrame, int, list[int]]:
    '''
    Get the Koopman triplet for the given system observable and series id.
    The triplet consists of the Koopman modes, eigenvalues, and eigenfunction for the initial condition.
    The initial condition is the first state of the time series respecting the temporal embedding length.
    '''
    # get the modes
    modes = model.koopman_modes  # row per feature column, column per mode
    if modes is not None:
        modes = modes[modes.index.to_series().str.contains(observable_type)].to_numpy()  # shape (#cells, #modes)
    else:
        print(f"No Koopman modes available for {observable_type=}. Aborting...")
        raise ValueError(f"No Koopman modes available for {observable_type=}. Aborting...")

    # get the eigenvalues
    eigenvalues = model.koopman_eigenvalues.to_numpy()  # shape (75,)

    # select only the relevant triples
    if selection is None:
        selection = list(range(1, min(21, len(eigenvalues) + 1)))
    modes = modes[:, [v - 1 for v in selection]].T  # shape (len(selection), #grid cells)
    eigenvalues = eigenvalues[[v - 1 for v in selection]]  # (#selection,)

    # get the eigenfunctions for initial condition
    if tsc is not None:
        frames_for_initial_condition = model.n_samples_ic_ if model.n_samples_ic_ is not None else 1
        initial_state = 0 if prediction_timeshift is None else prediction_timeshift
        initial_state_idx = initial_state + frames_for_initial_condition - 1  # is 0 if we need exactly one state (no embedding)

        initial_conditions_per_series = tsc.select_time_values(
            np.arange(initial_state, initial_state_idx + 1))  # expects TSCDataFrame or array
        assert isinstance(initial_conditions_per_series,
                          TSCDataFrame), "Could not select the initial conditions from the TSCDataFrame."

        eval_eigenfunctions = model.koopman_eigenfunction(
            initial_conditions_per_series)  # expects TSCDataFrame or array
        assert isinstance(eval_eigenfunctions,
                          TSCDataFrame), "Could not evaluate the eigenfunctions for the initial conditions."
        eigenfunctions_per_series = eval_eigenfunctions.iloc[
            :, [v - 1 for v in selection]]  # returns TSCDataFrame because series id remains
        eigenfunctions_per_series.columns = pd.Index(
            [f'koop_eigfunc{i}' for i in selection])  # rename columns for selection
    else:
        initial_state_idx = None
        eigenfunctions_per_series = None

    return modes, eigenvalues, eigenfunctions_per_series, initial_state_idx, selection  # type: ignore


def _plot_single_grid(grid_data: np.ndarray,
                      title: str = '',
                      vmin: float = -0.8,
                      vmax: float = 0.8,
                      ax: Axes | None = None,
                      cmap: Colormap | str = CUSTOM_CMAP,
                      grid_resolution: list[float] = [0.2, 0.2],
                      transpose: bool = False,
                      remove_axis_labels: bool = False,
                      ) -> tuple[Axes, AxesImage | Quiver]:
    '''
    This method plots a single grid, given a numpy array.
    In case the grid_shape implies 1d data (one value per cell in 2d grid): the data is plotted as an image.
    In case the grid_shape implies 2d data (two values per cell in 2d grid): the data is plotted as a vector field.
    Optionally, an axis object can be passed to plot the grid on a specific axis (e.g. for subplots).
    '''
    ax = ax or plt.gca()  # get current axis if none is passed explicitly

    # handle the resolution
    resolution_x, resolution_y = grid_resolution

    # calculate the tick interval based on the resolution
    tick_interval_x, tick_interval_y = max(1, int(2 / resolution_x)), max(1, int(2 / resolution_y))

    grid_data = grid_data[0, :, :] if len(grid_data.shape) == 3 else grid_data
    grid_data = grid_data.T if transpose else grid_data
    im = ax.imshow(grid_data,
                   cmap=cmap,
                   origin='lower',  # origin is the lower left corner, where (x, y) = (0,0)
                   extent=(0, grid_data.shape[1], 0, grid_data.shape[0]),
                   aspect=resolution_y / resolution_x,  # aspect ratio of the grid
                   vmin=vmin,
                   vmax=vmax)

    x_length, y_length = grid_data.shape[1], grid_data.shape[0]

    # set title and labels
    ax.set_title(title)
    x_labels = list(map(str, map(int, np.arange(0, x_length + 1, tick_interval_x) * resolution_x)))
    y_labels = list(map(str, map(int, np.arange(0, y_length + 1, tick_interval_y) * resolution_y)))
    ax.set_xticks(ticks=np.arange(0, x_length + 1, tick_interval_x), labels=x_labels)
    ax.set_yticks(ticks=np.arange(0, y_length + 1, tick_interval_y), labels=y_labels)

    if remove_axis_labels:
        ax.tick_params(labelbottom=False, labelleft=False)  # remove x-axis and y-axis labels

    ax.xaxis.tick_top()  # move the x-axis to the top of the grid
    # flip up/down
    y_labels.reverse()
    ax.invert_yaxis()  # turn the contents including labels upside down

    return ax, im


def _plot_mode_row(vmin: float | None,
                   vmax: float | None,
                   grid_resolution,
                   axs,
                   grid_real,
                   grid_imag,
                   mode_name,
                   include_negative=False,
                   transpose: bool = False,
                   vector_scaling: float = 0.5, ):
    if vmax is None and grid_imag is not None:
        grid_max = max(grid_real.max(), grid_imag.max())
    elif vmax is None and grid_imag is None:
        grid_max = grid_real.max()

    if vmin is None and grid_imag is not None:
        grid_min = min(grid_real.min(), grid_imag.min())
    elif vmin is None and grid_imag is None:
        grid_min = grid_real.min()

    max_abs = max(abs(grid_max), abs(grid_min))  # this makes the color scale symmetric for pairs of modes

    ax1, cax1 = _plot_single_grid(
        grid_real,
        f'{mode_name} (real)',
        vmin=-max_abs,
        vmax=max_abs,
        ax=axs[0],
        cmap=plt.get_cmap('Spectral'),
        grid_resolution=grid_resolution,
        transpose=transpose,
    )

    if grid_imag is not None:
        _, _ = _plot_single_grid(
            grid_imag,
            f'{mode_name} (imag)',
            vmin=-max_abs,
            vmax=max_abs,
            ax=axs[1],
            cmap=plt.get_cmap('Spectral'),
            grid_resolution=grid_resolution,
            transpose=transpose,
        )

    if include_negative and axs.shape[0] >= 4:
        _, _ = _plot_single_grid(
            (-1) * grid_real, f'- {mode_name} (real)',
            vmin=-max_abs,
            vmax=max_abs,
            ax=axs[2],
            cmap=plt.get_cmap('Spectral'),
            grid_resolution=grid_resolution,
            transpose=transpose,
        )

        _, _ = _plot_single_grid(
            (-1) * grid_imag, f'- {mode_name} (imag)',
            vmin=-max_abs,
            vmax=max_abs,
            ax=axs[3],
            cmap=plt.get_cmap('Spectral'),
            grid_resolution=grid_resolution,
            transpose=transpose,
        )

    return ax1, cax1


def plot_modes(model: KIRNN,
               tsc: TSCDataFrame | None,
               observable_type: str,
               prediction_timeshift: int,
               grid_shape: list[int],
               selection: list[int] | None = None,
               vmin: float | None = None,
               vmax: float | None = None,
               grid_resolution: list[float] = [0.2, 0.2],
               transpose: bool = False,
               vector_scaling: float = 0.5,
               ):
    '''
    Visualize the Koopman modes for the given system observable.
    The modes are displayed as a grid for each real and imaginary part.
    '''
    print(f"Visualizing Koopman modes for {observable_type=}")
    matplotlib.rcParams.update({'font.size': 18})

    # get the Koopman modes and Koopman eigenvalues
    # we do not use the eigenfunction values here, because these are series specific
    modes, eigenvalues, eigenfunctions_per_series, initial_state_idx, selection = _get_koopman_triplet(tsc=tsc,
                                                                                                       model=model.edmd,
                                                                                                       observable_type=observable_type,
                                                                                                       selection=selection,
                                                                                                       prediction_timeshift=prediction_timeshift)

    # create one row per mode, each with real and imaginary part
    fig, axs = plt.subplots(len(selection), 4, figsize=(4 * grid_shape[1] * grid_resolution[0] * 0.3,
                                                        len(selection) * grid_shape[0] * grid_resolution[1] * 0.4))

    if len(selection) == 1:
        axs = np.expand_dims(axs, axis=0)
    images = []
    for i, (j, mode, ev) in enumerate(zip(selection, modes, eigenvalues)):
        try:
            grid_real = mode.real.reshape(-1, grid_shape[0], grid_shape[1])
            grid_imag = mode.imag.reshape(-1, grid_shape[0], grid_shape[1])
        except:
            raise ValueError(f'Cannot reshape array of size {mode.shape} into {(grid_shape[0], grid_shape[1])}.')

        ax1, cax1 = _plot_mode_row(vmin=vmin,
                                   vmax=vmax,
                                   grid_resolution=grid_resolution,
                                   axs=axs[i, :],
                                   grid_real=grid_real,
                                   grid_imag=grid_imag,
                                   mode_name=f'$v_{{{j}}}$',
                                   include_negative=True,
                                   transpose=transpose,
                                   vector_scaling=vector_scaling, )

        ax1.set_ylabel(f'$\lambda_{{{j}}}$ = {ev.real:.2f} + {ev.imag:.2f}i')
        images.append(cax1)  # color scale is the same due to setting vmax and vmin

    # squeeze the layout
    plt.tight_layout()

    # add one color bar for each row
    fig.subplots_adjust(right=0.82, top=0.95)
    fig.suptitle(f'{observable_type.capitalize()} Koopman modes', wrap=True)
    for ax, im in zip(axs, images):
        _, b, _, h = ax[0].get_position().bounds  # (left, bottom, width, height)
        cbar_ax = fig.add_axes((0.85, b, 0.01, h))
        if isinstance(im, Quiver):
            im.set_clim(0.0, 1.0)
            cbar = fig.colorbar(im, cax=cbar_ax, ticks=[0.0, 0.5, 1.0])  # angular coloring
            cbar.ax.set_yticklabels(['$- \\pi$', '0', '$\\pi$'])
        else:
            fig.colorbar(im, cax=cbar_ax)

    plt.savefig(f'./results/{observable_type}_modes.pdf', dpi=200)

    return selection, modes, eigenvalues, eigenfunctions_per_series, initial_state_idx
