import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import numpy as np
import os
import re
import csv

from src.utils.evaluation_functions import calc_KE

def sidebyside_comp():
    """
    Plots True and Rollout fields side by side for comparison for given rollout steps.
    """

    fig, axes = plt.subplots(4, 2, figsize=(18, 16), gridspec_kw={'wspace': 0, 'hspace': 0})

    plt.subplots_adjust(wspace=0, hspace=0)

    ims = []

    return None


# def plot_residual_divergence_over_time(roll_list, true_list, x_grid, y_grid, nr_steps, output_path=None):
#     """
#     Plots the evolution of the divergence residual of rollout compared to ground truth.
#     """
#     fig = plt.figure()
#     endtime = np.shape(roll_list)[0]
#     time_indices = np.linspace(0, endtime, nr_steps, endpoint=False, dtype=int)
    
#     rms_residuals_roll_list = []
#     rms_residuals_true_list = []

#     for time in time_indices:

#         divergence_residual, rms_residual_divergence = calc_divergence_residual(x_grid, y_grid, roll_list[time, :, :, 0], roll_list[time, :, :, 1])
#         rms_residuals_roll_list.append(rms_residual_divergence)

#         divergence_residual, rms_residual_divergence = calc_divergence_residual(x_grid, y_grid, true_list[time, :, :, 0], true_list[time, :, :, 1])
#         rms_residuals_true_list.append(rms_residual_divergence)

#     plt.plot(time_indices, rms_residuals_roll_list, color='tab:blue', label="Rollout")
#     plt.plot(time_indices, rms_residuals_true_list, color='tab:red', label="Ground Truth")
#     plt.legend()
#     plt.xlabel('Prediction Step')
#     plt.ylabel('Residual of Divergence')
#     plt.title('Evolution of Residual of Divergence over Rollout')

#     if output_path is not None:
#         fig.savefig(output_path, dpi=600, bbox_inches='tight')
#         print(f"Saved figure to: \n{output_path}")

def plot_residual_divergence_over_time(roll_list, true_list, x_grid, y_grid, nr_steps, output_path=None):
    """
    Plots the evolution of the divergence residual of rollout compared to ground truth.
    Supports:
        - Shape (timesteps, points, dims)
        - Shape (runs, timesteps, points, dims)
    """
    roll_list = np.asarray(roll_list)
    true_list = np.asarray(true_list)

    # Detect shape
    if roll_list.ndim == 4:  # (timesteps, points, dims)
        roll_list = roll_list[None, ...]  # Add runs dimension
        true_list = true_list[None, ...]  # shape -> (1, timesteps, points, dims)

    n_runs, n_timesteps, _, _, _ = roll_list.shape
    time_indices = np.linspace(0, n_timesteps, nr_steps, endpoint=False, dtype=int)

    fig, ax = plt.subplots()

    for run in range(n_runs):
        rms_residuals_roll_list = []
        rms_residuals_true_list = []

        for time in time_indices:
            # Rollout
            _, rms_residual_roll = calc_divergence_residual(
                x_grid, y_grid,
                roll_list[run, time, :, :, 0],
                roll_list[run, time, :, :, 1]
            )
            rms_residuals_roll_list.append(rms_residual_roll)

            # Ground truth
            _, rms_residual_true = calc_divergence_residual(
                x_grid, y_grid,
                true_list[run, time, :, :, 0],
                true_list[run, time, :, :, 1]
            )
            rms_residuals_true_list.append(rms_residual_true)

        ax.plot(time_indices, rms_residuals_roll_list, color='tab:blue', label="Rollout" if run == 0 else "")
        ax.plot(time_indices, rms_residuals_true_list, color='tab:red', label="Ground Truth" if run == 0 else "")

    ax.legend()
    ax.set_xlabel('Prediction Step')
    ax.set_ylabel('Residual of Divergence')
    ax.set_title('Evolution of Residual of Divergence over Rollout')

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")

    plt.show()



def plot_histogramm_over_time(roll_list, true_list, nr_steps, output_path=None):
    """
    Plots the evolution of histogramms of kinetic energy of time snapshots during rollout.
    """
    endtime = np.shape(roll_list)[0]
    time_indices = np.linspace(0, endtime, nr_steps, endpoint=False, dtype=int)
    n_times = len(time_indices)
    cmap = cm.get_cmap('Reds', n_times)
    alphas = np.linspace(1.0, 0.2, n_times)  
    all_ke_roll = np.concatenate([0.5 * (roll_list[time,:,0]**2 + roll_list[time,:,1]**2) for time in time_indices])
    all_ke_true = np.concatenate([0.5 * (true_list[time,:,0]**2 + true_list[time,:,1]**2) for time in time_indices])
    ke_min = min(np.nanmin(all_ke_roll), np.nanmin(all_ke_true))
    ke_max = max(np.nanmax(all_ke_roll), np.nanmax(all_ke_true))
    bin_width = 0.00075
    bins = np.arange(ke_min, ke_max + bin_width, bin_width)

    fig = plt.figure(figsize=(8, 6))
    for i, time in enumerate(time_indices):
        u = roll_list[time, :, 0]
        v = roll_list[time, :, 1]
        ke = 0.5 * (u**2 + v**2)
        plt.hist(
            ke, 
            bins=bins, 
            #alpha=alphas[i],
            color=cmap(i), 
            label=f't={time}', 
            log=True,
            edgecolor='black',
            linewidth=0.5, 
            histtype='stepfilled')
        
    plt.legend()
    plt.xlabel('Kinetic Energy (m/s)')
    plt.title('Histogram of Kinetic Energy Over Rollout')
    plt.show()

    
    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")



def plot_histogramm_over_time_comp(roll_list, true_list, nr_steps, output_path=None):
    """
    Plots the evolution of histogramms of kinetic energy of time snapshots during rollout.
    """
    endtime = np.shape(roll_list)[0]
    stepsize = endtime / nr_steps
    time_indices = np.arange(0, endtime, stepsize)
    n_times = len(time_indices)
    cmap = cm.get_cmap('Reds', n_times)
    alphas = np.linspace(1.0, 0.2, n_times)  

    all_ke_roll = np.concatenate([0.5 * (roll_list[time,:,0]**2 + roll_list[time,:,1]**2) for time in time_indices])
    all_ke_true = np.concatenate([0.5 * (true_list[time,:,0]**2 + true_list[time,:,1]**2) for time in time_indices])
    ke_min = min(np.nanmin(all_ke_roll), np.nanmin(all_ke_true))
    ke_max = max(np.nanmax(all_ke_roll), np.nanmax(all_ke_true))
    bin_width = 0.00075
    bins = np.arange(ke_min, ke_max + bin_width, bin_width)

    fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)

    for i, time in enumerate(time_indices):
        u = roll_list[time, :, 0]
        v = roll_list[time, :, 1]
        ke = 0.5 * (u**2 + v**2)
        axes[0].hist(
            ke, 
            bins=bins, 
            #alpha=alphas[i],
            color=cmap(i), 
            label=f't={time}', 
            log=True,
            edgecolor='black',
            linewidth=0.5, 
            histtype='stepfilled')
        
    axes[0].legend()
    axes[0].set_xlabel('Kinetic Energy (m/s)')
    axes[0].set_title('Histogram of Kinetic Energy Over Rollout')
    
    for i, time in enumerate(time_indices):
        u = true_list[time, :, 0]
        v = true_list[time, :, 1]
        ke = 0.5 * (u**2 + v**2)
        axes[1].hist(
            ke, 
            bins=bins, 
            #alpha=alphas[i],
            color=cmap(i), 
            label=f't={time}', 
            log=True,
            edgecolor='black',
            linewidth=0.5, 
            histtype='stepfilled')
        
    axes[1].legend()
    axes[1].set_xlabel('Kinetic Energy (m/s)')
    axes[1].set_title('Histogram of Kinetic Energy Over Rollout')
    
    plt.suptitle('Histogram of Kinetic Energy Over Rollout')
    plt.tight_layout()    
    plt.show()

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")


def plot_ke_histograms(arr1, arr2, bins=50, output_path=None):
    """
    Plot two histograms on the same plot.
    
    arr1: first array (plotted in red)
    arr2: second array (plotted in blue)
    bins: number of bins (default 50)
    """
    fig = plt.figure(figsize=(8, 6))
    plt.hist(arr1, bins=bins, color="red", alpha=0.5, label="Array 1")
    plt.hist(arr2, bins=bins, color="blue", alpha=0.5, label="Array 2")
    
    plt.xlabel("Value")
    plt.ylabel("Frequency")
    plt.title("Comparison of Two Histograms")
    plt.legend()
    plt.grid(True, linestyle="--", linewidth=0.5)
    plt.show()

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")


def plot_mean_ke_spectra_over_wavenumber(E_spectrum_roll, k_center_roll, E_spectrum_true, k_center_true, output_path=None):

    fig = plt.figure()

    plt.loglog(
            k_center_roll, 
            E_spectrum_roll, 
            color='tab:blue',
            label='Prediction'
            )
    
    plt.loglog(
            k_center_true, 
            E_spectrum_true, 
            color='tab:red',
            label='Ground Truth'
            )
        
    plt.legend()

    plt.ylabel('E(k)')
    plt.xlabel('k')
    plt.title('Mean Kinetic Energy Spectra during Prediction')
    plt.show()

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")



def plot_ke_spectra_over_wavenumber(E_spectrum_roll, k_center_roll, E_spectrum_true, k_center_true, output_path=None):

    fig = plt.figure()
    n = E_spectrum_roll.shape[0]
    cmap_blue = cm.get_cmap('Blues', n+2)
    cmap_red = cm.get_cmap('Reds', n+2)
    alphas = np.linspace(1.0, 0.2, n)[::-1]

    for i, spectrum in enumerate(E_spectrum_roll):

        plt.loglog(
                k_center_roll, 
                E_spectrum_roll[i], 
                #color='tab:blue',
                color=cmap_blue(i+2),
                alpha=alphas[i],
                label='Rollout' if i==n//3*2 else None
                )
        
        plt.loglog(
                k_center_true, 
                E_spectrum_true[i], 
                #color='tab:red',
                color=cmap_red(i+2),
                alpha=alphas[i],
                label='Ground Truth' if i==n//3*2 else None
                )
        
    plt.legend()

    plt.ylabel('E(k)')
    plt.xlabel('k')
    plt.title('Evolution of Kinetic Energy Spectra during Rollout')
    plt.show()

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")



def plot_momentum_conservation(time_rollout, Px_roll, Px_true, Py_roll, Py_true, Ke_roll, Ke_true, output_paths=None):
    """
    Create three figures, one for each Px, Py and Ke
    """

    fig = plt.figure()
    plt.plot(time_rollout, Px_true, label='Ground Truth', color='tab:red')
    plt.plot(time_rollout, Px_roll, label='Rollout', color='tab:blue')
    plt.ylabel('Total Momentum in X (m³/s)')
    plt.xlabel('Prediction Step')
    plt.title('Total Momentum in X during Rollout')
    plt.legend()
    if output_paths is not None:
        output_path = output_paths[0]
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")


    fig = plt.figure()
    plt.plot(time_rollout, Py_true, label='Ground Truth', color='tab:red')
    plt.plot(time_rollout, Py_roll, label='Rollout', color='tab:blue')
    plt.ylabel('Total Momentum in Y (m³/s)')
    plt.xlabel('Prediction Step')
    plt.title('Total Momentum in Y during Rollout')
    plt.legend()
    if output_paths is not None:
        output_path = output_paths[1]
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")


    fig = plt.figure()
    plt.plot(time_rollout, Ke_true, label='Ground Truth', color='tab:red')
    plt.plot(time_rollout, Ke_roll, label='Rollout', color='tab:blue')
    plt.ylabel('Total Kinetic Energy (m⁴/s²)')
    plt.xlabel('Prediciton Step')
    plt.title('Total Kinetic Energy during Rollout')
    plt.legend()
    if output_paths is not None:
        output_path = output_paths[2]
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")


def plot_comp_diff(pred, true, title, probe_idcs=None, output_path=None, return_fig=False):

    diff = pred - true

    fig, axes = plt.subplots(3, 1, figsize=(20, 10), gridspec_kw={'wspace': 0, 'hspace': 0.1})

    cmap_viridis = plt.cm.viridis
    cmap_viridis.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

    cmap_coolwarm = plt.cm.coolwarm

    vmin = 0
    vmax = max(np.max(pred), np.max(true))
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    vmax_diff = max(abs(np.min(diff)), np.max(diff))
    vmin_diff = -vmax_diff
    norm_diff = mcolors.Normalize(vmin=vmin_diff, vmax=vmax_diff)

    ims = []
    pred_mag = np.linalg.norm(pred, axis=-1)
    ims.append(axes[0].imshow(pred_mag, cmap=cmap_viridis, origin='lower', norm=norm))

    true_mag = np.linalg.norm(true, axis=-1)
    ims.append(axes[1].imshow(true_mag, cmap=cmap_viridis, origin='lower', norm=norm))

    diff_mag = np.linalg.norm(diff, axis=-1)
    ims.append(axes[2].imshow(diff_mag, cmap=cmap_coolwarm, origin='lower', norm=norm_diff))    

    for i in range(3):
        axes[i].axis('off')
        # axes[i].set_xlim(roi[0], roi[1])
        # axes[i].set_ylim(roi[2], roi[3])
        if probe_idcs is not None:
            for (y, x) in probe_idcs:
                axes[i].plot(x, y, 'rx', markersize=4)  # red little crosses


    cbar_top = fig.colorbar(
    ims[0], ax=axes, orientation='horizontal', location='top', fraction=0.03, pad=0.04, aspect=25)
    cbar_top.ax.tick_params(labelsize=10)
    cbar_top.set_label('Velocity Magnitude (m/s)', fontsize=12)
    ims[0].set_clim(0, vmax)

    cbar_bottom = fig.colorbar(
    ims[2], ax=axes, orientation='horizontal', location='bottom', fraction=0.03, pad=0.04, aspect=25)
    cbar_bottom.ax.tick_params(labelsize=10)
    cbar_bottom.set_label('Difference (m/s)', fontsize=12)
    ims[2].set_clim(vmin_diff, vmax_diff)

    fig.suptitle(title, fontsize=14, y=0.95)    

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")

    if return_fig:
        return fig


def plot_sequence_movie(sequ, vmax, output_dir=None):


    sequ = np.squeeze(sequ)

    sequ_mag = np.linalg.norm(sequ, axis=-1)
    # vmax = np.max(sequ_mag)
    # vmax = 0.24
    norm = mcolors.Normalize(vmin=0, vmax=vmax)

    # Use the viridis colormap
    cmap = plt.cm.viridis

    # Plot and save each image in the sequence
    for i in range(sequ.shape[0]):
        fig, ax = plt.subplots(figsize=(5, 5))
        im = ax.imshow(sequ_mag[i], cmap=cmap, origin='lower', norm=norm)
        ax.axis('off')

        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label("Magnitude", fontsize=10)
        
        plt.tight_layout()

        # Save the figure
        if output_dir is not None:
            output_path = os.path.join(output_dir, f'image_{i:04d}.png')
            fig.savefig(output_path, dpi=300, bbox_inches='tight', pad_inches=0)
            plt.close(fig)

    print(f"Saved {sequ.shape[0]} images to: \n{output_dir}")


def plot_two_sequences_movie(pred_sequ, true_sequ, probe_idcs=None, output_dir=None):
    """
    Plots two sequences side-by-side as images for each timestep.

    Args:
        sequ1 (ndarray): First sequence of shape (T, H, W, C).
        sequ2 (ndarray): Second sequence of shape (T, H, W, C).
        vmax (float): Maximum value for colormap normalization.
        output_dir (str): Directory where output images will be saved.
    """
    assert pred_sequ.shape[0] == true_sequ.shape[0], "Both sequences must have the same length."


    vmax = max(np.max(pred_sequ), np.max(true_sequ))
    sequ1_mag = np.linalg.norm(pred_sequ, axis=-1)
    sequ2_mag = np.linalg.norm(true_sequ, axis=-1)
    diff = sequ1_mag - sequ2_mag

    norm = mcolors.Normalize(vmin=0, vmax=vmax)
    cmap = plt.cm.viridis

    max_diff = np.max(np.abs(np.min(diff))-np.max(diff))
    norm_diff = mcolors.Normalize(vmin=-max_diff, vmax=max_diff)
    cmap_diff = plt.cm.coolwarm

    # Plot and save each timestep
    for i in range(pred_sequ.shape[0]):
        fig, axes = plt.subplots(1, 3, figsize=(12, 5))

        ims = []
        ims.append(axes[0].imshow(sequ1_mag[i], cmap=cmap, origin='lower', norm=norm))
        axes[0].axis('off')
        if i < 16:
            axes[0].set_title("Reconstruction")
        else:
            axes[0].set_title("Prediction")

        ims.append(axes[1].imshow(sequ2_mag[i], cmap=cmap, origin='lower', norm=norm))
        axes[1].axis('off')
        axes[1].set_title("Truth")

        ims.append(axes[2].imshow(diff[i], cmap=cmap_diff, origin='lower', norm=norm_diff))
        axes[2].axis('off')
        axes[2].set_title("Difference")

        # Shared colorbar
        cbar = fig.colorbar(ims[1], ax=axes, orientation='horizontal', fraction=0.046, pad=0.04)
        cbar.set_label("Magnitude", fontsize=10)

        # plt.tight_layout()

        if probe_idcs is not None and i < 16:
            for (y, x) in probe_idcs:
                axes[0].plot(x, y, 'rx', markersize=2)  # red little crosses
                axes[1].plot(x, y, 'rx', markersize=2)  # red little crosses
                # axes[2].plot(x, y, 'rx', markersize=2)  # red little crosses


        if output_dir is not None:
            output_path = os.path.join(output_dir, f'image_{i:04d}.png')
            fig.savefig(output_path, dpi=300, bbox_inches='tight', pad_inches=0)
            plt.close(fig)

    print(f"Saved {pred_sequ.shape[0]} side-by-side images to: \n{output_dir}")



def plot_comp_diff_history(pred, true, n_plots=[0], title=None, output_path=None, return_fig=False):

    if max(n_plots) >= pred.shape[0]:
        print(f"Warning: n_plots contains index {max(n_plots)} but pred has only {pred.shape[0]} frames. Resetting to [0].")
        n_plots = [0]
    
    diff = pred - true

    n_cols = len(n_plots)
    fig, axes = plt.subplots(3, n_cols, figsize=(4*n_cols, 12), gridspec_kw={'wspace': 0, 'hspace': 0.1})

    cmap_viridis = plt.cm.viridis
    cmap_viridis.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

    cmap_coolwarm = plt.cm.coolwarm

    vmin = 0
    vmax = max(np.max(pred), np.max(true))
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    vmax_diff = max(abs(np.min(diff)), np.max(diff))
    vmin_diff = -vmax_diff
    norm_diff = mcolors.Normalize(vmin=vmin_diff, vmax=vmax_diff)

    # if n_cols == 1:
    #     axes = [axes]
    if n_cols == 1:
        axes = axes.reshape(3, 1)

    ims = []    
    for i, idx in enumerate(n_plots):
        pred_img = np.linalg.norm(pred[idx], axis=-1)
        true_img = np.linalg.norm(true[idx], axis=-1)
        diff_img = np.linalg.norm(diff[idx], axis=-1)

        ims.append(axes[0, i].imshow(pred_img, cmap=cmap_viridis, origin='lower', norm=norm))
        ims.append(axes[1, i].imshow(true_img, cmap=cmap_viridis, origin='lower', norm=norm))
        ims.append(axes[2, i].imshow(diff_img, cmap=cmap_coolwarm, origin='lower', norm=norm_diff))    
        axes[0, i].plot(0, 64, 'ro', markersize=5)
        
        axes[0, i].axis('off')
        axes[1, i].axis('off')
        axes[2, i].axis('off')


        axes[0, i].set_title(f"Pred frame {idx}", fontsize=12)
        axes[1, i].set_title(f"True frame {idx}", fontsize=12)
        axes[2, i].set_title(f"Diff frame {idx}", fontsize=12)

        # top colorbar spanning the whole column
        cbar_top = fig.colorbar(
            ims[-3], ax=axes[:, i], orientation='horizontal', location='top',
            fraction=0.01, pad=0.04, aspect=25
        )
        cbar_top.ax.tick_params(labelsize=8)
        cbar_top.set_label('Velocity Magnitude (m/s)', fontsize=10)
        ims[-3].set_clim(0, vmax)

        # bottom colorbar spanning the whole column
        cbar_bottom = fig.colorbar(
            ims[-1], ax=axes[:, i], orientation='horizontal', location='bottom',
            fraction=0.01, pad=0.04, aspect=25
        )
        cbar_bottom.ax.tick_params(labelsize=8)
        cbar_bottom.set_label('Difference (m/s)', fontsize=10)
        ims[-1].set_clim(vmin_diff, vmax_diff)

    if title is not None:
        fig.suptitle(title, fontsize=14, y=0.95)    

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")

    if return_fig:
        return fig


def plot_comp_evolution(roll_norm_list, true_norm_list, time_indices, probe_idcs=None, output_path=None):

    fig, axes = plt.subplots(len(time_indices), 2, figsize=(6, 3.5*len(time_indices)), gridspec_kw={'wspace': 0.1, 'hspace': 0})

    vmin = 0
    # vmax = 0.3
    vmax = max(np.max(roll_norm_list), np.max(true_norm_list))
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    cmap_viridis = plt.cm.viridis
    cmap_viridis.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

    ims = []
    for i, t in enumerate(time_indices):
        roll_mag = np.linalg.norm(np.squeeze(roll_norm_list)[t-1], axis=-1)
        true_mag = np.linalg.norm(np.squeeze(true_norm_list)[t-1], axis=-1)

        ims.append(axes[i,0].imshow(roll_mag, cmap=cmap_viridis, origin='lower', norm=norm))
        ims.append(axes[i,1].imshow(true_mag, cmap=cmap_viridis, origin='lower', norm=norm))

        axes[i, 0].set_title(f"Pred t={t}", fontsize=12)
        axes[i, 1].set_title(f"True t={t}", fontsize=12)


    for i in range(len(time_indices)):
        for j in range(2):
            axes[i, j].axis('off')
            for spine in axes[i, j].spines.values():
                spine.set_edgecolor('black')
                spine.set_linewidth(2)
            if probe_idcs is not None:
                for (y, x) in probe_idcs:
                    axes[i,j].plot(x, y, 'rx', markersize=4)  # red little crosses



    cbar = fig.colorbar(ims[0], ax=axes, orientation='horizontal', fraction=0.015, pad=0.02)
    cbar.ax.tick_params(labelsize=9)  
    cbar.set_label('Velocity Magnitude (m/s)', fontsize=12)
    ims[0].set_clim(0, vmax)

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")



def plot_probepoints(timespan, rollout, groundtruth, model_names=None, output_path=None):

    fig = plt.figure()

    n_lines = rollout.shape[0]
    for i in range(n_lines):

        if model_names is None:
            plt.plot(timespan, rollout[i], color='tab:blue', label='Rollout' if i==0 else None)
            plt.plot(timespan, groundtruth[i], color='tab:red', label='Ground Truth' if i==0 else None)

        else:        
            tab10 = plt.cm.get_cmap('tab10').colors
            colors_without_red = [c for i, c in enumerate(tab10) if i !=3]
            plt.plot(timespan, rollout[i], color=colors_without_red[i], label=model_names[i])
            plt.plot(timespan, groundtruth[0], color='tab:red', label='Ground Truth' if i==0 else None)
    
    plt.xlabel('Prediction Step')
    plt.ylabel('Velocity Magnitude (m/s)')
    plt.legend()
    plt.title('Velocity at Probe Point')

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")

def plot_metrics(csv_file, plotmeta, output_path):

    rand_values = []
    rand_mse = []

    ord_values = []
    ord_mse = []

    # Read CSV
    with open(csv_file, "r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            col_value = row["inference_probes"]  # replace with your column name containing randX/ordY
            mse = float(row[plotmeta["metric"]])         # replace with actual column name

            # Check for 'rand' or 'ord'
            match_rand = re.search(r"rand(\d+)", col_value)
            match_ord  = re.search(r"ord(\d+)", col_value)

            if match_rand:
                rand_values.append(int(match_rand.group(1)))
                rand_mse.append(mse)
            elif match_ord:
                ord_values.append(int(match_ord.group(1)))
                ord_mse.append(mse)
        
        nr_of_inference_runs = row["nr_of_inference_runs"]
        sequence_length = row["sequence_length"]

    fig = plt.figure(figsize=(7,5))

    # Check if data is present
    if not rand_values and not ord_values:
        raise ValueError("No rand or ord values found in the CSV!")
    if not rand_values:
        ord_sorted = sorted(zip(ord_values, ord_mse))
        ord_x, ord_y = zip(*ord_sorted)
        plt.plot(ord_x, ord_y, marker='x', label='ord')
    elif not ord_values:
        rand_sorted = sorted(zip(rand_values, rand_mse))
        rand_x, rand_y = zip(*rand_sorted)
        plt.plot(rand_x, rand_y, marker='o', label='rand')
    else:
        # Both available
        rand_sorted = sorted(zip(rand_values, rand_mse))
        rand_x, rand_y = zip(*rand_sorted)
        ord_sorted = sorted(zip(ord_values, ord_mse))
        ord_x, ord_y = zip(*ord_sorted)
        plt.plot(rand_x, rand_y, marker='o', label='random placing')
        plt.plot(ord_x, ord_y, marker='x', label='grid placing')

    plt.xlabel(plotmeta['xlabel'])
    plt.ylabel(plotmeta['ylabel'])
    plt.title(f"{plotmeta['metric']} vs Nr. Probe Points\nRuns={nr_of_inference_runs}, SeqLen={sequence_length}")
    plt.legend()
    plt.legend()
    plt.grid(True)
    plt.show()

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")

