import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import numpy as np

from src.utils.evaluation_functions import calc_KE
from src.utils.postprocessing import calc_ke_spectra_over_wavenumber


### ICLR PAPER PLOTS ###

def plot_velocity_snapshots(pred, true, plot_indices, u_rescale, probe_idcs):
    """
    Plot a comparison of predicted and groundtruth velocity fields over selected timesteps.

    This function displays:
    - Multiple prediction samples in rows
    - Groundtruth in the last row
    - Selected timesteps in columns
    - Optional probe locations as hollow red circles
    - Norm of velocity vectors (2D) for visualization
    - Colorbar scaled by inlet velocity

    Parameters
    ----------
    pred : np.ndarray
        Predicted velocity array of shape (n_samples, n_timesteps, H, W, 2).
    true : np.ndarray
        Groundtruth velocity array of shape (1, n_timesteps, H, W, 2).
    plot_indices : list[int]
        Indices of timesteps to plot.
    u_rescale : float
        Scaling factor to normalize velocities (e.g., inlet velocity).
    probe_idcs : list[tuple[int, int]]
        Optional list of probe (y, x) positions to highlight on plots.
    
    Returns
    -------
    matplotlib.figure.Figure
        The figure object containing all subplots.
    """

    plt.rcParams.update({'font.size': 16})
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.sans-serif'] = 'Times Roman'

    pred = np.squeeze(pred) / u_rescale
    true = np.squeeze(true) / u_rescale
    
    n_plots = len(plot_indices)

    fig, axes = plt.subplots(4, n_plots, figsize=(3.5*n_plots, 12), gridspec_kw={'wspace': 0.1, 'hspace': 0})


    cmap_viridis = plt.cm.viridis
    cmap_viridis.set_bad(color='white', alpha=0)  
    vmax = max(np.max(pred), np.max(true))
    norm = mcolors.Normalize(vmin=0, vmax=vmax)

    ims = []    

    for i, idx in enumerate(plot_indices):
        true_img = np.linalg.norm(true[0, idx], axis=-1)

        ims.append(axes[pred.shape[0], i].imshow(true_img, cmap=cmap_viridis, origin='lower', norm=norm))
        axes[pred.shape[0], i].axis('off')
        axes[0, i].set_title(f"t = {idx-15}")



        for s in range(pred.shape[0]):

            pred_img = np.linalg.norm(pred[s, idx], axis=-1)
            ims.append(axes[pred.shape[0]-1-s, i].imshow(pred_img, cmap=cmap_viridis, origin='lower', norm=norm))
            
            axes[pred.shape[0]-1-s, i].axis('off')


            if i < len(plot_indices)-1:
                for (y, x) in probe_idcs:
                    axes[s,i].plot(x, y, 'ro', markersize=4, markerfacecolor='none')  # red little crosses
                    axes[pred.shape[0],i].plot(x, y, 'rx', markersize=4, markerfacecolor='none')
    
    cbar = fig.colorbar(ims[0], ax=axes.ravel().tolist(), orientation='vertical', pad=0.025, shrink=0.8, aspect=40)
    # cbar.set_label(r'$u / \langle U \rangle_{\mathrm{inlet}}$', fontname="DejaVu Sans")

    n_rows, n_cols = axes.shape

    row_labels = ['', '', '', 'LES']
    for r, label in enumerate(row_labels):
        y = (n_rows - 0.77*r - 0.85) / n_rows  # center of each row
        fig.text(0.12, y, label, va='center', ha='right', rotation=90)

    y_last = (1 / (0.825*n_rows))  
    x_start, x_end = 0.126, 0.764  

    # fig.lines.append(plt.Line2D([x_start, x_end], [y_last, y_last], transform=fig.transFigure, color='black', linewidth=2))



def plot_mean_variance_diff(mean_true, mean_pred, var_true, var_pred, u_rescale, probe_idcs=None, output_path=None, return_fig=False):
    """
    Plots comparison of mean and variance of velocity fields, along with differences, with optional probe locations.
    
    Parameters
    ----------
    mean_true : ndarray
        True mean velocity field with shape (..., H, W, 2).
    mean_pred : ndarray
        Predicted mean velocity field with shape (..., H, W, 2).
    var_true : ndarray
        True variance velocity field with shape (..., H, W, 2).
    var_pred : ndarray
        Predicted variance velocity field with shape (..., H, W, 2).
    u_rescale : float
        Scaling factor to normalize velocities (e.g., inlet velocity).
    probe_idcs : list of tuples, optional
        List of (y, x) probe indices to highlight on the plots.
    output_path : str, optional
        If provided, the figure is saved to this path.
    return_fig : bool, optional
        If True, returns the figure object.
    
    Returns
    -------
    fig : matplotlib.figure.Figure, optional
        Returns the figure object if `return_fig` is True.
    
    Notes
    -----
    - Displays a 2x3 grid: [mean_pred, mean_true, mean_diff; var_pred, var_true, var_diff].
    - Norms the last axis of input arrays (velocity vectors) before plotting.
    - Adds separate colorbars on left/right of relevant plots without resizing images.
    - Probe positions are plotted as hollow red circles.
    """

    plt.rcParams.update({'font.size': 16})
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.sans-serif'] = 'Times Roman'


    # rescale data, calculate magnitude, calculate difference
    mean_pred = np.linalg.norm(mean_pred, axis=-1) / u_rescale
    mean_true = np.linalg.norm(mean_true, axis=-1) / u_rescale
    var_pred = np.linalg.norm(var_pred, axis=-1) / u_rescale
    var_true = np.linalg.norm(var_true, axis=-1) / u_rescale

    mean_diff = mean_pred - mean_true
    var_diff = var_pred - var_true

    fig, axes = plt.subplots(2, 3, figsize=(15, 10), gridspec_kw={'wspace': 0, 'hspace': 0.1})

    # colormapping preparation
    cmap_viridis = plt.cm.viridis
    cmap_coolwarm = plt.cm.coolwarm

    vmax_mean = max(np.max(mean_pred), np.max(mean_true))
    norm_mean = mcolors.Normalize(vmin=0, vmax=vmax_mean)

    vmax_var = max(np.max(var_pred), np.max(var_true))
    norm_var = mcolors.Normalize(vmin=0, vmax=vmax_var)

    vmax_mean_diff = max(abs(np.min(mean_diff)), np.max(mean_diff))
    vmin_mean_diff = -vmax_mean_diff
    norm_mean_diff = mcolors.Normalize(vmin=vmin_mean_diff, vmax=vmax_mean_diff)

    vmax_var_diff = max(abs(np.min(var_diff)), np.max(var_diff))
    vmin_var_diff = -vmax_var_diff
    norm_var_diff = mcolors.Normalize(vmin=vmin_var_diff, vmax=vmax_var_diff)

    # plot data
    ims = []
    ims.append(axes[0,0].imshow(mean_pred, cmap=cmap_viridis, origin='lower', norm=norm_mean))
    ims.append(axes[0,1].imshow(mean_true, cmap=cmap_viridis, origin='lower', norm=norm_mean))
    ims.append(axes[0,2].imshow(mean_diff, cmap=cmap_coolwarm, origin='lower', norm=norm_mean_diff))
    ims.append(axes[1,0].imshow(var_pred, cmap=cmap_viridis, origin='lower', norm=norm_var))
    ims.append(axes[1,1].imshow(var_true, cmap=cmap_viridis, origin='lower', norm=norm_var))
    ims.append(axes[1,2].imshow(var_diff, cmap=cmap_coolwarm, origin='lower', norm=norm_var_diff))

    axes[0, 0].set_title(f"Prediction")
    axes[0, 1].set_title(f"Groundtruth")
    axes[0, 2].set_title(f"Difference")

    # turn axis off for image data, plot probe points if provided
    for i in range(3):
        for j in range(2):
            axes[j,i].axis('off')
            if probe_idcs is not None:
                for (y, x) in probe_idcs:
                    axes[j,i].plot(x, y, 'ro', markersize=4, markerfacecolor='none')      

    # colorbars 
    pos0 = axes[0,0].get_position()
    cax0 = fig.add_axes([pos0.x0 - 0.025, pos0.y0, 0.015, pos0.height])
    cbar0 = fig.colorbar(ims[0], cax=cax0)
    cbar0.ax.yaxis.set_ticks_position('left')
    cbar0.ax.yaxis.set_label_position('left')
    # cbar0.set_label(r"$\overline{u}^2 / \langle U \rangle_{\mathrm{inlet}}$")

    pos1 = axes[1,0].get_position()
    cax1 = fig.add_axes([pos1.x0 - 0.025, pos1.y0, 0.015, pos1.height])
    cbar1 = fig.colorbar(ims[3], cax=cax1)
    cbar1.ax.yaxis.set_ticks_position('left')
    cbar1.ax.yaxis.set_label_position('left')
    # cbar1.set_label(r"$\overline{u'}^2 / \langle U \rangle_{\mathrm{inlet}}$")

    pos2 = axes[0,2].get_position()
    cax2 = fig.add_axes([pos2.x1 + 0.01, pos2.y0, 0.015, pos2.height])
    cbar2 = fig.colorbar(ims[2], cax=cax2)
    # cbar2.set_label(r"$ \Delta \overline{u}^2 / \langle U \rangle_{\mathrm{inlet}}$")

    pos3 = axes[1,2].get_position()
    cax3 = fig.add_axes([pos3.x1 + 0.01, pos3.y0, 0.015, pos3.height])
    cbar3 = fig.colorbar(ims[5], cax=cax3)
    # cbar3.set_label(r"$ \Delta \overline{u'}^2 / \langle U \rangle_{\mathrm{inlet}}$")


    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: \n{output_path}")

    if return_fig:
        return fig


def plot_taverage_KE_histogram(prediction, groundtruth, bins=100, output_path=None, return_fig=False):
    """
    Plot overlaid histograms of time-averaged kinetic energy for predictions and groundtruth
    with a logarithmic y-axis.

    Parameters
    ----------
    prediction : np.ndarray
        Predicted velocity field array (any shape that calc_KE supports).
    groundtruth : np.ndarray
        Groundtruth velocity field array (any shape that calc_KE supports).
    output_path : str, optional
        If provided, saves the figure to this path.
    return_fig : bool
        If True, returns the matplotlib Figure object.

    Returns
    -------
    matplotlib.figure.Figure (optional)
    """

    plt.rcParams.update({'font.size': 16})
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.sans-serif'] = 'Times Roman'

    KE_prediction = calc_KE(prediction)
    KE_groundtruth = calc_KE(groundtruth)

    fig = plt.figure(figsize=(8, 6))

    plt.hist(KE_groundtruth.ravel(), bins=bins, alpha=0.5, color='tab:blue', label='Groundtruth', density=True, log=True)
    plt.hist(KE_prediction.ravel(), bins=bins, alpha=0.5, color='tab:red', label='Prediction', density=True, log=True)

    # plt.xlim([0, 0.03])
    plt.xlabel('Kinetic Energy')
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=2)    
    plt.tight_layout()

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: {output_path}")

    if return_fig:
        return fig


def plot_taverage_KE_spectra(
        TKE_prediction, 
        wavenumbers_prediction, 
        TKE_groundtruth, 
        wavenumbers_groundtruth,
        output_path=None, 
        return_fig=None
        ):
    
    """
        Plot time-averaged kinetic energy (TKE) spectra of prediction and ground truth
        as a log-log comparison.

        Parameters
        ----------
        TKE_prediction : array-like
            1D array of time-averaged kinetic energy spectrum from the prediction model.
        wavenumbers_prediction : array-like
            1D array of wavenumbers corresponding to TKE_prediction.
        TKE_groundtruth : array-like
            1D array of time-averaged kinetic energy spectrum from ground truth data.
        wavenumbers_groundtruth : array-like
            1D array of wavenumbers corresponding to TKE_groundtruth.
        output_path : str, optional
            If provided, saves the figure as a high-resolution image to this path.
        return_fig : bool, optional
            If True, return the matplotlib Figure object for further use.

        Returns
        -------
        matplotlib.figure.Figure, optional
            The created figure if `return_fig=True`.
        
        Notes
        -----
        - Both spectra are plotted on a log-log scale.
        - The legend is placed above the figure for clarity.
        - Axes are labeled with wavenumber `k` and energy spectrum `E(k)`.
        """

    plt.rcParams.update({'font.size': 16})
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.sans-serif'] = 'Times Roman'

    fig = plt.figure(figsize=(8, 6))

    plt.loglog(
            TKE_prediction, 
            wavenumbers_prediction, 
            color='tab:red',
            label='Prediction'
            )
    
    plt.loglog(
            TKE_groundtruth, 
            wavenumbers_groundtruth, 
            color='tab:blue',
            label='Ground Truth'
            )
        
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=2)    
    plt.tight_layout()

    plt.ylabel('E(k)')
    plt.xlabel('k')
    plt.show()

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: {output_path}")

    if return_fig:
        return fig
    

def plot_probes_comparison(
        probes_data,
        output_path=None, 
        return_fig=None
        ):   
        
    """
        Plot the per-sequence MSE vs. number of probe points used, 
        for different kinds of probe point positioning. 
    """
    
    plt.rcParams.update({'font.size': 16})
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.sans-serif'] = 'Times Roman'

    fig = plt.figure(figsize=(8, 6))



    for key in probes_data.keys():
        plt.scatter(probes_data[key][:,0], probes_data[key][:,1], marker='x', label=key)

    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=len(probes_data.keys()))    
    plt.tight_layout()

    plt.xlim(0, probes_data[key][-1,0])
    plt.ylabel('MSE')
    plt.xlabel('Nr. of Probe Points')
    plt.show()


    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: {output_path}")

    if return_fig:
        return fig
    
def plot_probes_mse_distance(
        mse, 
        distances,
        n_bins=100,
        output_path=None,
        return_fig=None
        ):
    
    """
        Plot per-pixel MSE vs. distance to next probe point.
    """

    plt.rcParams.update({'font.size': 16})
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.sans-serif'] = 'Times Roman'


    bin_edges = np.linspace(np.min(distances), np.max(distances), n_bins+1)
    bin_indices = np.digitize(distances, bin_edges[1:-1], right=True)

    mean_mse = np.array([np.mean(mse[bin_indices == i]) for i in range(n_bins)])
    var_mse = np.array([np.var(mse[bin_indices == i], ddof=1) for i in range(n_bins)])
    std_mse = np.sqrt(var_mse)

    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2


    fig, ax = plt.subplots(figsize=(8, 6))

    ax.plot(bin_centers, mean_mse, 'b-', linewidth=2, label='Mean MSE')
    ax.fill_between(bin_centers, mean_mse - std_mse, mean_mse + std_mse, alpha=0.2, color='blue', label='Std Dev')

    ax.set_xlabel('Distance to Nearest Probe Point')
    ax.set_ylabel('MSE')
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=2)    

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: {output_path}")

    if return_fig:
        return fig



def plot_distribution_MSE(
        mse, 
        output_path=None,
        return_fig=None
        ):

    plt.rcParams.update({'font.size': 16})
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.sans-serif'] = 'Times Roman'


    fig = plt.figure(figsize=(8, 6))

    bins = 100
    plt.hist(mse, bins=bins, alpha=0.5, color='tab:red', label='Prediction', density=True, log=True)
    

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: {output_path}")

    if return_fig:
        return fig
    

def plot_distribution_MSE(
        mse, 
        bins=100,
        output_path=None,
        return_fig=None
        ):

    """
        Plots a histogram of the MSE between N number of sequences 
        of groundtruth and prediction.
    """

    plt.rcParams.update({'font.size': 16})
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.sans-serif'] = 'Times Roman'

    fig = plt.figure(figsize=(8, 6))
    plt.hist(mse, bins=bins, alpha=0.5, color='tab:red', label='Prediction', density=True, log=True)
    plt.xlabel('MSE')
    plt.ylabel("Probability Density")


    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: {output_path}")

    if return_fig:
        return fig
    
def plot_distribution_wasserstein(
        wasserstein, 
        bins=100,
        output_path=None,
        return_fig=None
        ):

    """
        Plots a histogram of the MSE between N number of sequences 
        of groundtruth and prediction.
    """

    plt.rcParams.update({'font.size': 16})
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.sans-serif'] = 'Times Roman'

    fig = plt.figure(figsize=(8, 6))
    plt.hist(wasserstein, bins=bins, alpha=0.5, color='tab:red', label='Prediction', density=True, log=True)
    plt.xlabel("Wasserstein Distance (Kinetic Energy)")
    plt.ylabel("Probability Density")

    if output_path is not None:
        fig.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"Saved figure to: {output_path}")

    if return_fig:
        return fig