# This script contains all the utility functions for the Euler-Bernoulli beam equation
import matplotlib.pyplot as plt
from matplotlib import ticker
import numpy as np
from mpl_toolkits.mplot3d.art3d import Line3DCollection

def plot_heatmap(data, extent, aspect, cmap, xlabel, ylabel, output_file, fontsize=24):
    """
    Visualize a solution using a heatmap.

    Parameters:
    - data (numpy array): The data to visualize.
    - extent (list): [x_min, x_max, t_min, t_max] for the plot.
    - aspect (float): Aspect ratio for the heatmap.
    - cmap (Colormap): Colormap for the heatmap.
    - xlabel (str): Label for the x-axis.
    - ylabel (str): Label for the y-axis.
    - output_file (str): Path to save the resulting plot.
    - fontsize (int, optional): Font size for labels and ticks. Default is 24.
    """
    fig, ax = plt.subplots(1, 1, figsize=(6, 3), constrained_layout=True)
    sol_img = ax.imshow(data, extent=extent, origin='lower', aspect=aspect, cmap=cmap)
    cb = fig.colorbar(sol_img, ax=ax, location='bottom', aspect=20)
    
    # Format colorbar labels in scientific notation
    cb.formatter = ticker.FuncFormatter(lambda x, pos: f'{x:.0e}')
    cb.update_ticks()
    
    # Set font sizes
    cb.ax.tick_params(labelsize=fontsize)
    ax.set_xlabel(xlabel, fontsize=fontsize)
    ax.set_ylabel(ylabel, fontsize=fontsize)
    ax.tick_params(axis='both', labelsize=fontsize)
    
    # Save the plot
    plt.savefig(output_file)

def plot_solution_comparison(x, t, u_true, u_pred, save_path=None, cmap='jet', aspect_ratio=4):
    """
    Plots the true solution, predicted solution, and absolute error side-by-side.

    Parameters:
        x (np.ndarray): 1D array of spatial coordinates.
        t (np.ndarray): 1D array of time coordinates.
        u_true (np.ndarray): 2D array of true solution, shape (len(t), len(x)).
        u_pred (np.ndarray): 2D array of predicted solution, same shape as u_true.
        save_path (str, optional): If provided, saves the figure to the given path.
        cmap (str): Colormap to use for plots.
        aspect_ratio (float): Aspect ratio for imshow.
    """
    extent = [x.min(), x.max(), t.min(), t.max()]
    plt.rcParams['image.cmap'] = cmap

    fig, axes = plt.subplots(1, 3, figsize=(10, 3))
    titles = ['True solution', 'Frozen-PINN-elm solution', 'Absolute difference']
    data = [u_true, u_pred, np.abs(u_pred - u_true)]

    for ax, img, title in zip(axes, data, titles):
        im = ax.imshow(img, extent=extent, origin='lower', aspect=aspect_ratio)
        ax.set_title(title)
        fig.colorbar(im, ax=ax, location='bottom')

    fig.tight_layout()

    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
    plt.show()

def plot_3d_surface_with_contour(x, t, u, x_lim, time_slice=None, save_path='3d_solution.png',
                                 cmap='jet', offset=-2.0, fontsize=18):
    """
    Plot a 3D surface of the solution u(x, t) with a bottom contour and an optional vertical time slice.

    Parameters:
        x (np.ndarray): 1D spatial grid.
        t (np.ndarray): 1D temporal grid.
        u (np.ndarray): 2D solution array of shape (len(t), len(x)).
        x_lim (tuple): (x_min, x_max) limits used for drawing the time slice box.
        time_slice (float, optional): Time value at which to draw a vertical slice. Default is None.
        save_path (str): Path to save the generated plot.
        cmap (str): Colormap to use for visualization.
        offset (float): z-offset for bottom contour projection.
        fontsize (int): Font size for colorbar ticks.
    """
    fig = plt.figure(figsize=(4, 3))
    ax = fig.add_subplot(111, projection='3d')

    X, T = np.meshgrid(x, t)
    ax.plot_surface(X, T, u, alpha=1, cmap=cmap)

    # Bottom contour projection
    contour = ax.contourf(X, T, u, zdir='z', offset=offset, cmap=cmap, alpha=0.8, levels=500)

    # Add slice if specified
    if time_slice is not None:
        _add_time_slice(ax, x, t, u, time_slice, x_lim, offset)

    # View settings
    ax.set_xlim([X.min(), X.max()])
    ax.set_ylim([T.min(), T.max()])
    ax.set_zlim([offset - 0.01, u.max()])
    ax.view_init(elev=20, azim=210)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    # Colorbar
    cbar = fig.colorbar(contour, ax=ax, shrink=0.5, aspect=10)
    cbar.ax.tick_params(labelsize=fontsize)
    cbar.set_ticks([u.min(), u.max()])

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    plt.show()


def _add_time_slice(ax, x, t, u, t_fixed, x_lim, offset, color='k'):
    """
    Draw a vertical slice of the solution u(x, t_fixed) and bounding box at time t_fixed.

    Parameters:
        ax (Axes3D): Matplotlib 3D axes.
        x (np.ndarray): 1D spatial grid.
        t (np.ndarray): 1D temporal grid.
        u (np.ndarray): 2D solution array.
        t_fixed (float): Specific time value to plot the slice.
        x_lim (tuple): (x_min, x_max) bounds for the box edges.
        offset (float): z-coordinate for the base of the slice.
        color (str): Line color for the slice and box.
    """
    idx = np.argmin(np.abs(t - t_fixed))
    z_top = u.max()

    # Draw bounding box
    edges = [
        [[x_lim[0], t_fixed, offset], [x_lim[1], t_fixed, offset]],
        [[x_lim[1], t_fixed, offset], [x_lim[1], t_fixed, z_top]],
        [[x_lim[1], t_fixed, z_top], [x_lim[0], t_fixed, z_top]],
        [[x_lim[0], t_fixed, z_top], [x_lim[0], t_fixed, offset]]
    ]
    ax.add_collection3d(Line3DCollection(edges, colors=color, linewidths=2))

    # Draw vertical solution slice
    ax.plot(x.reshape(-1), np.full_like(x, t_fixed), u[idx, :], color=color, linewidth=2, alpha=0.99)
