# This script contains all the utility functions for the advection equation.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import ticker
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from matplotlib import cm
from matplotlib.ticker import LogLocator, MaxNLocator


def update_plot_style():
    """
    Update matplotlib global style settings for consistency across plots.
    Disables LaTeX rendering for faster and more compatible plotting.
    """
    plt.rcParams.update({"text.usetex": False})

def extract_errors(res, n_s, jump, s_i):
    """
    Extracts the minimum error metrics (RMSE, L2 error, etc.) from a result array.

    Parameters:
        res (np.ndarray): Stacked experiment results.
        n_s (np.ndarray): Array of collocation point counts.
        jump (int): Offset for slicing into res for each n_s group.
        s_i (int): Starting index for metric columns in res.

    Returns:
        dict: Dictionary containing arrays of error metrics for ELM and SWIM.
    """
    def min_error(col_offset):
        return np.array([min(res[i * jump:(i + 1) * jump, s_i + col_offset]) for i in range(len(n_s))])

    return {
        'rmse_elm': min_error(0),
        'rmse_swim': min_error(1),
        'l2_rel_err_elm': min_error(4),
        'l2_rel_err_swim': min_error(5),
        'l2_rel_err_elm_min': min_error(6),
        'l2_rel_err_swim_min': min_error(7),
    }

def compute_spectral_line(x_vals, y_vals, shift=3):
    """
    Computes a function representing spectral convergence in log space.

    Parameters:
        x_vals (tuple): Two x-values for computing slope (e.g., collocation points).
        y_vals (tuple): Corresponding y-values (e.g., L2 errors).
        shift (float): Vertical shift in log space to position the line.

    Returns:
        callable: Function that computes y given x based on spectral slope.
    """
    x1, x2 = x_vals
    y1, y2 = y_vals
    m = (np.log(y2) - np.log(y1)) / (x2 - x1)
    c = np.log(y1) - m * x1 + shift
    return lambda x: np.exp(m * x + c)

def plot_convergence(x, y1, y2, x_spec, y_spec, xlabel, ylabel, filename, legend_labels, fontsize=12):
    """
    Creates and saves a semilog-y convergence plot with two data series and a spectral reference.

    Parameters:
        x (array-like): X-axis values for both y1 and y2.
        y1 (array-like): First error curve (e.g., ELM).
        y2 (array-like): Second error curve (e.g., SWIM).
        x_spec (array-like): X-values for the spectral reference line.
        y_spec (array-like): Y-values for the spectral reference line.
        xlabel (str): Label for the x-axis.
        ylabel (str): Label for the y-axis.
        filename (str): Output filename for the saved plot (PDF).
        legend_labels (list): Labels for the legend entries.
        fontsize (int): Font size for labels and ticks.
    """
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.semilogy(x, y1, label=legend_labels[0])
    ax.semilogy(x, y2, '--', label=legend_labels[1])
    ax.semilogy(x_spec, y_spec, ':', label=legend_labels[2])
    ax.set_xlabel(xlabel, fontsize=fontsize)
    ax.set_ylabel(ylabel, fontsize=fontsize)
    ax.legend(frameon=False, loc='upper right', fontsize=fontsize)
    ax.tick_params(axis='both', labelsize=fontsize)
    fig.tight_layout()
    plt.savefig(filename)



def plot_field(ax, field, extent, cmap, aspect, fontsize, title, use_sci_notation=False, ticks=None):
    """
    Plots a 2D field using imshow with standardized formatting and colorbar.

    Parameters:
        ax (matplotlib.axes.Axes): The axis to plot on.
        field (np.ndarray): 2D array of the field to visualize.
        extent (list): Extent of the image [xmin, xmax, ymin, ymax].
        cmap (str or Colormap): Color map for imshow.
        aspect (float): Aspect ratio for imshow.
        fontsize (int): Font size for all labels and ticks.
        title (str): Title of the plot.
        use_sci_notation (bool): If True, use scientific notation on colorbar.
        ticks (list): Specific ticks to display on the colorbar.
    """
    im = ax.imshow(field.T, extent=extent, origin='lower', aspect=aspect, cmap=cmap)
    cb = ax.figure.colorbar(im, ax=ax, location='right', aspect=10, shrink=0.3)

    if use_sci_notation:
        cb.formatter = ticker.FuncFormatter(lambda x, pos: f'{x:.0e}')
    if ticks is not None:
        cb.set_ticks(ticks)

    cb.ax.tick_params(labelsize=fontsize)
    cb.update_ticks()

    ax.set_xlabel('t', fontsize=fontsize)
    ax.set_ylabel('x', fontsize=fontsize)
    ax.tick_params(axis='both', labelsize=fontsize)
    ax.set_xticks([0., 1.])
    ax.set_yticks([0., 6.28])
    ax.set_title(title, fontsize=fontsize)



def plot_heatmap(
    data,
    t_eval,
    x_space,
    cmap,
    filename=None,
    title=None,
    colorbar_location="right",
    fontsize=20,
    aspect=0.07,
    scientific_format=True,
):
    """
    Plot a heatmap of the solution or error with consistent formatting across notebook and saved PDF.

    Parameters:
        data (np.ndarray): 2D array of shape (x, t), to be transposed for imshow.
        t_eval (np.ndarray): Array of time points.
        x_space (np.ndarray): Array of spatial points.
        cmap (str or Colormap): Colormap to use for the heatmap.
        filename (str, optional): File name to save the figure. If None, figure is not saved.
        title (str, optional): Title of the plot.
        colorbar_location (str): Location of colorbar ('right', 'bottom', etc.).
        fontsize (int): Font size for labels and ticks.
        aspect (float): Aspect ratio for imshow.
        scientific_format (bool): Whether to use scientific notation on the colorbar.
    """
    extent = [np.min(t_eval), np.max(t_eval), np.min(x_space), np.max(x_space)]

    fig, ax = plt.subplots(figsize=(6, 5), constrained_layout=True)
    img = ax.imshow(data.T, extent=extent, origin="lower", aspect=aspect, cmap=cmap)

    cb = fig.colorbar(img, ax=ax, location=colorbar_location, 
                      aspect=20 if colorbar_location == 'bottom' else 10, 
                      shrink=1 if colorbar_location == 'bottom' else 0.5)

    # Configure colorbar ticks
    cb.set_ticks([data.min(), data.max()])
    if scientific_format:
        cb.formatter = ticker.FuncFormatter(lambda x, pos: f'{x:.1e}')
    cb.update_ticks()
    cb.ax.tick_params(labelsize=fontsize)

    # Axis labels and ticks
    ax.set_xlabel('t', fontsize=fontsize)
    ax.set_ylabel('x', fontsize=fontsize)
    ax.set_xticks([0.0, 1.0])
    ax.set_yticks([0.0, 6.28])
    ax.tick_params(axis='both', labelsize=fontsize)

    if title:
        ax.set_title(title, fontsize=fontsize)

    # Save or display
    if filename:
        plt.savefig(filename, format='pdf' if filename.endswith('.pdf') else 'png', bbox_inches='tight')
    plt.show()


def plot_3d_solution_with_slices(xx, yy, u_data, x_vals, t_vals, slices, filename='adv_true_sol_3d.png'):
    """
    Plots a 3D surface with 2D contours on the bottom and vertical slices at fixed spatial positions.

    Parameters:
        xx (ndarray): Meshgrid for x-axis (space).
        yy (ndarray): Meshgrid for y-axis (time).
        u_data (ndarray): 2D array of solution values (shape: [len(t_vals), len(x_vals)]).
        x_vals (ndarray): 1D spatial coordinate array.
        t_vals (ndarray): 1D temporal coordinate array.
        slices (list of dict): Each dict defines a slice with keys:
            - 'x': x-location of slice
            - 'line_zorder': z-order for the slice curve
            - 'rect_zorder': z-order for the bounding rectangle
        filename (str): Path to save the figure.
    """
    fig = plt.figure(figsize=(6, 4.5))
    ax = fig.add_subplot(111, projection='3d')

    z_min, z_max = u_data.min(), u_data.max()

    # Bottom contour plot
    contour = ax.contourf(
        xx, yy, u_data,
        zdir='z',
        offset=z_min,
        cmap='jet',
        alpha=0.8,
        levels=500,
        zorder=1
    )

    for slice_cfg in slices:
        x_fixed = slice_cfg['x']
        line_z = slice_cfg['line_zorder']
        rect_z = slice_cfg['rect_zorder']
        x_idx = np.argmin(np.abs(x_vals - x_fixed))
        u_slice = u_data[:, x_idx]

        # Blue curve slice
        ax.plot(
            np.full_like(t_vals, x_fixed),
            t_vals,
            u_slice,
            color='b',
            linewidth=2,
            zorder=line_z
        )

        # Hollow black rectangle
        add_vertical_slice_rectangle(ax, x_fixed, t_vals, z_min, z_max, zorder=rect_z)

    # View and ticks
    ax.view_init(elev=30, azim=230)
    ax.set_box_aspect([1, 1, 0.7])
    ax.set_xticks([0, 2 * np.pi])
    ax.set_yticks([0, 1])
    ax.set_zticks([])
    ax.tick_params(labelsize=14)

    # Colorbar
    cbar = fig.colorbar(contour, ax=ax, shrink=0.5, aspect=10)
    cbar.ax.tick_params(labelsize=14)
    cbar.set_ticks([z_min, z_max])

    plt.tight_layout()
    fig.savefig(filename)
    plt.show()


def add_vertical_slice_rectangle(ax, x_fixed, t_vals, z_min, z_max, zorder=4):
    """
    Adds a vertical hollow rectangle along time axis for a fixed x-location.

    Parameters:
        ax (Axes3D): Matplotlib 3D axis.
        x_fixed (float): The x-location where the rectangle is placed.
        t_vals (ndarray): Time coordinates.
        z_min (float): Minimum z for rectangle base.
        z_max (float): Maximum z for rectangle top.
        zorder (int): z-order of the rectangle.
    """
    y0, y1 = t_vals[0], t_vals[-1]

    edges = [
        [[x_fixed, y0, z_min], [x_fixed, y1, z_min]],
        [[x_fixed, y1, z_min], [x_fixed, y1, z_max]],
        [[x_fixed, y1, z_max], [x_fixed, y0, z_max]],
        [[x_fixed, y0, z_max], [x_fixed, y0, z_min]],
    ]
    rect = Line3DCollection(edges, colors='k', linewidths=2, zorder=zorder)
    ax.add_collection3d(rect)


def plot_solution_image_minimal(
    data, x_space, t_eval, output_file,
    fontsize=14, cmap=cm.jet, aspect=16,
    xticks=None, yticks=None, colorbar_ticks=None
):
    """
    Plot a minimal 2D solution image with specified ticks and save to file.

    Parameters:
        data (2D array): Solution data to visualize (shape: [space, time]).
        x_space (array): Spatial domain values.
        t_eval (array): Temporal domain values.
        output_file (str): Filename to save the image.
        fontsize (int): Font size for ticks.
        cmap: Matplotlib colormap.
        aspect (float): Aspect ratio of the image.
        xticks (list): Custom x-ticks to set.
        yticks (list): Custom y-ticks to set.
        colorbar_ticks (list): Specific tick values for colorbar.
    """
    fig, ax = plt.subplots(figsize=(5, 4), constrained_layout=True)
    extent = [np.min(t_eval), np.max(t_eval), np.min(x_space), np.max(x_space)]

    img = ax.imshow(data.T, extent=extent, origin='lower', aspect=aspect, cmap=cmap)
    cb = fig.colorbar(img, ax=ax, location='right', aspect=5, shrink=0.1)
    
    if colorbar_ticks:
        cb.set_ticks(colorbar_ticks)
    cb.ax.tick_params(labelsize=fontsize)

    if xticks is not None:
        ax.set_xticks(xticks)
    if yticks is not None:
        ax.set_yticks(yticks)

    ax.tick_params(axis='both', labelsize=fontsize)
    plt.savefig(output_file)


def set_plot_style(fontsize=14, lw=2):
    """Set global matplotlib style parameters."""
    rc_fonts = {"text.usetex": False}
    plt.rcParams.update(rc_fonts)
    return fontsize, lw, cm.jet

def plot_error_vs_time(t_eval, err_dict, linestyles, colors, output_file, fontsize=14, lw=2):
    """
    Plot relative L2 error vs time for multiple methods.

    Parameters:
        t_eval (array): Time evaluation points.
        err_dict (dict): Dictionary of method label to error arrays.
        linestyles (dict): Mapping from method labels to line styles.
        colors (dict): Mapping from method labels to line colors.
        output_file (str): Filename to save the plot.
        fontsize (int): Font size for labels and ticks.
        lw (int): Line width.
    """
    fig, ax = plt.subplots(figsize=(4, 2.5), sharey=True)
    for label, error in err_dict.items():
        ax.semilogy(
            t_eval, error, label=label,
            linestyle=linestyles[label],
            color=colors[label],
            linewidth=lw
        )

    ax.legend(frameon=False, loc='lower right', fontsize=fontsize)
    ax.grid(True)
    ax.set_xlabel('Time', fontsize=fontsize)
    ax.set_ylabel(r'Relative  $\mathbb{L}_{2}$ error', fontsize=fontsize)
    ax.yaxis.set_major_locator(LogLocator(base=10.0, numticks=5))
    ax.yaxis.set_minor_locator(LogLocator(base=10.0, subs='auto', numticks=10))
    ax.xaxis.set_major_locator(MaxNLocator(nbins=5))
    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    ax.tick_params(axis='both', which='minor', labelsize=10)
    fig.tight_layout()
    plt.savefig(output_file)

def plot_solution_image(data, x_space, t_eval, title, output_file, fontsize=14, cmap=cm.jet):
    """
    Plot a 2D solution image (e.g., true solution).

    Parameters:
        data (2D array): The solution data to visualize.
        x_space (array): Spatial domain values.
        t_eval (array): Temporal domain values.
        title (str): Title of the plot.
        output_file (str): Filename to save the image.
        fontsize (int): Font size for labels and ticks.
        cmap: Colormap to use for image.
    """
    fig, ax = plt.subplots(figsize=(6, 3), constrained_layout=True)
    extent = [np.min(t_eval), np.max(t_eval), np.min(x_space), np.max(x_space)]
    img = ax.imshow(data.T, extent=extent, origin='lower', aspect=16, cmap=cmap)
    cb = fig.colorbar(img, ax=ax, location='bottom', aspect=50)
    ax.set_title(title, fontsize=fontsize)
    ax.set_xlabel('t', fontsize=fontsize)
    ax.set_ylabel('x', fontsize=fontsize)
    cb.ax.tick_params(labelsize=fontsize)
    ax.tick_params(axis='both', labelsize=fontsize)
    plt.savefig(output_file)
