import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy.typing as npt
import numpy as np
from scipy.stats.qmc import LatinHypercube
from matplotlib.ticker import ScalarFormatter
from mpl_toolkits.mplot3d import Axes3D  # Required for 3D plots
from mpl_toolkits.mplot3d.art3d import Line3DCollection

def plot_3d_solution_surface(ax, X, T, u_true_test, alpha=0.1):
    """
    Plot a semi-transparent 3D surface representing the solution.

    Parameters:
        ax (Axes3D): Matplotlib 3D axis.
        X (ndarray): Spatial grid.
        T (ndarray): Temporal grid.
        u_true_test (ndarray): Ground truth solution values.
        alpha (float): Transparency level for surface.
    """
    ax.plot_surface(X, T, u_true_test, alpha=alpha)


def plot_collocation_points_1(ax, burgers_solver_swim, time_blocks, offset=-1.8):
    """
    Plot collocation points and project them onto the bottom plane.

    Parameters:
        ax (Axes3D): Matplotlib 3D axis.
        burgers_solver_swim: Solver object with collocation and solution data.
        time_blocks (int): Number of time segments.
        offset (float): Z-axis offset for projected scatter.
    """
    for t_s in range(time_blocks - 1):
        pts = burgers_solver_swim.collocation_point_collection[t_s][::5, 0]
        grad_pts = burgers_solver_swim.sol_collection[t_s][::5, 0]
        time_vals = np.full_like(pts, (t_s + 1) / time_blocks)

        ax.scatter(pts, time_vals, grad_pts, color='red', linewidth=2, alpha=0.3)
        ax.scatter(pts, time_vals, np.full_like(grad_pts, offset),
                   color='k', linewidth=2, alpha=0.01)


def add_slice_box(ax, x_eval, y_fixed, u_true_test, x_lim, t_eval, offset, color='r'):
    """
    Add a rectangular slice in the 3D plot to emphasize a time section.

    Parameters:
        ax (Axes3D): Matplotlib 3D axis.
        x_eval (ndarray): Spatial evaluation points.
        y_fixed (float): Fixed time slice (value of t).
        u_true_test (ndarray): Ground truth solution values.
        x_lim (tuple): (x_min, x_max) bounds.
        t_eval (ndarray): Time evaluation array.
        offset (float): Z-axis offset for bottom plane.
        color (str): Outline color for the slice box.
    """
    y_index = np.argmin(np.abs(t_eval - y_fixed))
    xz_edges = [
        [[x_lim[0], y_fixed, offset], [x_lim[1], y_fixed, offset]],
        [[x_lim[1], y_fixed, offset], [x_lim[1], y_fixed, -offset]],
        [[x_lim[1], y_fixed, -offset], [x_lim[0], y_fixed, -offset]],
        [[x_lim[0], y_fixed, -offset], [x_lim[0], y_fixed, offset]]
    ]
    ax.add_collection3d(Line3DCollection(xz_edges, colors=color, linewidths=2))


def plot_3d_collocation_diagnostics(X, T, u_true_test, burgers_solver_swim,
                                    time_blocks, x_eval, t_eval, x_lim,
                                    slices=[0.0, 0.5, 1.0], offset=-1.8,
                                    save_path='burgers_3d_sol.png'):
    """
    Create a 3D plot showing the solution surface and collocation points,
    including bottom-plane projections and slice boxes.

    Parameters:
        X (ndarray): Spatial meshgrid.
        T (ndarray): Temporal meshgrid.
        u_true_test (ndarray): Ground truth 2D solution array.
        burgers_solver_swim: Solver object with collocation/solution data.
        time_blocks (int): Number of time segments in domain.
        x_eval (ndarray): Spatial evaluation points.
        t_eval (ndarray): Time evaluation points.
        x_lim (tuple): Spatial limits (min, max) for slice outlines.
        slices (list): Time values to highlight using slice boxes.
        offset (float): Z-axis offset for projections.
        save_path (str): File path to save the figure.
    """
    fig = plt.figure(figsize=(4, 3))
    ax = fig.add_subplot(111, projection='3d')

    # Plot surface and points
    plot_3d_solution_surface(ax, X, T, u_true_test, alpha=0.1)
    plot_collocation_points_1(ax, burgers_solver_swim, time_blocks, offset)

    # Optional slice boxes
    #for y_fixed in slices:
    #    add_slice_box(ax, x_eval, y_fixed, u_true_test, x_lim, t_eval, offset)

    # Formatting
    ax.set_xlim([X.min(), X.max()])
    ax.set_ylim([T.min(), T.max()])
    ax.set_zlim([offset, u_true_test.max() + 0.1])
    ax.view_init(elev=20, azim=210)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    plt.show()



def plot_probability_distribution(ax, burgers_solver_swim, n_sample, collocation_points_probabilities, fontsize = 14):
    """
    Plot a heatmap showing the resampling probability distribution over space-time.

    Parameters:
        ax (Axes): Matplotlib axis to plot on.
        burgers_solver_swim: Object with gradient_collection_sample.
        n_sample (int): Number of spatial samples per time step.
        collocation_points_probabilities (callable): Function to compute sampling probabilities.
    """
    grad = np.abs(np.asarray(burgers_solver_swim.gradient_collection_sample)).reshape(-1, n_sample).T
    prob = collocation_points_probabilities(grad)

    im = ax.imshow(prob, extent=[0, 1, -1, 1], aspect=0.3, cmap='jet', norm='log')
    cbar = plt.colorbar(im, ax=ax, shrink=0.4, aspect=12)
    cbar.ax.tick_params(labelsize=fontsize)
    #cbar.set_ticks([1e-4, 3e-5])  # Custom tick positions
    #cbar.set_ticks([img.get_clim()[0], img.get_clim()[1]])

    ax.set_title('Probability distribution', fontsize=fontsize)
    ax.set_xlabel('t', fontsize=fontsize)
    ax.set_ylabel('x', fontsize=fontsize)
    ax.tick_params(axis='both', labelsize=fontsize)


def plot_sampled_collocation_points(ax, burgers_solver_swim, time_blocks, 
                                    fontsize = 14, alpha=0.005):
    """
    Plot a scatter plot of all collocation points sampled across time blocks.

    Parameters:
        ax (Axes): Matplotlib axis to plot on.
        burgers_solver_swim: Object with collocation_point_collection.
        time_blocks (int): Number of temporal blocks in the domain.
    """
    for t_s in range(time_blocks - 1):
        pts = burgers_solver_swim.collocation_point_collection[t_s][:, 0]
        time_vals = np.full_like(pts, (t_s + 1) / time_blocks)
        ax.scatter(time_vals, pts.reshape(-1), color='k', linewidth=1, alpha=alpha)

    ax.set_xlim(0, 1)
    ax.set_ylim(-1, 1)
    ax.set_aspect(0.3)
    ax.set_title('Sampled points', fontsize=fontsize)
    ax.set_xlabel('t', fontsize=fontsize)
    ax.set_ylabel('x', fontsize=fontsize)
    ax.tick_params(axis='both', labelsize=fontsize)


def plot_sampling_diagnostics(burgers_solver_swim, time_blocks, n_sample, collocation_points_probabilities, 
                              fontsize = 14, figsize=(3.5, 4), alpha=0.005):
    """
    Generate a 2-panel figure showing:
      1. Resampling probability distribution
      2. Scatter of sampled collocation points

    Parameters:
        burgers_solver_swim: Object with sampling and gradient data.
        time_blocks (int): Number of time intervals.
        n_sample (int): Number of samples per time block.
        collocation_points_probabilities (callable): Probability mapping function.
    """
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, constrained_layout=True)

    plot_probability_distribution(ax1, burgers_solver_swim, n_sample, collocation_points_probabilities, fontsize = fontsize)
    plot_sampled_collocation_points(ax2, burgers_solver_swim, time_blocks, alpha=alpha, fontsize = fontsize)

    fig.savefig('burgers_sampling.pdf')
    fig.savefig('burgers_sampling.png')
    plt.show()


def plot_collocation_points(burgers_solver_swim, time_blocks, ax):
    """
    Plot the locations of collocation points over time.

    Parameters:
        burgers_solver_swim: Object containing collocation_point_collection.
        time_blocks (int): Number of time blocks in the domain.
        ax (Axes): Matplotlib axis to plot on.
    """
    for t_s in range(time_blocks - 1):
        pts = burgers_solver_swim.collocation_point_collection[t_s][:, 0]
        time_val = (t_s + 1) / time_blocks
        ax.scatter(
            np.full_like(pts, time_val),
            pts.reshape(-1),
            color='k', linewidth=1, alpha=0.008
        )
    ax.set_title("Collocation Points")


def plot_gradient_magnitudes(burgers_solver_swim, time_blocks, ax):
    """
    Plot the absolute value of gradient magnitudes at collocation points.

    Parameters:
        burgers_solver_swim: Object containing gradient_collection_collocation.
        time_blocks (int): Number of time blocks in the domain.
        ax (Axes): Matplotlib axis to plot on.
    Returns:
        PathCollection: Scatter plot object for colorbar usage.
    """
    for t_s in range(time_blocks - 1):
        pts = burgers_solver_swim.collocation_point_collection[t_s][:, 0]
        grad_pts = burgers_solver_swim.gradient_collection_collocation[t_s][:, 0]
        time_val = (t_s + 1) / time_blocks
        sp = ax.scatter(
            np.full_like(pts, time_val),
            pts.reshape(-1),
            c=np.abs(grad_pts),
            linewidth=1, alpha=1
        )
    ax.set_title("Gradient Magnitudes")
    return sp


def plot_sampling_probabilities(burgers_solver_swim, time_blocks, ax, collocation_points_probabilities):
    """
    Plot sampling probabilities derived from gradients.

    Parameters:
        burgers_solver_swim: Object containing gradient_collection_collocation.
        time_blocks (int): Number of time blocks in the domain.
        ax (Axes): Matplotlib axis to plot on.
        collocation_points_probabilities (callable): Function mapping gradients to probabilities.
    Returns:
        PathCollection: Scatter plot object for colorbar usage.
    """
    for t_s in range(time_blocks - 1):
        pts = burgers_solver_swim.collocation_point_collection[t_s][:, 0]
        grad_pts = burgers_solver_swim.gradient_collection_collocation[t_s][:, 0]
        prob_vals = collocation_points_probabilities(grad_pts)
        time_val = (t_s + 1) / time_blocks
        sp = ax.scatter(
            np.full_like(pts, time_val),
            pts.reshape(-1),
            c=np.abs(prob_vals),
            linewidth=0.1, alpha=0.2
        )
    ax.set_title("Sampling Probabilities")
    return sp


def plot_collocation_diagnostics(burgers_solver_swim, time_blocks, collocation_points_probabilities, figsize=(9, 2)):
    """
    Create a 3-panel plot showing:
      1. Collocation point positions
      2. Gradient magnitudes
      3. Sampling probabilities

    Parameters:
        burgers_solver_swim: Object with collocation and gradient data.
        time_blocks (int): Number of time segments.
        collocation_points_probabilities (callable): Function to compute sampling probabilities.
        figsize (tuple): Size of the overall figure.
    """
    fig = plt.figure(figsize=figsize)
    
    # Plot 1: Raw collocation points
    ax1 = fig.add_subplot(131)
    plot_collocation_points(burgers_solver_swim, time_blocks, ax1)

    # Plot 2: Gradient magnitudes
    ax2 = fig.add_subplot(132)
    sp2 = plot_gradient_magnitudes(burgers_solver_swim, time_blocks, ax2)
    cbar2 = fig.colorbar(sp2, ax=ax2, shrink=0.5, aspect=10)

    # Plot 3: Sampling probabilities
    ax3 = fig.add_subplot(133)
    sp3 = plot_sampling_probabilities(burgers_solver_swim, time_blocks, ax3, collocation_points_probabilities)
    cbar3 = fig.colorbar(sp3, ax=ax3, shrink=0.5, aspect=10)

    plt.tight_layout()
    plt.show()



def plot_3d_surface_with_slices(X, T, u_data, x_eval, t_eval, x_lim, output_file='burgers_3d_sol.png'):
    """
    Plot a 3D surface of u(x,t) along with vertical slices and a contour at the base.

    Parameters:
        X (ndarray): 2D array of x coordinates.
        T (ndarray): 2D array of t coordinates.
        u_data (ndarray): 2D array of solution values u(x, t).
        x_eval (ndarray): 1D array of x values for slicing.
        t_eval (ndarray): 1D array of t values corresponding to T.
        x_lim (tuple): Limits of x-axis (min, max).
        output_file (str): Path to save the output figure.
    """
    fig = plt.figure(figsize=(4, 3))
    ax = fig.add_subplot(111, projection='3d')

    # Plot transparent surface
    ax.plot_surface(X, T, u_data, alpha=0.4)

    # Add filled contour at bottom
    offset = -1.8
    cont = ax.contourf(X, T, u_data, zdir='z', offset=offset, cmap='jet', alpha=0.8, levels=500)

    # Add slices at selected t values
    for t_val in [1, 0.5, 0]:
        add_slice(ax, x_eval, t_eval, t_val, u_data, x_lim, offset)

    # Axes formatting
    ax.set_xlim([X.min(), X.max()])
    ax.set_ylim([T.min(), T.max()])
    ax.set_ylim(ax.get_ylim()[::-1])
    ax.set_zlim([offset - 0.01, u_data.max()])
    ax.view_init(elev=20, azim=210)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    # Colorbar setup
    cbar = fig.colorbar(cont, ax=ax, shrink=0.5, aspect=10)
    cbar.ax.tick_params(labelsize=18)
    cbar.set_ticks([u_data.min(), u_data.max()])

    plt.tight_layout()
    plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
    plt.show()


def add_slice(ax, x_eval, t_eval, t_fixed, u_data, x_lim, offset, color='r'):
    """
    Add a slice (line and bounding box) at a fixed time value to the 3D plot.

    Parameters:
        ax (Axes3D): The 3D axes object to plot on.
        x_eval (ndarray): 1D array of x values.
        t_eval (ndarray): 1D array of t values corresponding to u_data.
        t_fixed (float): The time value where the slice is made.
        u_data (ndarray): 2D array of solution values u(x, t).
        x_lim (tuple): (xmin, xmax) for bounding box.
        offset (float): Z-offset for placing bounding box on contour plane.
        color (str): Line color for the slice.
    """
    y_index = np.argmin(np.abs(t_eval - t_fixed))

    # Rectangle edges at z = offset
    box_edges = [
        [[x_lim[0], t_fixed, offset], [x_lim[1], t_fixed, offset]],
        [[x_lim[1], t_fixed, offset], [x_lim[1], t_fixed, 1]],
        [[x_lim[1], t_fixed, 1], [x_lim[0], t_fixed, 1]],
        [[x_lim[0], t_fixed, 1], [x_lim[0], t_fixed, offset]]
    ]
    ax.add_collection3d(Line3DCollection(box_edges, colors='k', linewidths=2))

    # Slice curve at t = t_fixed
    ax.plot(
        x_eval.reshape(-1),
        np.full_like(x_eval, t_fixed),
        u_data[y_index, :],
        color=color,
        linewidth=2,
        alpha=0.9
    )

def compute_relative_l2_error(u_true, u_pred):
    """
    Compute the relative L2 error between true and predicted solutions.

    Parameters:
        u_true (ndarray): Ground truth solution.
        u_pred (ndarray): Predicted solution (flattened).

    Returns:
        float: Relative L2 error.
    """
    return np.linalg.norm(u_true - u_pred.flatten()[:, None], 2) / np.linalg.norm(u_true, 2)


def plot_spatiotemporal_solutions(u_true, u_model, x_space, filename="burgers_1.pdf", title='Frozen-PINN-swim'):
    """
    Plot the ground truth, model prediction, and absolute error across space and time.

    Parameters:
        u_true (ndarray): Ground truth solution.
        u_model (ndarray): Model-predicted solution.
        x_space (ndarray): Spatial grid.
        filename (str): Output filename for the figure.
    """
    fontsize = 14
    fig, ax = plt.subplots(1, 3, figsize=(7, 3), constrained_layout=True)
    extent = [0, 1, np.min(x_space), np.max(x_space)]
    aspect = 0.3

    sol_img1 = ax[0].imshow(u_true.T, extent=extent, origin='lower', aspect=aspect)
    sol_img2 = ax[1].imshow(u_model.T, extent=extent, origin='lower', aspect=aspect)
    error_img = ax[2].imshow(np.abs(u_model - u_true).T, extent=extent, origin='lower', aspect=aspect)

    for a in ax:
        a.set_xlabel('t', fontsize=fontsize)
    ax[0].set_ylabel('x', fontsize=fontsize)

    for line_pos in [0.25, 0.5, 0.75]:
        ax[0].axvline(x=line_pos, color='k', linestyle='--', linewidth=2)

    sampling_times = np.linspace(0, 1, 10)[1:-1]
    for s_t in sampling_times:
        ax[1].axvline(x=s_t, color='gray', linestyle='dotted', linewidth=3)

    add_colorbar(fig, sol_img1, ax[0], label="Ground truth", scientific=True)
    add_colorbar(fig, sol_img2, ax[1], label="Prediction", scientific=True)
    add_colorbar(fig, error_img, ax[2], label="Error", scientific=False)

    ax[0].set_title('Ground truth', fontsize=fontsize)
    ax[1].set_title(title, fontsize=fontsize)
    ax[2].set_title('Absolute error', fontsize=fontsize)

    fig.savefig(filename)


def add_colorbar(fig, img, ax, label="", scientific=False):
    """
    Add a colorbar to a subplot with optional scientific notation formatting.

    Parameters:
        fig (Figure): Matplotlib figure object.
        img (AxesImage): The image to attach the colorbar to.
        ax (Axes): Axis to place the colorbar under.
        label (str): Optional label for the colorbar.
        scientific (bool): Whether to use scientific notation.
    """
    if scientific:
        formatter = ScalarFormatter()
        formatter.set_scientific(True)
        formatter.set_useMathText(False)
        formatter.set_powerlimits((-8, 8))
        cbar = fig.colorbar(img, ax=ax, location='bottom', format=formatter, fraction=0.049)
    else:
        cbar = fig.colorbar(img, ax=ax, location='bottom', format='%.0e', fraction=0.049)

    cbar.locator = ticker.MaxNLocator(nbins=2)
    cbar.update_ticks()


def plot_temporal_slices(u_true, u_model, x_eval, filename="burgers_2.pdf"):
    """
    Plot solution slices at t = 0.25, 0.50, and 0.75 comparing model and ground truth.

    Parameters:
        u_true (ndarray): Ground truth solution.
        u_model (ndarray): Predicted solution.
        x_eval (ndarray): Evaluation grid in space.
        filename (str): Output filename for the figure.
    """
    fontsize = 14
    fig, ax = plt.subplots(1, 3, figsize=(7, 3), constrained_layout=True)

    time_indices = [25, 50, 75]
    time_labels = [0.25, 0.50, 0.75]

    for i, (t_idx, t_label) in enumerate(zip(time_indices, time_labels)):
        ax[i].plot(x_eval, u_true[t_idx, :], 'b-', linewidth=2, label='Ground truth')
        ax[i].plot(x_eval, u_model[t_idx, :], 'r--', linewidth=2, label='Frozen-PINN-swim (resampling)')
        ax[i].set_xlabel('$x$', fontsize=fontsize)
        ax[i].set_ylabel('$u(t,x)$', fontsize=fontsize)
        ax[i].set_title(f'$t = {t_label}$', fontsize=fontsize)
        ax[i].axis('square')
        ax[i].set_xlim([-1.1, 1.1])
        ax[i].set_ylim([-1.1, 1.1])

    fig.legend(*ax[1].get_legend_handles_labels(), loc='upper center', ncol=2, fontsize=fontsize, frameon=False)
    fig.savefig(filename, bbox_inches='tight')


def plot(x:npt.NDArray[np.float64], u:npt.NDArray[np.float64], 
        title:str=None, savefig:bool=False, fontsize:int=14, 
        timesteps:list= [0, 30, 60, 99], figsize:tuple=(6,2), 
        cmap_offset:np.float64=0., figname:str='fig.pdf',
        marker_size:int = 0.4, extent:bool=True):
    """ Make a scatter plot of the solution.

    Args:
        x (npt.NDArray[np.float64]): Spatial data points (dimensions: n_points * n_dim)
        u (npt.NDArray[np.float64]): Solution (dimensions: n_timesteps * n_points * n_dim)
        timesteps (list): Index numbers/time-steps at which the solution is visualized.
        title (str): title of the plot.
        savefig (bool, optional): Flag that decides to save the figure. Defaults to False.
        fontsize (int, optional): fontsize for plots. Defaults to 14.
        figsize (tuple, optional): figure size. Defaults to (6,2).
        cmap_offset (np.float, optional): offset for colorbar. Defaults to 0..
        figname (str, optional): Name of the saved figure. Defaults to 'fig.pdf'.
    """
    fontsize = fontsize
    fig, ax = plt.subplots(1, 4, figsize=figsize, sharex=True, sharey=True, constrained_layout=True)
    cmap = 'jet'

    # Calculate the common vmin and vmax for all datasets
    if extent:
        vmin = np.min(u[:, 0] + cmap_offset)
        vmax = np.max(u[:, 0] - cmap_offset) 
        sol_img0 = ax[0].scatter(x[:,0], x[:, 1], c=u[timesteps[0]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
        sol_img1 = ax[1].scatter(x[:,0], x[:, 1], c=u[timesteps[1]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
        sol_img2 = ax[2].scatter(x[:,0], x[:, 1], c=u[timesteps[2]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
        sol_img3 = ax[3].scatter(x[:,0], x[:, 1], c=u[timesteps[3]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    else:
        sol_img0 = ax[0].scatter(x[:,0], x[:, 1], c=u[timesteps[0]], cmap=cmap, s=marker_size)
        sol_img1 = ax[1].scatter(x[:,0], x[:, 1], c=u[timesteps[1]], cmap=cmap, s=marker_size)
        sol_img2 = ax[2].scatter(x[:,0], x[:, 1], c=u[timesteps[2]], cmap=cmap, s=marker_size)
        sol_img3 = ax[3].scatter(x[:,0], x[:, 1], c=u[timesteps[3]], cmap=cmap, s=marker_size)
    
    ax[0].tick_params(axis='both', labelsize=fontsize)
    ax[1].tick_params(axis='both', labelsize=fontsize)
    ax[2].tick_params(axis='both', labelsize=fontsize)
    ax[3].tick_params(axis='both', labelsize=fontsize)
    fig.supxlabel(r'$x_1$')
    fig.supylabel(r'$x_2$')
    ax[0].set_title('t = 0', fontsize=fontsize)
    ax[1].set_title('t = 0.3', fontsize=fontsize)
    ax[2].set_title('t = 0.6', fontsize=fontsize)
    ax[3].set_title('t = 1', fontsize=fontsize)
    cbar_true_0 = fig.colorbar(sol_img0, ax= ax, location='right', aspect=8)
    cbar_true_0.ax.tick_params(labelsize=fontsize)
    plt.tick_params(axis='both') 
    if title is not None:
        plt.suptitle(title, fontsize=fontsize)
    plt.tick_params(axis='both', labelsize=fontsize)
    if savefig:
        plt.savefig(figname)
    plt.show()


def plot_minimal(x:npt.NDArray[np.float64], u:npt.NDArray[np.float64], 
        title:str=None, savefig:bool=False, fontsize:int=14, 
        timesteps:list= [0, 30, 60, 99], figsize:tuple=(6,2), 
        cmap_offset:np.float64=0., figname:str='fig.pdf',
        marker_size:int = 0.4):
    """ Make a scatter plot of the solution.

    Args:
        x (npt.NDArray[np.float64]): Spatial data points (dimensions: n_points * n_dim)
        u (npt.NDArray[np.float64]): Solution (dimensions: n_timesteps * n_points * n_dim)
        timesteps (list): Index numbers/time-steps at which the solution is visualized.
        title (str): title of the plot.
        savefig (bool, optional): Flag that decides to save the figure. Defaults to False.
        fontsize (int, optional): fontsize for plots. Defaults to 14.
        figsize (tuple, optional): figure size. Defaults to (6,2).
        cmap_offset (np.float, optional): offset for colorbar. Defaults to 0..
        figname (str, optional): Name of the saved figure. Defaults to 'fig.pdf'.
    """
    fig, ax = plt.subplots(1, 4, figsize=figsize, sharex=True, sharey=True, constrained_layout=True)
    cmap = 'jet'
    savefig = True
    # Calculate the common vmin and vmax for all datasets
    vmin = np.min(u[:, 0] + cmap_offset)
    vmax = np.max(u[:, 0] - cmap_offset) 
    sol_img0 = ax[0].scatter(x[:,0], x[:, 1], c=u[timesteps[0]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    #ax[0].tick_params(axis='both', labelsize=fontsize)
    sol_img1 = ax[1].scatter(x[:,0], x[:, 1], c=u[timesteps[1]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    #ax[1].tick_params(axis='both', labelsize=fontsize)
    sol_img2 = ax[2].scatter(x[:,0], x[:, 1], c=u[timesteps[2]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    #ax[2].tick_params(axis='both', labelsize=fontsize)
    sol_img3 = ax[3].scatter(x[:,0], x[:, 1], c=u[timesteps[3]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    #ax[3].tick_params(axis='both', labelsize=fontsize)
    fig.supxlabel(r'$x_1$')
    fig.supylabel(r'$x_2$')
    ax[0].set_title('t = 0', fontsize=fontsize)
    ax[1].set_title('t = 0.3', fontsize=fontsize)
    ax[2].set_title('t = 0.6', fontsize=fontsize)
    ax[3].set_title('t = 1', fontsize=fontsize)
    cbar_true_0 = fig.colorbar(sol_img0, ax= ax, location='right', aspect=8)
    cbar_true_0.ax.tick_params(labelsize=fontsize)
    #plt.tick_params(axis='both') 
    ax[0].set_xticks([])
    ax[1].set_xticks([])
    ax[2].set_xticks([])
    ax[3].set_xticks([])
    ax[0].set_yticks([])
    ax[1].set_yticks([])
    ax[2].set_yticks([])
    ax[3].set_yticks([])
    if title is not None:
        plt.suptitle(title, fontsize=fontsize)
    plt.tick_params(axis='both', labelsize=fontsize)
    if savefig:
        plt.savefig(figname)
    plt.show()

def plot_error(x:npt.NDArray[np.float64], u_true:npt.NDArray[np.float64], u_nn:npt.NDArray[np.float64], 
               title:str, timesteps:list= [0, 30, 60, 99], figsize:tuple=(8,3), fontsize:int=14, 
               savefig:bool=False, figname:str='fig.pdf',
               marker_size:int=0.4):
    """ Plot the absolute error.

    Args:
        x (npt.NDArray[np.float64]): Spatial data points (dimensions: n_points * n_dim)
        u_true (npt.NDArray[np.float64]): True solution (dimensions: n_timesteps * n_points * n_dim)
        u_nn (npt.NDArray[np.float64]): Neural network solution (dimensions: n_timesteps * n_points * n_dim)
        timesteps (list): Index numbers /time-steps at which the solution is visualized.
        title (str): title of the plot.
        savefig (bool, optional): Flag that decides to save the figure. Defaults to False.
        fontsize (int, optional): fontsize for plots. Defaults to 14.
        figsize (tuple, optional): figure size. Defaults to (6,2).
        cmap_offset (np.float, optional): offset for colorbar. Defaults to 0..
        figname (str, optional): Name of the saved figure. Defaults to 'fig.pdf'.
    """
    fig, ax = plt.subplots(1, 4, figsize=figsize, sharex=True, sharey=True, constrained_layout=True)
    cmap = 'jet'
    fontsize=fontsize
    sol_img0 = ax[0].scatter(x[:,0], x[:, 1], c=np.abs(u_true[timesteps[0]]- u_nn[timesteps[0]]), cmap=cmap, s=marker_size)
    ax[0].tick_params(axis='both', labelsize=fontsize)
    ax[0].set_xlabel(r'$x_1$',fontsize=fontsize)
    sol_img1 = ax[1].scatter(x[:,0], x[:, 1], c=np.abs(u_true[timesteps[1]]- u_nn[timesteps[1]]), cmap=cmap, s=marker_size)
    ax[1].set_xlabel(r'$x_1$',fontsize=fontsize)
    ax[1].tick_params(axis='both', labelsize=fontsize)
    sol_img2 = ax[2].scatter(x[:,0], x[:, 1], c=np.abs(u_true[timesteps[2]]- u_nn[timesteps[2]]), cmap=cmap, s=marker_size)
    ax[2].tick_params(axis='both', labelsize=fontsize)
    ax[2].set_xlabel(r'$x_1$',fontsize=fontsize)
    sol_img3 = ax[3].scatter(x[:,0], x[:, 1], c=np.abs(u_true[timesteps[3]]- u_nn[timesteps[3]]), cmap=cmap, s=marker_size)
    ax[3].tick_params(axis='both', labelsize=fontsize)
    ax[3].set_xlabel(r'$x_1$' ,fontsize=fontsize)

    ax[0].set_title('t = 0', fontsize=fontsize)
    ax[1].set_title('t = 0.3', fontsize=fontsize)
    ax[2].set_title('t = 0.6', fontsize=fontsize)
    ax[3].set_title('t = 1', fontsize=fontsize)

    cbar_true_0 = fig.colorbar(sol_img0, ax= ax[0], location='bottom', aspect=8)
    cbar_true_0.ax.ticklabel_format(style='scientific', axis='x', scilimits=(0, 0))
    cbar_true_0.ax.tick_params(labelsize=fontsize)
    cbar_true_0.ax.xaxis.get_offset_text().set_fontsize(fontsize)  # Set fontsize here

    cbar_true_1 = fig.colorbar(sol_img1, ax= ax[1], location='bottom', aspect=8)
    cbar_true_1.ax.ticklabel_format(style='scientific', axis='x', scilimits=(0, 0))
    cbar_true_1.ax.tick_params(labelsize=fontsize)
    cbar_true_1.ax.xaxis.get_offset_text().set_fontsize(fontsize)  # Set fontsize here

    cbar_true_2 = fig.colorbar(sol_img2, ax= ax[2], location='bottom', aspect=8)
    cbar_true_2.ax.ticklabel_format(style='scientific', axis='x', scilimits=(0, 0))
    cbar_true_2.ax.tick_params(labelsize=fontsize)
    cbar_true_2.ax.xaxis.get_offset_text().set_fontsize(fontsize)  # Set fontsize here

    cbar_true_3 = fig.colorbar(sol_img3, ax= ax[3], location='bottom', aspect=8)
    cbar_true_3.ax.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
    cbar_true_3.ax.tick_params(labelsize=fontsize)
    cbar_true_3.ax.xaxis.get_offset_text().set_fontsize(fontsize)  # Set fontsize here
    #fig.supxlabel('X')
    fig.supylabel(r'$x_2$', y=0.65)
    plt.tick_params(axis='both', labelsize=fontsize)
    plt.suptitle(title, fontsize=fontsize)
    if savefig:
        plt.savefig(figname)
    plt.show()


def sample_boundary_lhs(d, n_samples, bounds=(-1, 1)):
    """
    Sample points on the boundary of a d-dimensional hypercube using
    Latin Hypercube Sampling.

    Parameters:
        d (int): Number of dimensions of the hypercube.
        n_samples (int): Number of points to sample.
        bounds (tuple): Lower and upper bounds of the hypercube (default is (-1, 1)).

    Returns:
        np.ndarray: An (n_samples, d) array of sampled points on the boundary.
        np.ndarray: An array of boundary labels indicating which boundary the point lies on.
   
    """
    lower, upper = bounds

    # Generate Latin Hypercube samples in [0, 1]^d
    lhs = np.zeros((n_samples, d))
    for i in range(d):
        perm = np.random.permutation(n_samples)
        lhs[:, i] = (perm + np.random.uniform(size=n_samples)) / n_samples

    # Scale samples to the range [lower, upper]
    lhs = lower + (upper - lower) * lhs

    # Move points to the boundary
    boundary_points = []
    boundary_labels = []

    for i in range(n_samples):
        # Select a random dimension to place on the boundary
        boundary_dim = np.random.randint(d)
        # Randomly choose between the lower or upper boundary for that dimension
        side = np.random.choice([lower, upper])
        point = lhs[i].copy()
        point[boundary_dim] = side  # Set the selected dimension to the boundary
        boundary_points.append(point)
        # Label the boundary as (dimension, side)
        boundary_labels.append((boundary_dim, side))
    return np.array(boundary_points), np.array(boundary_labels)


def sample_boundary_lhs_ball(d, n_samples):
    """
    Sample points on the boundary of a d-dimensional unit ball using
    Latin Hypercube Sampling.

    Parameters:
        d (int): Number of dimensions of the unit ball.
        n_samples (int): Number of points to sample.

    Returns:
        np.ndarray: An (n_samples, d) array of sampled points on the boundary.
    """
    # Generate LHS samples in [0, 1]^d
    sampler = LatinHypercube(d)
    lhs = sampler.random(n_samples)
    
    # Map to a d-dimensional normal distribution
    normal_samples = np.random.normal(size=(n_samples, d))
    norms = np.linalg.norm(normal_samples, axis=1, keepdims=True)
    
    # Normalize to lie on the unit sphere
    boundary_points = normal_samples / norms
    
    return boundary_points

def sample_interior_lhs_ball(d, n_samples):
    """
    Sample points inside a d-dimensional unit ball using
    Latin Hypercube Sampling.

    Parameters:
        d (int): Number of dimensions of the unit ball.
        n_samples (int): Number of points to sample.

    Returns:
        np.ndarray: An (n_samples, d) array of sampled points inside the unit ball.
    """
    # Generate LHS samples in [0, 1]^d
    sampler = LatinHypercube(d)
    lhs = sampler.random(n_samples)
    
    # Map to a d-dimensional normal distribution
    normal_samples = np.random.normal(size=(n_samples, d))
    norms = np.linalg.norm(normal_samples, axis=1, keepdims=True)
    
    # Normalize to lie on the unit sphere
    unit_vectors = normal_samples / norms
    
    # Generate radii using the proper distribution for uniformity in a unit ball
    radii = np.random.uniform(0, 1, size=(n_samples, 1)) ** (1/d)
    
    # Scale unit vectors by radii to obtain uniform distribution inside the ball
    inside_points = unit_vectors * radii
    
    return inside_points

def visualize_samples(samples):
    """
    Visualize the sampled points on a unit sphere (only for 3D).

    Parameters:
        samples (np.ndarray): Sampled points on the boundary of the unit ball.
    """
    if samples.shape[1] != 3:
        print("Visualization only supported for 3D samples.")
        return
    
    fig = plt.figure(figsize=(3,3))
    ax = fig.add_subplot(111, projection='3d')
    
    ax.scatter(samples[:, 0], samples[:, 1], samples[:, 2], c='b', marker='o', alpha=0.2)
    
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    #ax.set_title("LHS Samples on Unit Sphere")
    plt.show()

def plot_boundary_points_3d(points, labels):
    """
    Plot boundary points on a 3D unit cube with different colors for boundaries.

    Parameters:
        points (np.ndarray): The sampled points on the boundaries (n_samples, 3).
        labels (np.ndarray): The boundary labels for each point.
    """
    fig = plt.figure(figsize=(3,3))
    ax = fig.add_subplot(111, projection='3d')

    # Assign a unique color to each boundary
    unique_labels = np.unique(labels, axis=0)
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_labels)))

    # Plot each boundary with a different color
    for i, label in enumerate(unique_labels):
        mask = np.all(labels == label, axis=1)
        ax.scatter(
            points[mask, 0], points[mask, 1], points[mask, 2],
            label=f"Boundary {label[0]}: {label[1]}",
            color=colors[i], s=30
        )

    # Add labels and legend
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    #ax.legend()
    plt.show()
    
def sample_interior_lhs(d, n_samples, bounds=(-1, 1)):
    """
    Sample points in the interior of a d-dimensional hypercube with given bounds
    using Latin Hypercube Sampling.

    Parameters:
        d (int): Number of dimensions of the hypercube.
        n_samples (int): Number of points to sample.
        bounds (tuple): Lower and upper bounds of the hypercube (default is (-1, 1)).

    Returns:
        np.ndarray: An (n_samples, d) array of sampled points.
    """
    lower, upper = bounds

    # Create Latin Hypercube Sampling in the range [0, 1]^d
    samples = np.zeros((n_samples, d))
    for i in range(d):
        perm = np.random.permutation(n_samples)
        samples[:, i] = (perm + np.random.uniform(size=n_samples)) / n_samples

    # Scale samples to the desired bounds
    samples = lower + (upper - lower) * samples
    return samples

'''
def sample_interior_lhs(d, n_samples):
    """
    Sample points in the interior of a d-dimensional unit hypercube using
    Latin Hypercube Sampling.

    Parameters:
        d (int): Number of dimensions of the unit hypercube.
        n_samples (int): Number of points to sample.

    Returns:
        np.ndarray: An (n_samples, d) array of sampled points.
    """
    # Create an empty array for storing the sampled points
    samples = np.zeros((n_samples, d))
    
    # Perform Latin Hypercube Sampling for each dimension
    for i in range(d):
        # Divide the unit interval [0, 1) into n_samples segments
        perm = np.random.permutation(n_samples)
        
        # Sample randomly within each segment
        samples[:, i] = (perm + np.random.uniform(size=n_samples)) / n_samples
    
    return samples
'''

def plot_interior_points(data_interior):
    # Plot the points in 3D
    fig = plt.figure(figsize=(3,3))
    ax = fig.add_subplot(111, projection='3d')

    # Scatter plot of sampled points
    ax.scatter(data_interior[:, 0], data_interior[:, 1], data_interior[:, 2],
            color='blue', s=30, alpha=0.7, label='Sampled Points')

    # Labels and grid
    ax.set_title('Sampled Points in the Interior of a 3D Unit Cube')
    ax.set_xlabel('X-axis')
    ax.set_ylabel('Y-axis')
    ax.set_zlabel('Z-axis')
    ax.grid(True)
    ax.legend()

    plt.show()