import numpy as np
import cc3d
import vedo
from typing import Tuple, List, Dict
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from pyquaternion import Quaternion


def filter_surface_voxels(binary: np.ndarray):
    """
    Args:
        binary: A binary 3d array of shape [X, Y, Z], must be bool
    """
    binary = binary.astype(bool)
    connectivity_graph = cc3d.voxel_connectivity_graph(binary, connectivity=6)
    return binary & (connectivity_graph < 63)


def plot_binary(
    ax: Axes,
    binary: np.ndarray,
    color: Tuple[float, float, float, float] = (1, 1, 1, 1),
):
    """
    Args:
        ax: Axes object to plot on
        binary: A binary 3d array of shape [X, Y, Z],
                value below 0.5 is considered 0 and above 0.5 is considered 1
        color: Color of the "1" voxels
    """
    ax.clear()
    surface = filter_surface_voxels(binary > 0.5 if binary.dtype != bool else binary)
    print(f"Surface voxels: {np.sum(surface)}")
    colors = np.zeros(list(binary.shape) + [4])
    colors[surface] = color
    ax.voxels(surface, facecolors=colors, edgecolor="none")
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    plt.draw()


# def plot_binary_vedo(binary: np.ndarray, color=(1, 1, 1, 1), smooth=True):
#     """
#     Visualize a 3D binary volume with vedo by extracting an isosurface.
#
#     Args:
#         binary: A binary 3D array of shape [X, Y, Z].
#                 Values > 0.5 are considered '1', else '0'.
#         color:  (r, g, b, a) tuple in [0,1]. Color + transparency for the surface.
#     """
#
#     # Convert input to boolean if necessary
#     bin_bool = (binary > 0.5) if binary.dtype != bool else binary
#
#     # surface = filter_surface_voxels(bin_bool)
#     # print(f"Surface voxels: {np.sum(surface)}")
#     surface = bin_bool
#
#     # Create a volume in vedo.
#     vol = vedo.Volume(surface.astype(np.uint8))
#
#     if not smooth:
#         vol = vol.legosurface(0.5, 1.1).color(color[:3], alpha=color[-1])
#     else:
#         vol.color([(0, (0, 0, 0)), (1, color[:3])], vmin=0, vmax=1).alpha(
#             [(0, 0), (1, color[3])], vmin=0, vmax=1
#         )
#
#     return vol


def plot_binary_vedo(binary: np.ndarray, color=(1, 1, 1, 1), smooth=True):
    """
    Visualize a 3D binary volume with vedo by extracting an isosurface.
    Properly maps voxel-centered values to vertex-centered representation.

    Args:
        binary: A binary 3D array of shape [X, Y, Z].
                Values > 0.5 are considered '1', else '0'.
        color:  (r, g, b, a) tuple in [0,1]. Color + transparency for the surface.
    """
    # Convert input to boolean if necessary
    bin_bool = (binary > 0.5) if binary.dtype != bool else binary
    
    # Cast to uint8 for processing
    bin_uint8 = bin_bool.astype(np.uint8)

    # Pad the binary mask with zeros to get a grid with +2 in each dimension
    padded = np.pad(
        bin_uint8, ((1, 1), (1, 1), (1, 1)), mode="constant", constant_values=0
    )

    # Create a vertex grid with size N+1 in each dimension
    vertex_grid = np.zeros(
        (padded.shape[0] - 1, padded.shape[1] - 1, padded.shape[2] - 1),
        dtype=np.uint8,
    )

    # Fill the vertex grid using vectorized operations
    # Each vertex is influenced by the 8 voxels that meet at that vertex
    voxel_influences = np.zeros((8, *vertex_grid.shape), dtype=np.uint8)

    # Collect the 8 voxel values that influence each vertex
    for idx, offset in enumerate(
        [
            (0, 0, 0),
            (1, 0, 0),
            (0, 1, 0),
            (1, 1, 0),
            (0, 0, 1),
            (1, 0, 1),
            (0, 1, 1),
            (1, 1, 1),
        ]
    ):
        voxel_influences[idx] = padded[
            offset[0] : offset[0] + vertex_grid.shape[0],
            offset[1] : offset[1] + vertex_grid.shape[1],
            offset[2] : offset[2] + vertex_grid.shape[2],
        ]

    # For binary volumes, if any surrounding voxel is 1, the vertex is 1
    vertex_grid = np.any(voxel_influences, axis=0).astype(np.uint8)

    # Create a volume in vedo for visualization
    vol = vedo.Volume(vertex_grid)

    if not smooth:
        # Create lego-style surface with boundary=False
        vol = vol.legosurface(0.5, 1.1, boundary=False).color(
            color[:3], alpha=color[-1]
        )
    else:
        # Apply color and alpha mapping directly to the volume
        vol.color([(0, (0, 0, 0)), (1, color[:3])], vmin=0, vmax=1).alpha(
            [(0, 0), (1, color[3])], vmin=0, vmax=1
        )

    return vol



def plot_segment_id_vedo(segments: np.ndarray, colors=None, smooth=True):
    """
    Visualize a 3D segmentation volume with vedo, creating separate visualizations for each segment ID.
    Properly maps voxel-centered values to vertex-centered representation.

    Args:
        segments: An integer 3D array of shape [X, Y, Z] containing segment IDs.
                  Each unique integer represents a different segment.
        colors:   List of (r, g, b, a) tuples in [0,1]. Colors + transparency for each segment.
                  If fewer colors than segments, colors will cycle using modulo.
        smooth:   Whether to use smooth volume rendering (True) or lego-style blocks (False).

    Returns:
        A vedo.Assembly containing all the segment visualizations.
    """
    # Default colors if none provided
    if colors is None:
        colors = [
            (1, 0, 0, 1),  # Red
            (0, 1, 0, 1),  # Green
            (0, 0, 1, 1),  # Blue
            (1, 1, 0, 1),  # Yellow
            (1, 0, 1, 1),  # Magenta
            (0, 1, 1, 1),  # Cyan
            (1, 0.5, 0, 1),  # Orange
            (0.5, 0, 1, 1),  # Purple
        ]

    # Find unique segment IDs (excluding 0 which is typically background)
    unique_ids = np.unique(segments)
    if (
        0 in unique_ids and len(unique_ids) > 1
    ):  # Only remove 0 if there are other values
        unique_ids = unique_ids[unique_ids != 0]

    # Create a list to hold all segment visualizations
    visualizations = []

    # Process each segment ID
    for seg_id in unique_ids:
        if seg_id == 0:  # Skip background segment (usually 0 is background)
            continue

        # Create binary mask for this segment
        bin_bool = (segments == seg_id).astype(np.uint8)

        # Pad the binary mask with zeros to get a 130³ grid (add 1 on each side)
        padded = np.pad(
            bin_bool, ((1, 1), (1, 1), (1, 1)), mode="constant", constant_values=0
        )

        # Create a vertex grid with 129³ vertices
        vertex_grid = np.zeros(
            (padded.shape[0] - 1, padded.shape[1] - 1, padded.shape[2] - 1),
            dtype=np.uint8,
        )

        # Fill the vertex grid using vectorized operations
        # Each vertex is influenced by the 8 voxels that meet at that vertex
        # For a vertex at position (i,j,k), these are the voxels at:
        # (i-1,j-1,k-1), (i,j-1,k-1), (i-1,j,k-1), (i,j,k-1),
        # (i-1,j-1,k), (i,j-1,k), (i-1,j,k), (i,j,k)

        # We use numpy's stride_tricks to efficiently apply this operation
        # This creates views into the padded array for each of the 8 positions
        # without actually copying the data
        voxel_influences = np.zeros((8, *vertex_grid.shape), dtype=np.uint8)

        # Collect the 8 voxel values that influence each vertex
        for idx, offset in enumerate(
            [
                (0, 0, 0),
                (1, 0, 0),
                (0, 1, 0),
                (1, 1, 0),
                (0, 0, 1),
                (1, 0, 1),
                (0, 1, 1),
                (1, 1, 1),
            ]
        ):
            voxel_influences[idx] = padded[
                offset[0] : offset[0] + vertex_grid.shape[0],
                offset[1] : offset[1] + vertex_grid.shape[1],
                offset[2] : offset[2] + vertex_grid.shape[2],
            ]

        # For binary segmentations, if any surrounding voxel belongs to segment,
        # we consider the vertex part of the segment (max operation)
        vertex_grid = np.any(voxel_influences, axis=0)

        # Create a volume in vedo for this segment
        vol = vedo.Volume(vertex_grid)

        # Get the color for this segment based on its ID
        color_idx = seg_id % len(colors)
        color = colors[color_idx]

        if not smooth:
            # Create lego-style surface
            vol = vol.legosurface(0.5, 1.1, boundary=False).color(
                color[:3], alpha=color[-1]
            )
        else:
            # Apply color and alpha mapping directly to the volume
            vol.color([(0, (0, 0, 0)), (1, color[:3])], vmin=0, vmax=1).alpha(
                [(0, 0), (1, color[3])], vmin=0, vmax=1
            )

        # Add segment ID as metadata
        vol.name = f"Segment_{seg_id}"

        # Add to list of visualizations
        visualizations.append(vol)

    # Combine all visualizations into one assembly
    if visualizations:
        assembly = vedo.Assembly(visualizations)
        return assembly
    else:
        return None


def plot_dual_binary(ax: Axes, binary1, binary2, color1, color2):
    """
    Plot 2 binary arrays together, second binary array overwrites the first
    Args:
        ax: Axes object to plot on
        binary1: A binary 3d array of shape [X, Y, Z],
                value below 0.5 is considered 0 and above 0.5 is considered 1.
        binary2: A binary 3d array of shape [X, Y, Z],
                value below 0.5 is considered 0 and above 0.5 is considered 1.
        color1: Color of the "1" voxels in the first array.
        color2: Color of the "1" voxels in the second array
    """
    ax.clear()
    colors = np.zeros(list(binary1.shape) + [4])

    surface1 = filter_surface_voxels(
        binary1 > 0.5 if binary1.dtype != bool else binary1
    )
    surface2 = filter_surface_voxels(
        binary2 > 0.5 if binary2.dtype != bool else binary2
    )
    colors[surface1] = color1
    colors[surface2] = color2
    ax.voxels(
        np.logical_or(surface1, surface2),
        facecolors=colors,
        edgecolor="none",
    )
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    plt.draw()


def plot_colored_rigid_by_segment(
    ax: Axes, is_rigid, segment, segment_colors, require_correspondence=True
):
    """
    Plot rigid segments
    Args:
        ax: Axes object to plot on
        is_rigid: A binary 3d array of shape [X, Y, Z]
        segment: A weight 3d array of shape [C, X, Y, Z], the first dim should sum to 1,
            the channel with the highest weight is considered as the label. The first
            channel is the non-rigid channel.
        segment_colors: A list of segment colors.
        require_correspondence: If True, requires C - 1 = len(segment_colors)
    """
    if require_correspondence and len(segment) != len(segment_colors) + 1:
        raise ValueError(
            "Segment channel (dim 0) must have same length as segment colors when minus 1"
        )
    is_rigid = is_rigid > 0.5 if is_rigid.dtype != bool else is_rigid
    max_indices = np.argmax(segment, axis=0)
    plot_colored_rigid_by_id(ax, is_rigid, max_indices, segment_colors)


def plot_colored_rigid_by_id(ax: Axes, is_rigid, segment_id, segment_colors):
    """
    Plot rigid segments
    Args:
        ax: Axes object to plot on
        is_rigid: A binary 3d array of shape [X, Y, Z]
        segment_id: An id 3d array of shape [X, Y, Z], voxels labeled as 0 is
            non-rigid.
        segment_colors: A list of segment colors.
    """
    is_rigid = is_rigid > 0.5 if is_rigid.dtype != bool else is_rigid

    ax.clear()
    colors = np.zeros(list(segment_id.shape) + [4])
    for idx, seg in enumerate(range(1, np.max(segment_id) + 1)):
        surface = np.logical_and(filter_surface_voxels(segment_id == seg), is_rigid)
        colors[surface] = segment_colors[idx % len(segment_colors)]
    ax.voxels(np.any(colors != 0, axis=-1), facecolors=colors, edgecolor="none")
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    plt.draw()


def plot_hinge_joints(ax: Axes, connections: List[Dict[str, np.ndarray]]):
    """
    Plot hinge joints
    Args:
        ax: Axes object to plot on
        connections: A list of dicts with two keys, "position" and "axis",
            both has value of a numpy array of shape [3]
    """
    for connection in connections:
        start = connection["position"]
        direction = connection["axis"]
        ax.quiver(
            start[0],
            start[1],
            start[2],  # Start point
            direction[0],
            direction[1],
            direction[2],  # Direction components
            length=20,  # Arrow length (can be adjusted)
            normalize=True,  # Normalize the direction vector to unit length
            color="black",
        )
        ax.scatter(start[0], start[1], start[2], color="black", s=100)
    plt.draw()


def plot_floor(floor: np.ndarray):
    """
    Visualizes a 2D floor using both 2D and 3D plots.

    This function creates a figure with two subplots:
    1. A 2D image plot of the floor using the 'terrain' colormap.
    2. A 3D surface plot representing the topography with labeled axes for X, Y, and height.

    Args:
        floor (numpy.ndarray): A 2D numpy array representing the floor height field.
    """
    fig = plt.figure(figsize=(16, 8))

    # Create a subplot for the 2D image
    ax1 = fig.add_subplot(1, 2, 1)
    cax = ax1.imshow(floor, cmap="terrain")
    fig.colorbar(cax, ax=ax1)
    ax1.set_title("Floor Height Map (2D)")

    # Create a subplot for the 3D surface
    ax2 = fig.add_subplot(1, 2, 2, projection="3d")
    height, width = floor.shape
    x = np.arange(0, width)
    y = np.arange(0, height)
    x, y = np.meshgrid(x, y)

    ax2.plot_surface(
        x,
        y,
        floor,
        cmap="terrain",
        rstride=1,
        cstride=1,
        linewidth=0,
        antialiased=False,
    )
    ax2.set_title("Floor Height Map (3D)")
    ax2.set_xlabel("X")
    ax2.set_ylabel("Y")
    ax2.set_zlabel("Height")

    plt.tight_layout()
    return fig


def plot_voxel_view(
    ax: Axes,
    voxel_positions: np.ndarray,
    voxel_features: np.ndarray,
    point_size: float = 2,
):
    """
    Plots a 3D view of the voxels and their pressure values.

    Args:
        ax (Axes): Matplotlib 3D axis for plotting.
        voxel_positions (np.ndarray): Array of shape [voxel_num, 3] representing voxel positions.
        voxel_features (np.ndarray): Array of shape [voxel_num, 4] where the last column holds pressure values.
        point_size (float): Size of point to represent voxels.
    """
    x, y, z = voxel_positions[:, 0], voxel_positions[:, 1], voxel_positions[:, 2]
    pressure = np.abs(voxel_features[:, 3])

    # Mask for zero pressure points
    mask = pressure == 0
    non_zero_mask = pressure != 0

    # Plot zero pressure points as transparent circles
    # ax.scatter(x[mask], y[mask], z[mask], color=[0, 0, 0.1, 0.05])

    # Plot non-zero pressure points with fixed color range
    scatter = ax.scatter(
        x[non_zero_mask],
        y[non_zero_mask],
        z[non_zero_mask],
        c=pressure[non_zero_mask],
        cmap="inferno",
        vmin=0.0,
        vmax=0.1,
        s=point_size,
    )
    plt.colorbar(
        scatter, ax=ax, label="Pressure", orientation="horizontal", shrink=0.7, pad=0.1
    )

    # ax.set_title("Voxel pressure")
    lim = np.mean(np.max(np.abs(voxel_positions), axis=0))
    ax.set_xlim([-lim, lim])
    ax.set_ylim([-lim, lim])
    ax.set_zlim([-lim, lim])
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")


def plot_rigid_body_graph(
    ax: Axes,
    voxel_positions: np.ndarray,
    node_positions: np.ndarray,
    node_features: np.ndarray,
    edges: np.ndarray,
):
    """
    Plots a 3D graph of the rigid bodies, including nodes and edges between them.

    Args:
        ax (Axes): Matplotlib 3D axis for plotting.
        voxel_positions (np.ndarray): Array of shape [voxel_num, 3] representing voxel positions.
        node_positions (np.ndarray): Array of shape [N*, 3] representing node positions (rigid body positions).
        node_features (np.ndarray): Array of shape [N*, 14] representing node features (including orientation).
        edges (np.ndarray): Array of shape [2, E* * 2] representing edges connecting nodes.
    """
    ax.scatter(
        voxel_positions[:, 0],
        voxel_positions[:, 1],
        voxel_positions[:, 2],
        color=[0, 0, 0.1, 0.1],
    )
    ax.scatter(
        node_positions[:, 0],
        node_positions[:, 1],
        node_positions[:, 2],
        color="red",
        s=100,
    )

    for i, node_position in enumerate(node_positions):
        orientation = Quaternion(node_features[i, 3:7])
        direction = orientation.rotate(np.array([0, 0, 1]))
        ax.quiver(
            node_position[0],
            node_position[1],
            node_position[2],
            direction[0] * 0.1,
            direction[1] * 0.1,
            direction[2] * 0.1,
            color="red",
        )

    for edge in edges[:, : edges.shape[1] // 2].T:
        start = node_positions[edge[0]]
        end = node_positions[edge[1]]
        ax.plot(
            [start[0], end[0]], [start[1], end[1]], [start[2], end[2]], color="blue"
        )

    # ax.set_title("Rigid body state")
    lim = np.mean(np.max(np.abs(voxel_positions), axis=0))
    ax.set_xlim([-lim, lim])
    ax.set_ylim([-lim, lim])
    ax.set_zlim([-lim, lim])
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")


def plot_encoder_view(
    axs: Tuple[Axes, Axes],
    voxel_positions: np.ndarray,
    voxel_features: np.ndarray,
    node_positions: np.ndarray,
    node_features: np.ndarray,
    edges: np.ndarray,
    point_size: float = 2,
) -> None:
    """
    Plots the voxel and rigid body views of the encoder's input in subplots.

    Args:
        axs (Tuple[Axes, Axes]): List of Matplotlib 3D axes for plotting (expects 2 or more).
        voxel_positions (np.ndarray): Array of shape [voxel_num, 3] representing voxel positions.
        voxel_features (np.ndarray): Array of shape [voxel_num, 4] representing voxel features.
        node_positions (np.ndarray): Array of shape [N*, 3] representing node positions.
        node_features (np.ndarray): Array of shape [N*, 14] representing node features.
        edges (np.ndarray): Array of shape [2, E* * 2] representing edges connecting nodes.
        point_size (float): Size of point to represent voxels.
    """
    plot_voxel_view(axs[0], voxel_positions, voxel_features, point_size=point_size)
    plot_rigid_body_graph(axs[1], voxel_positions, node_positions, node_features, edges)
    # plot_rigid_body_topology(axs[2], voxel_positions, node_positions, edges)
    plt.draw()
    plt.pause(0.01)
