import time
import numpy as np
from scipy.interpolate import griddata
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

def get_boundary_mesh_from_ofpp_volume(ofpp_mesh):
    boundary_faces = np.concatenate([
        ofpp_mesh.faces[info.start:info.start + info.num]
        for info in ofpp_mesh.boundary.values()
    ])

    unique_boundary_points_indices = np.unique(boundary_faces)
    boundary_vertices = ofpp_mesh.points[unique_boundary_points_indices]

    old_to_new_index = {old: new for new, old in enumerate(unique_boundary_points_indices)}

    new_boundary_faces = np.array([[old_to_new_index[vertex] for vertex in face] for face in boundary_faces])
    
    return boundary_vertices, new_boundary_faces

def get_planepoints_in_convex_hull(vertices, slice_axis, axis_value, grid_resolution):
    # Determine the remaining axes for plotting
    if slice_axis == 'x':
        grid_y = np.linspace(min(vertices[:, 1]), max(vertices[:, 1]), grid_resolution)
        grid_z = np.linspace(min(vertices[:, 2]), max(vertices[:, 2]), grid_resolution)
        grid_y, grid_z = np.meshgrid(grid_y, grid_z)
        grid_x = np.ones_like(grid_y) * axis_value
        plane_points = np.stack([grid_x, grid_y, grid_z], axis=-1)
        extent = (grid_y.min(), grid_y.max(), grid_z.min(), grid_z.max())
    elif slice_axis == 'y':
        grid_x = np.linspace(min(vertices[:, 0]), max(vertices[:, 0]), grid_resolution)
        grid_z = np.linspace(min(vertices[:, 2]), max(vertices[:, 2]), grid_resolution)
        grid_x, grid_z = np.meshgrid(grid_x, grid_z)
        grid_y = np.ones_like(grid_x) * axis_value
        plane_points = np.stack([grid_x, grid_y, grid_z], axis=-1)
        extent = (grid_x.min(), grid_x.max(), grid_z.min(), grid_z.max())

    elif slice_axis == 'z':
        grid_x = np.linspace(min(vertices[:, 0]), max(vertices[:, 0]), grid_resolution)
        grid_y = np.linspace(min(vertices[:, 1]), max(vertices[:, 1]), grid_resolution)
        grid_x, grid_y = np.meshgrid(grid_x, grid_y)
        grid_z = np.ones_like(grid_y) * axis_value
        plane_points = np.stack([grid_x, grid_y, grid_z], axis=-1)
        extent = (grid_x.min(), grid_x.max(), grid_y.min(), grid_y.max())
        
    return plane_points, extent

def get_neighbor_plot_indices_for_plane(vertices, slice_axis, axis_value, grid_resolution=100):
    """
    Plots a slice of 3D points and values along an axis-parallel plane within a specified interval.

    Parameters:
    - points: (N, 3) array of 3D points.
    - values: (N,) array of values corresponding to the points.
    - axis: The axis to slice along ('x', 'y', or 'z').
    - axis_value: The value of the chosen axis, e.g. x=0.3
    - grid_resolution: The number of points along each dimension of the grid.
    """
    
    plane_points, extent = get_planepoints_in_convex_hull(vertices, slice_axis, axis_value, grid_resolution)

    # Interpolate the values on the grid
    indices = griddata(
        points=vertices,
        values=range(len(vertices)),
        xi=plane_points,
        method='nearest'
    )
    # print(f'{indices=}')
    return plane_points, indices.astype(int), extent

def plot_scalar_at_plane(coords, values, slice_axis, axis_value, grid_resolution=100, method='nearest', take_log=False):
    """
    Plots a slice of 3D points and values along an axis-parallel plane within a specified interval.

    Parameters:
    - points: (N, 3) array of 3D points.
    - values: (N,) array of values corresponding to the points.
    - axis: The axis to slice along ('x', 'y', or 'z').
    - axis_value: The value of the chosen axis, e.g. x=0.3
    - grid_resolution: The number of points along each dimension of the grid.
    """
    plane_points, extent = get_planepoints_in_convex_hull(coords, slice_axis, axis_value, grid_resolution)

    # Interpolate the values on the grid
    values_at_plane = griddata(
        points=coords,
        values=values,
        xi=plane_points,
        method=method
    )
    
    color_label = 'norm(u)'
    if take_log:
        values_at_plane = np.log10(values_at_plane)
        color_label = f'log({color_label})'
    
    plt.figure(figsize=(8, 6))
    sc = plt.imshow(values_at_plane, cmap='viridis', origin='lower', extent=extent)
    plt.colorbar(sc, label=color_label)
    plt.show()
    
    
def plot_stream_at_plane(points, vectors, axis, min_value, max_value, resolution=50, interpolation='nearest', density=1, measure_time=False):
    """
    Plots a slice of a 3D vector field along an axis-parallel plane within a specified interval.

    Parameters:
    - points: (N, 3) array of 3D points.
    - vectors: (N, 3) array of vectors corresponding to the points.
    - axis: The axis to slice along ('x', 'y', or 'z').
    - min_value: The minimum value of the interval.
    - max_value: The maximum value of the interval.
    """
    
    start_t = time.time()
    
    # Determine the indices of the points that lie within the interval
    if axis == 'x':
        mask = (points[:, 0] >= min_value) & (points[:, 0] <= max_value)
    elif axis == 'y':
        mask = (points[:, 1] >= min_value) & (points[:, 1] <= max_value)
    elif axis == 'z':
        mask = (points[:, 2] >= min_value) & (points[:, 2] <= max_value)
    else:
        raise ValueError("Axis must be 'x', 'y', or 'z'.")


    # Filter the points and vectors based on the interval
    slice_points = points[mask]
    slice_vectors = vectors[mask]

    if measure_time:
        print(f'needed for slice {time.time() - start_t:0.2f}')

    # Determine the remaining axes for plotting
    if axis == 'x':
        x_axis, y_axis = 1, 2
    elif axis == 'y':
        x_axis, y_axis = 0, 2
    elif axis == 'z':
        x_axis, y_axis = 0, 1

    # Create a grid of points
    x = np.linspace(slice_points[:, x_axis].min(), slice_points[:, x_axis].max(), resolution)
    y = np.linspace(slice_points[:, y_axis].min(), slice_points[:, y_axis].max(), resolution)
    X, Y = np.meshgrid(x, y)

    # Interpolate the vector field onto the grid
    U = griddata(slice_points[:, [x_axis, y_axis]], slice_vectors[:, x_axis], (X, Y), method=interpolation)
    V = griddata(slice_points[:, [x_axis, y_axis]], slice_vectors[:, y_axis], (X, Y), method=interpolation)
    if measure_time:
        print(f'needed for griddata {time.time() - start_t:0.2f}')
    
    # Calculate the magnitude of the vectors
    magnitude = np.sqrt(U**2 + V**2)

     # Plot the vector slice with magnitude represented by color
    plt.figure(figsize=(8, 6))
    strm = plt.streamplot(X, Y, U, V, color=magnitude, cmap='viridis', density=density)
    if measure_time:
        print(f'needed for streamplot {time.time() - start_t:0.2f}')
    
    plt.colorbar(strm.lines, label='Magnitude')
    # plt.xlabel(f'Axis {x_axis + 1}')
    # plt.ylabel(f'Axis {y_axis + 1}')
    plt.title(f'Vector Field Slice interval on {axis}-axis: [{min_value}, {max_value}]')
    plt.show()
    
    if measure_time:
        print(f'needed for vis {time.time() - start_t:0.2f}')
        

def plot_vector_at_plane(points, vectors, slice_axis, axis_value, eps, max_norm_factor=10, arrow_scale=1, interpolation='nearest', resolution=50):
    """
    Plots a slice of a 3D vector field along an axis-parallel plane within a specified interval.

    Parameters:
    - points: (N, 3) array of 3D points.
    - vectors: (N, 3) array of vectors corresponding to the points.
    - axis: The axis to slice along ('x', 'y', or 'z').
    - min_value: The minimum value of the interval.
    - max_value: The maximum value of the interval.
    """
    min_value = axis_value - eps
    max_value = axis_value + eps
    # Determine the indices of the points that lie within the interval
    if slice_axis == 'x':
        mask = (points[:, 0] >= min_value) & (points[:, 0] <= max_value)
    elif slice_axis == 'y':
        mask = (points[:, 1] >= min_value) & (points[:, 1] <= max_value)
    elif slice_axis == 'z':
        mask = (points[:, 2] >= min_value) & (points[:, 2] <= max_value)
    else:
        raise ValueError("Axis must be 'x', 'y', or 'z'.")

    # Filter the points and vectors based on the interval
    slice_points = points[mask]
    slice_vectors = vectors[mask]

    # Determine the remaining axes for plotting
    if slice_axis == 'x':
        x_axis, y_axis = 1, 2
    elif slice_axis == 'y':
        x_axis, y_axis = 0, 2
    elif slice_axis == 'z':
        x_axis, y_axis = 0, 1

    # Create a grid of points
    x = np.linspace(slice_points[:, x_axis].min(), slice_points[:, x_axis].max(), resolution)
    y = np.linspace(slice_points[:, y_axis].min(), slice_points[:, y_axis].max(), resolution)
    X, Y = np.meshgrid(x, y)

    # Interpolate the vector field onto the grid
    U = griddata(slice_points[:, [x_axis, y_axis]], slice_vectors[:, x_axis], (X, Y), method=interpolation, fill_value=0)
    V = griddata(slice_points[:, [x_axis, y_axis]], slice_vectors[:, y_axis], (X, Y), method=interpolation, fill_value=0)

    # Calculate the magnitude of the vectors
    magnitude = np.sqrt(U**2 + V**2)
    # max_magnitude = np.max(magnitude)
    # max_magnitude = np.min(U[1:,1:] - U[:-1,:-1], axis=
    x_min = np.min(x[1:] - x[:-1])
    y_min = np.min(y[1:] - y[:-1])
    max_magnitude = np.sqrt(x_min**2 + y_min**2) * max_norm_factor

    U_normalized = np.where(magnitude <= max_magnitude, U, U / (magnitude+1e-9) * max_magnitude)
    V_normalized = np.where(magnitude <= max_magnitude, V, V / (magnitude+1e-9) * max_magnitude)


    # Plot the vector field with normalized arrow lengths and magnitude represented by color
    plt.figure(figsize=(8, 6))
    plt.quiver(X, Y, U_normalized, V_normalized, magnitude, cmap='viridis', scale=arrow_scale)
    plt.colorbar(label='Magnitude')
    plt.title(f'Vector Field Slice interval on {slice_axis}-axis: [{min_value}, {max_value}]')
    plt.show()
    
    
    
# def plot_frames(idx, true_list, pred_list, inside_mask, extent):
#     fig, axes = plt.subplots(1, 3, figsize=(24, 6))

#     # Determine the minimum and maximum values across both datasets
#     vmin = min(np.min(true_list[idx]), np.min(true_list[idx]))
#     vmax = max(np.max(true_list[idx]), np.max(true_list[idx]))

#     # Create a normalization instance
#     norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

#     # Create a custom colormap that makes the masked areas transparent
#     cmap_viridis = plt.cm.viridis
#     cmap_viridis.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

#     # Apply the mask to the data
#     masked_true = np.ma.masked_where(inside_mask == 0, true_list[idx])
#     masked_pred = np.ma.masked_where(inside_mask == 0, pred_list[idx])

#     # Plot the true data with the custom colormap
#     sc1 = axes[0].imshow(masked_true, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
#     axes[0].set_title('True')

#     # Plot the predicted data with the custom colormap
#     sc2 = axes[1].imshow(masked_pred, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
#     axes[1].set_title('Predicted')

#     # Calculate the difference
#     diff = true_list[idx] - pred_list[idx]

#     # Create a custom colormap for the difference plot
#     cmap_diff = plt.cm.bwr
#     cmap_diff.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

#     # Apply the mask to the difference data
#     masked_diff = np.ma.masked_where(inside_mask == 0, diff)

#     # Plot the difference with the custom colormap
#     sc3 = axes[2].imshow(masked_diff, cmap=cmap_diff, origin='lower', extent=extent)
#     axes[2].set_title('Difference')

#     # Create a single colorbar for the first two plots
#     cbar = fig.colorbar(sc2, ax=axes[:2], orientation='vertical', fraction=0.015, pad=0.05)
#     cbar.set_label('norm(u)')

#     # Create a separate colorbar for the difference plot
#     cbar_diff = fig.colorbar(sc3, ax=axes[2], orientation='vertical', fraction=0.015, pad=0.05)
#     cbar_diff.set_label('Difference')

#     # Show the plot
#     plt.show()

def plot_frames_3d(idx, true_list, pred_list, inside_mask, extent, logify=False):
    fig, axes = plt.subplots(1, 3, figsize=(24, 6))

    # Determine the minimum and maximum values across both datasets
    vmin = min(np.min(true_list[idx]), np.min(pred_list[idx]))
    vmax = max(np.max(true_list[idx]), np.max(pred_list[idx]))

    # Create a normalization instance
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    # Create a custom colormap that makes the masked areas transparent
    cmap_viridis = plt.cm.viridis
    cmap_viridis.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

    # Apply the mask to the data
    masked_true = np.ma.masked_where(inside_mask == 0, true_list[idx])
    masked_pred = np.ma.masked_where(inside_mask == 0, pred_list[idx])

    # Plot the true data with the custom colormap
    sc1 = axes[0].imshow(masked_true, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
    axes[0].set_title('True')

    # Plot the predicted data with the custom colormap
    sc2 = axes[1].imshow(masked_pred, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
    axes[1].set_title('Predicted')

    # Apply logarithmic scaling if logify is True
    if logify:
        sc1.set_norm(mcolors.LogNorm(vmin=max(vmin, 1e-10), vmax=vmax))
        sc2.set_norm(mcolors.LogNorm(vmin=max(vmin, 1e-10), vmax=vmax))

    # Calculate the difference
    diff = true_list[idx] - pred_list[idx]

    # Create a custom colormap for the difference plot
    cmap_diff = plt.cm.bwr
    cmap_diff.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

    # Apply the mask to the difference data
    masked_diff = np.ma.masked_where(inside_mask == 0, diff)

    # Plot the difference with the custom colormap
    sc3 = axes[2].imshow(masked_diff, cmap=cmap_diff, origin='lower', extent=extent)
    axes[2].set_title('Difference')

    # Create a single colorbar for the first two plots
    cbar = fig.colorbar(sc2, ax=axes[:2], orientation='vertical', fraction=0.015, pad=0.05)
    cbar.set_label('norm(u)')

    # Create a separate colorbar for the difference plot
    cbar_diff = fig.colorbar(sc3, ax=axes[2], orientation='vertical', fraction=0.015, pad=0.05)
    cbar_diff.set_label('Difference')

    # Show the plot
    plt.show()