import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy.typing as npt
import numpy as np
from scipy.stats.qmc import LatinHypercube


def plot(x:npt.NDArray[np.float64], u:npt.NDArray[np.float64], 
        title:str=None, savefig:bool=False, fontsize:int=14, 
        timesteps:list= [0, 30, 60, 99], figsize:tuple=(6,2), 
        cmap_offset:np.float64=0., figname:str='fig.pdf',
        marker_size:int = 0.4, extent:bool=True):
    """ Make a scatter plot of the solution.

    Args:
        x (npt.NDArray[np.float64]): Spatial data points (dimensions: n_points * n_dim)
        u (npt.NDArray[np.float64]): Solution (dimensions: n_timesteps * n_points * n_dim)
        timesteps (list): Index numbers/time-steps at which the solution is visualized.
        title (str): title of the plot.
        savefig (bool, optional): Flag that decides to save the figure. Defaults to False.
        fontsize (int, optional): fontsize for plots. Defaults to 14.
        figsize (tuple, optional): figure size. Defaults to (6,2).
        cmap_offset (np.float, optional): offset for colorbar. Defaults to 0..
        figname (str, optional): Name of the saved figure. Defaults to 'fig.pdf'.
    """
    fontsize = fontsize
    fig, ax = plt.subplots(1, 4, figsize=figsize, sharex=True, sharey=True, constrained_layout=True)
    cmap = 'jet'

    # Calculate the common vmin and vmax for all datasets
    if extent:
        vmin = np.min(u[:, 0] + cmap_offset)
        vmax = np.max(u[:, 0] - cmap_offset) 
        sol_img0 = ax[0].scatter(x[:,0], x[:, 1], c=u[timesteps[0]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
        sol_img1 = ax[1].scatter(x[:,0], x[:, 1], c=u[timesteps[1]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
        sol_img2 = ax[2].scatter(x[:,0], x[:, 1], c=u[timesteps[2]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
        sol_img3 = ax[3].scatter(x[:,0], x[:, 1], c=u[timesteps[3]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    else:
        sol_img0 = ax[0].scatter(x[:,0], x[:, 1], c=u[timesteps[0]], cmap=cmap, s=marker_size)
        sol_img1 = ax[1].scatter(x[:,0], x[:, 1], c=u[timesteps[1]], cmap=cmap, s=marker_size)
        sol_img2 = ax[2].scatter(x[:,0], x[:, 1], c=u[timesteps[2]], cmap=cmap, s=marker_size)
        sol_img3 = ax[3].scatter(x[:,0], x[:, 1], c=u[timesteps[3]], cmap=cmap, s=marker_size)
    
    ax[0].tick_params(axis='both', labelsize=fontsize)
    ax[1].tick_params(axis='both', labelsize=fontsize)
    ax[2].tick_params(axis='both', labelsize=fontsize)
    ax[3].tick_params(axis='both', labelsize=fontsize)
    fig.supxlabel(r'$x_1$')
    fig.supylabel(r'$x_2$')
    ax[0].set_title('t = 0', fontsize=fontsize)
    ax[1].set_title('t = 0.3', fontsize=fontsize)
    ax[2].set_title('t = 0.6', fontsize=fontsize)
    ax[3].set_title('t = 1', fontsize=fontsize)
    cbar_true_0 = fig.colorbar(sol_img0, ax= ax, location='right', aspect=8)
    cbar_true_0.ax.tick_params(labelsize=fontsize)
    plt.tick_params(axis='both') 
    if title is not None:
        plt.suptitle(title, fontsize=fontsize)
    plt.tick_params(axis='both', labelsize=fontsize)
    if savefig:
        plt.savefig(figname)
    plt.show()


def plot_minimal(x:npt.NDArray[np.float64], u:npt.NDArray[np.float64], 
        title:str=None, savefig:bool=False, fontsize:int=14, 
        timesteps:list= [0, 30, 60, 99], figsize:tuple=(6,2), 
        cmap_offset:np.float64=0., figname:str='fig.pdf',
        marker_size:int = 0.4):
    """ Make a scatter plot of the solution.

    Args:
        x (npt.NDArray[np.float64]): Spatial data points (dimensions: n_points * n_dim)
        u (npt.NDArray[np.float64]): Solution (dimensions: n_timesteps * n_points * n_dim)
        timesteps (list): Index numbers/time-steps at which the solution is visualized.
        title (str): title of the plot.
        savefig (bool, optional): Flag that decides to save the figure. Defaults to False.
        fontsize (int, optional): fontsize for plots. Defaults to 14.
        figsize (tuple, optional): figure size. Defaults to (6,2).
        cmap_offset (np.float, optional): offset for colorbar. Defaults to 0..
        figname (str, optional): Name of the saved figure. Defaults to 'fig.pdf'.
    """
    fig, ax = plt.subplots(1, 4, figsize=figsize, sharex=True, sharey=True, constrained_layout=True)
    cmap = 'jet'
    savefig = True
    # Calculate the common vmin and vmax for all datasets
    vmin = np.min(u[:, 0] + cmap_offset)
    vmax = np.max(u[:, 0] - cmap_offset) 
    sol_img0 = ax[0].scatter(x[:,0], x[:, 1], c=u[timesteps[0]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    #ax[0].tick_params(axis='both', labelsize=fontsize)
    sol_img1 = ax[1].scatter(x[:,0], x[:, 1], c=u[timesteps[1]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    #ax[1].tick_params(axis='both', labelsize=fontsize)
    sol_img2 = ax[2].scatter(x[:,0], x[:, 1], c=u[timesteps[2]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    #ax[2].tick_params(axis='both', labelsize=fontsize)
    sol_img3 = ax[3].scatter(x[:,0], x[:, 1], c=u[timesteps[3]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    #ax[3].tick_params(axis='both', labelsize=fontsize)
    fig.supxlabel(r'$x_1$')
    fig.supylabel(r'$x_2$')
    ax[0].set_title('t = 0', fontsize=fontsize)
    ax[1].set_title('t = 0.3', fontsize=fontsize)
    ax[2].set_title('t = 0.6', fontsize=fontsize)
    ax[3].set_title('t = 1', fontsize=fontsize)
    cbar_true_0 = fig.colorbar(sol_img0, ax= ax, location='right', aspect=8)
    cbar_true_0.ax.tick_params(labelsize=fontsize)
    #plt.tick_params(axis='both') 
    ax[0].set_xticks([])
    ax[1].set_xticks([])
    ax[2].set_xticks([])
    ax[3].set_xticks([])
    ax[0].set_yticks([])
    ax[1].set_yticks([])
    ax[2].set_yticks([])
    ax[3].set_yticks([])
    if title is not None:
        plt.suptitle(title, fontsize=fontsize)
    plt.tick_params(axis='both', labelsize=fontsize)
    if savefig:
        plt.savefig(figname)
    plt.show()

def plot_error(x:npt.NDArray[np.float64], u_true:npt.NDArray[np.float64], u_nn:npt.NDArray[np.float64], 
               title:str, timesteps:list= [0, 30, 60, 99], figsize:tuple=(8,3), fontsize:int=14, 
               savefig:bool=False, figname:str='fig.pdf',
               marker_size:int=0.4):
    """ Plot the absolute error.

    Args:
        x (npt.NDArray[np.float64]): Spatial data points (dimensions: n_points * n_dim)
        u_true (npt.NDArray[np.float64]): True solution (dimensions: n_timesteps * n_points * n_dim)
        u_nn (npt.NDArray[np.float64]): Neural network solution (dimensions: n_timesteps * n_points * n_dim)
        timesteps (list): Index numbers /time-steps at which the solution is visualized.
        title (str): title of the plot.
        savefig (bool, optional): Flag that decides to save the figure. Defaults to False.
        fontsize (int, optional): fontsize for plots. Defaults to 14.
        figsize (tuple, optional): figure size. Defaults to (6,2).
        cmap_offset (np.float, optional): offset for colorbar. Defaults to 0..
        figname (str, optional): Name of the saved figure. Defaults to 'fig.pdf'.
    """
    fig, ax = plt.subplots(1, 4, figsize=figsize, sharex=True, sharey=True, constrained_layout=True)
    cmap = 'jet'
    fontsize=fontsize
    sol_img0 = ax[0].scatter(x[:,0], x[:, 1], c=np.abs(u_true[timesteps[0]]- u_nn[timesteps[0]]), cmap=cmap, s=marker_size)
    ax[0].tick_params(axis='both', labelsize=fontsize)
    ax[0].set_xlabel(r'$x_1$',fontsize=fontsize)
    sol_img1 = ax[1].scatter(x[:,0], x[:, 1], c=np.abs(u_true[timesteps[1]]- u_nn[timesteps[1]]), cmap=cmap, s=marker_size)
    ax[1].set_xlabel(r'$x_1$',fontsize=fontsize)
    ax[1].tick_params(axis='both', labelsize=fontsize)
    sol_img2 = ax[2].scatter(x[:,0], x[:, 1], c=np.abs(u_true[timesteps[2]]- u_nn[timesteps[2]]), cmap=cmap, s=marker_size)
    ax[2].tick_params(axis='both', labelsize=fontsize)
    ax[2].set_xlabel(r'$x_1$',fontsize=fontsize)
    sol_img3 = ax[3].scatter(x[:,0], x[:, 1], c=np.abs(u_true[timesteps[3]]- u_nn[timesteps[3]]), cmap=cmap, s=marker_size)
    ax[3].tick_params(axis='both', labelsize=fontsize)
    ax[3].set_xlabel(r'$x_1$' ,fontsize=fontsize)

    ax[0].set_title('t = 0', fontsize=fontsize)
    ax[1].set_title('t = 0.3', fontsize=fontsize)
    ax[2].set_title('t = 0.6', fontsize=fontsize)
    ax[3].set_title('t = 1', fontsize=fontsize)

    cbar_true_0 = fig.colorbar(sol_img0, ax= ax[0], location='bottom', aspect=8)
    cbar_true_0.ax.ticklabel_format(style='scientific', axis='x', scilimits=(0, 0))
    cbar_true_0.ax.tick_params(labelsize=fontsize)
    cbar_true_0.ax.xaxis.get_offset_text().set_fontsize(fontsize)  # Set fontsize here

    cbar_true_1 = fig.colorbar(sol_img1, ax= ax[1], location='bottom', aspect=8)
    cbar_true_1.ax.ticklabel_format(style='scientific', axis='x', scilimits=(0, 0))
    cbar_true_1.ax.tick_params(labelsize=fontsize)
    cbar_true_1.ax.xaxis.get_offset_text().set_fontsize(fontsize)  # Set fontsize here

    cbar_true_2 = fig.colorbar(sol_img2, ax= ax[2], location='bottom', aspect=8)
    cbar_true_2.ax.ticklabel_format(style='scientific', axis='x', scilimits=(0, 0))
    cbar_true_2.ax.tick_params(labelsize=fontsize)
    cbar_true_2.ax.xaxis.get_offset_text().set_fontsize(fontsize)  # Set fontsize here

    cbar_true_3 = fig.colorbar(sol_img3, ax= ax[3], location='bottom', aspect=8)
    cbar_true_3.ax.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
    cbar_true_3.ax.tick_params(labelsize=fontsize)
    cbar_true_3.ax.xaxis.get_offset_text().set_fontsize(fontsize)  # Set fontsize here
    #fig.supxlabel('X')
    fig.supylabel(r'$x_2$', y=0.65)
    plt.tick_params(axis='both', labelsize=fontsize)
    plt.suptitle(title, fontsize=fontsize)
    if savefig:
        plt.savefig(figname)
    plt.show()


def sample_boundary_lhs(d, n_samples, bounds=(-1, 1)):
    """
    Sample points on the boundary of a d-dimensional hypercube using
    Latin Hypercube Sampling.

    Parameters:
        d (int): Number of dimensions of the hypercube.
        n_samples (int): Number of points to sample.
        bounds (tuple): Lower and upper bounds of the hypercube (default is (-1, 1)).

    Returns:
        np.ndarray: An (n_samples, d) array of sampled points on the boundary.
        np.ndarray: An array of boundary labels indicating which boundary the point lies on.
   
    """
    lower, upper = bounds

    # Generate Latin Hypercube samples in [0, 1]^d
    lhs = np.zeros((n_samples, d))
    for i in range(d):
        perm = np.random.permutation(n_samples)
        lhs[:, i] = (perm + np.random.uniform(size=n_samples)) / n_samples

    # Scale samples to the range [lower, upper]
    lhs = lower + (upper - lower) * lhs

    # Move points to the boundary
    boundary_points = []
    boundary_labels = []

    for i in range(n_samples):
        # Select a random dimension to place on the boundary
        boundary_dim = np.random.randint(d)
        # Randomly choose between the lower or upper boundary for that dimension
        side = np.random.choice([lower, upper])
        point = lhs[i].copy()
        point[boundary_dim] = side  # Set the selected dimension to the boundary
        boundary_points.append(point)
        # Label the boundary as (dimension, side)
        boundary_labels.append((boundary_dim, side))
    return np.array(boundary_points), np.array(boundary_labels)


def sample_boundary_lhs_ball(d, n_samples):
    """
    Sample points on the boundary of a d-dimensional unit ball using
    Latin Hypercube Sampling.

    Parameters:
        d (int): Number of dimensions of the unit ball.
        n_samples (int): Number of points to sample.

    Returns:
        np.ndarray: An (n_samples, d) array of sampled points on the boundary.
    """
    # Generate LHS samples in [0, 1]^d
    sampler = LatinHypercube(d)
    lhs = sampler.random(n_samples)
    
    # Map to a d-dimensional normal distribution
    normal_samples = np.random.normal(size=(n_samples, d))
    norms = np.linalg.norm(normal_samples, axis=1, keepdims=True)
    
    # Normalize to lie on the unit sphere
    boundary_points = normal_samples / norms
    
    return boundary_points

def sample_interior_lhs_ball(d, n_samples):
    """
    Sample points inside a d-dimensional unit ball using
    Latin Hypercube Sampling.

    Parameters:
        d (int): Number of dimensions of the unit ball.
        n_samples (int): Number of points to sample.

    Returns:
        np.ndarray: An (n_samples, d) array of sampled points inside the unit ball.
    """
    # Generate LHS samples in [0, 1]^d
    sampler = LatinHypercube(d)
    lhs = sampler.random(n_samples)
    
    # Map to a d-dimensional normal distribution
    normal_samples = np.random.normal(size=(n_samples, d))
    norms = np.linalg.norm(normal_samples, axis=1, keepdims=True)
    
    # Normalize to lie on the unit sphere
    unit_vectors = normal_samples / norms
    
    # Generate radii using the proper distribution for uniformity in a unit ball
    radii = np.random.uniform(0, 1, size=(n_samples, 1)) ** (1/d)
    
    # Scale unit vectors by radii to obtain uniform distribution inside the ball
    inside_points = unit_vectors * radii
    
    return inside_points

def visualize_samples(samples):
    """
    Visualize the sampled points on a unit sphere (only for 3D).

    Parameters:
        samples (np.ndarray): Sampled points on the boundary of the unit ball.
    """
    if samples.shape[1] != 3:
        print("Visualization only supported for 3D samples.")
        return
    
    fig = plt.figure(figsize=(3,3))
    ax = fig.add_subplot(111, projection='3d')
    
    ax.scatter(samples[:, 0], samples[:, 1], samples[:, 2], c='b', marker='o', alpha=0.2)
    
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    #ax.set_title("LHS Samples on Unit Sphere")
    plt.show()

def plot_boundary_points_3d(points, labels):
    """
    Plot boundary points on a 3D unit cube with different colors for boundaries.

    Parameters:
        points (np.ndarray): The sampled points on the boundaries (n_samples, 3).
        labels (np.ndarray): The boundary labels for each point.
    """
    fig = plt.figure(figsize=(3,3))
    ax = fig.add_subplot(111, projection='3d')

    # Assign a unique color to each boundary
    unique_labels = np.unique(labels, axis=0)
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_labels)))

    # Plot each boundary with a different color
    for i, label in enumerate(unique_labels):
        mask = np.all(labels == label, axis=1)
        ax.scatter(
            points[mask, 0], points[mask, 1], points[mask, 2],
            label=f"Boundary {label[0]}: {label[1]}",
            color=colors[i], s=30
        )

    # Add labels and legend
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    #ax.legend()
    plt.show()
    
def sample_interior_lhs(d, n_samples, bounds=(-1, 1)):
    """
    Sample points in the interior of a d-dimensional hypercube with given bounds
    using Latin Hypercube Sampling.

    Parameters:
        d (int): Number of dimensions of the hypercube.
        n_samples (int): Number of points to sample.
        bounds (tuple): Lower and upper bounds of the hypercube (default is (-1, 1)).

    Returns:
        np.ndarray: An (n_samples, d) array of sampled points.
    """
    lower, upper = bounds

    # Create Latin Hypercube Sampling in the range [0, 1]^d
    samples = np.zeros((n_samples, d))
    for i in range(d):
        perm = np.random.permutation(n_samples)
        samples[:, i] = (perm + np.random.uniform(size=n_samples)) / n_samples

    # Scale samples to the desired bounds
    samples = lower + (upper - lower) * samples
    return samples

'''
def sample_interior_lhs(d, n_samples):
    """
    Sample points in the interior of a d-dimensional unit hypercube using
    Latin Hypercube Sampling.

    Parameters:
        d (int): Number of dimensions of the unit hypercube.
        n_samples (int): Number of points to sample.

    Returns:
        np.ndarray: An (n_samples, d) array of sampled points.
    """
    # Create an empty array for storing the sampled points
    samples = np.zeros((n_samples, d))
    
    # Perform Latin Hypercube Sampling for each dimension
    for i in range(d):
        # Divide the unit interval [0, 1) into n_samples segments
        perm = np.random.permutation(n_samples)
        
        # Sample randomly within each segment
        samples[:, i] = (perm + np.random.uniform(size=n_samples)) / n_samples
    
    return samples
'''

def plot_interior_points(data_interior):
    # Plot the points in 3D
    fig = plt.figure(figsize=(3,3))
    ax = fig.add_subplot(111, projection='3d')

    # Scatter plot of sampled points
    ax.scatter(data_interior[:, 0], data_interior[:, 1], data_interior[:, 2],
            color='blue', s=30, alpha=0.7, label='Sampled Points')

    # Labels and grid
    ax.set_title('Sampled Points in the Interior of a 3D Unit Cube')
    ax.set_xlabel('X-axis')
    ax.set_ylabel('Y-axis')
    ax.set_zlabel('Z-axis')
    ax.grid(True)
    ax.legend()

    plt.show()