# Revoxelize multi-body robot config to single N×N×N voxel grid
# Output: 0=void, 1=soft, 2-M=rigid segments (with graph coloring for neighbors)

import argparse
import os
import pickle
import numpy as np
from typing import Dict, List, Optional, Set, Tuple
from matplotlib import pyplot as plt
from scipy.spatial.transform import Rotation
from rise import *
from utils.plot import filter_surface_voxels, plot_binary, plot_colored_rigid_by_id


def build_segment_adjacency_graph(
    structure_config,
) -> Dict[Tuple[int, int], Set[Tuple[int, int]]]:
    """Build adjacency graph at segment level (body_sid, segment_bid)."""
    adjacency = {}

    # Initialize adjacency for all rigid segments
    for body in structure_config.bodies:
        for voxel_idx in range(len(body.segment_bid)):
            segment_bid = body.segment_bid[voxel_idx]
            segment_type = body.segment_type[voxel_idx]

            # Only track rigid segments
            if segment_bid != RS_NULL_INDEX and segment_type != 0:
                segment_key = (body.body_sid, segment_bid)
                if segment_key not in adjacency:
                    adjacency[segment_key] = set()

    # Add edges based on constraints connecting segments
    for constraint in structure_config.constraints:
        segment_a = (constraint.a_body_sid, constraint.a_segment_bid)
        segment_b = (constraint.b_body_sid, constraint.b_segment_bid)

        # Only add edge if both segments are rigid
        if segment_a in adjacency and segment_b in adjacency:
            adjacency[segment_a].add(segment_b)
            adjacency[segment_b].add(segment_a)

    return adjacency


def graph_coloring_with_seed(
    adjacency: Dict[Tuple[int, int], Set[Tuple[int, int]]],
    num_colors: int,
    seed: int = 0,
) -> Dict[Tuple[int, int], int]:
    """Greedy graph coloring ensuring neighbors have different IDs.

    Args:
        adjacency: Map from segment (body_sid, segment_bid) to set of adjacent segments
        num_colors: Number of colors to use for rigid segments (2 to num_colors+1)
        seed: Random seed for shuffling nodes

    Returns:
        Map from segment (body_sid, segment_bid) to color ID (2 to M+1)
    """
    rng = np.random.default_rng(seed)
    nodes = list(adjacency.keys())
    rng.shuffle(nodes)

    colors = {}
    for node in nodes:
        neighbor_colors = {colors[n] for n in adjacency[node] if n in colors}
        available = [c for c in range(2, num_colors + 2) if c not in neighbor_colors]

        if available:
            colors[node] = rng.choice(available)
        else:
            color_counts = {}
            for c in range(2, num_colors + 2):
                color_counts[c] = sum(1 for n in adjacency[node] if colors.get(n) == c)
            colors[node] = min(color_counts, key=color_counts.get)

    return colors


def compute_body_bounding_box(
    body,
    voxel_size: float,
) -> Tuple[np.ndarray, np.ndarray]:
    origin = np.array(
        [
            body.relative_origin_position.x,
            body.relative_origin_position.y,
            body.relative_origin_position.z,
        ]
    )
    quat = (
        body.relative_orientation.x,
        body.relative_orientation.y,
        body.relative_orientation.z,
        body.relative_orientation.w,
    )

    size = np.array([body.x_voxels, body.y_voxels, body.z_voxels]) * voxel_size
    corners_local = np.array(
        [
            [0, 0, 0],
            [size[0], 0, 0],
            [0, size[1], 0],
            [0, 0, size[2]],
            [size[0], size[1], 0],
            [size[0], 0, size[2]],
            [0, size[1], size[2]],
            [size[0], size[1], size[2]],
        ]
    )

    r = Rotation.from_quat(quat)
    corners_world = r.apply(corners_local) + origin

    return corners_world.min(axis=0), corners_world.max(axis=0)


def compute_structure_bounding_box(
    structure_config,
    voxel_size: float,
) -> Tuple[np.ndarray, np.ndarray]:
    all_mins = []
    all_maxs = []

    for body in structure_config.bodies:
        min_corner, max_corner = compute_body_bounding_box(body, voxel_size)
        all_mins.append(min_corner)
        all_maxs.append(max_corner)

    return np.min(all_mins, axis=0), np.max(all_maxs, axis=0)


def revoxelize_robot_config(
    structure_config,
    N: int = 64,
    N_used: int = 48,
    M: int = 7,
    seed: int = 0,
    voxel_size: Optional[float] = None,
) -> np.ndarray:
    """Revoxelize multi-body robot config to single N×N×N voxel grid.

    Output: 0=void, 1=soft, 2-M=rigid segments (with graph coloring for neighbors)

    This function preserves aspect ratio by:
    1. Computing the bounding box of the entire structure
    2. Finding the largest dimension
    3. Scaling uniformly so the largest dimension fits in N_used voxels
    4. Centering in the N×N×N grid
    """
    if voxel_size is None:
        voxel_size = structure_config.voxel_size

    grid = np.zeros((N, N, N), dtype=np.int32)

    # Compute bounding box and scale factor to preserve aspect ratio
    global_min, global_max = compute_structure_bounding_box(
        structure_config, voxel_size
    )
    structure_size = global_max - global_min
    structure_center = (global_min + global_max) / 2.0

    # Scale factor: ensures largest dimension fits in N_used voxels
    # All dimensions use same scale factor to preserve aspect ratio
    max_dimension = np.max(structure_size)
    if max_dimension > 0:
        output_voxel_size = max_dimension / N_used
    else:
        output_voxel_size = voxel_size

    grid_center = np.array([N / 2.0, N / 2.0, N / 2.0])

    # Build segment-level adjacency graph and assign colors
    adjacency = build_segment_adjacency_graph(structure_config)
    num_colors = M - 1
    segment_colors = graph_coloring_with_seed(adjacency, num_colors, seed)

    corner_offsets = np.array(
        [
            [0, 0, 0],
            [1, 0, 0],
            [0, 1, 0],
            [0, 0, 1],
            [1, 1, 0],
            [1, 0, 1],
            [0, 1, 1],
            [1, 1, 1],
        ],
        dtype=np.float64,
    )

    for body in structure_config.bodies:
        body_sid = body.body_sid

        origin = np.array(
            [
                body.relative_origin_position.x,
                body.relative_origin_position.y,
                body.relative_origin_position.z,
            ]
        )
        quat = (
            body.relative_orientation.x,
            body.relative_orientation.y,
            body.relative_orientation.z,
            body.relative_orientation.w,
        )
        r = Rotation.from_quat(quat)

        for iz in range(body.z_voxels):
            for iy in range(body.y_voxels):
                for ix in range(body.x_voxels):
                    voxel_idx = (
                        ix + iy * body.x_voxels + iz * body.x_voxels * body.y_voxels
                    )
                    segment_bid = body.segment_bid[voxel_idx]
                    segment_type = body.segment_type[voxel_idx]

                    # Empty voxels have segment_bid == RS_NULL_INDEX
                    if segment_bid == RS_NULL_INDEX:
                        continue

                    # Get color for this specific segment
                    segment_key = (body_sid, segment_bid)
                    rigid_id = segment_colors.get(segment_key, 2)

                    voxel_origin = np.array([ix, iy, iz], dtype=np.float64) * voxel_size
                    corners_local = voxel_origin + corner_offsets * voxel_size
                    corners_world = r.apply(corners_local) + origin
                    corners_relative = corners_world - structure_center
                    corners_grid = corners_relative / output_voxel_size + grid_center

                    grid_min = np.floor(corners_grid.min(axis=0)).astype(int)
                    grid_max = np.floor(corners_grid.max(axis=0)).astype(int)
                    grid_min = np.maximum(grid_min, 0)
                    grid_max = np.minimum(grid_max, N - 1)

                    for gx in range(grid_min[0], grid_max[0] + 1):
                        for gy in range(grid_min[1], grid_max[1] + 1):
                            for gz in range(grid_min[2], grid_max[2] + 1):
                                if segment_type == 0:  # Soft segment
                                    if grid[gx, gy, gz] == 0:
                                        grid[gx, gy, gz] = 1
                                else:  # Rigid segment
                                    grid[gx, gy, gz] = rigid_id

    return grid


def load_robot_config(config_path: str):
    with open(config_path, "rb") as f:
        return pickle.load(f)


def save_voxel_grid(grid: np.ndarray, output_path: str) -> None:
    np.save(output_path, grid)
    print(f"Saved to: {output_path}")


def print_grid_statistics(grid: np.ndarray, M: int) -> None:
    total_voxels = grid.size
    void_count = np.sum(grid == 0)
    soft_count = np.sum(grid == 1)
    rigid_counts = {i: np.sum(grid == i) for i in range(2, M + 1)}
    total_rigid = sum(rigid_counts.values())
    occupied = soft_count + total_rigid

    print(f"\nGrid Statistics:")
    print(f"  Size: {grid.shape}")
    print(f"  Void: {void_count:,} ({100*void_count/total_voxels:.1f}%)")
    print(f"  Soft: {soft_count:,} ({100*soft_count/total_voxels:.1f}%)")
    print(f"  Rigid: {total_rigid:,} ({100*total_rigid/total_voxels:.1f}%)")
    for rid, count in rigid_counts.items():
        if count > 0:
            print(f"    ID {rid}: {count:,}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Revoxelize robot config to single voxel grid"
    )
    parser.add_argument(
        "config_path", type=str, help="Path to robot config pickle file"
    )
    parser.add_argument(
        "-o", "--output", type=str, default=None, help="Output .npy file"
    )
    parser.add_argument(
        "-N", "--grid-size", type=int, default=40, help="Grid size (N×N×N)"
    )
    parser.add_argument("--n-used", type=int, default=48, help="Used region size")
    parser.add_argument(
        "-M", "--max-rigid-id", type=int, default=7, help="Max rigid ID"
    )
    parser.add_argument("-s", "--seed", type=int, default=0, help="Random seed")

    args = parser.parse_args()

    if args.output is None:
        base_name = os.path.splitext(os.path.basename(args.config_path))[0]
        output_dir = os.path.dirname(args.config_path)
        args.output = os.path.join(output_dir, f"{base_name}_revoxelized.npy")

    print(f"Input: {args.config_path}")
    print(f"Output: {args.output}")
    print(
        f"Grid: {args.grid_size}×{args.grid_size}×{args.grid_size}, Used: {args.n_used}, Max ID: {args.max_rigid_id}, Seed: {args.seed}"
    )

    structure_config = load_robot_config(args.config_path)
    print(
        f"Loaded: {structure_config.name} ({len(structure_config.bodies)} bodies, {len(structure_config.constraints)} constraints)"
    )

    grid = revoxelize_robot_config(
        structure_config,
        N=args.grid_size,
        N_used=args.n_used,
        M=args.max_rigid_id,
        seed=args.seed,
    )

    print_grid_statistics(grid, args.max_rigid_id)
    save_voxel_grid(grid, args.output)

    # Visualize
    fig = plt.figure(figsize=(14, 6))

    ax1 = fig.add_subplot(1, 2, 1, projection="3d")
    ax1.set_title("Soft Voxels (ID=1)")
    soft_mask = grid == 1
    plot_binary(ax1, soft_mask, color=(0.5, 0.8, 1.0, 0.8))
    ax1.set_box_aspect([1, 1, 1])  # Equal aspect ratio
    ax1.set_xlim(0, args.grid_size)
    ax1.set_ylim(0, args.grid_size)
    ax1.set_zlim(0, args.grid_size)

    ax2 = fig.add_subplot(1, 2, 2, projection="3d")
    ax2.set_title("Rigid Segments (ID=2 to M)")
    is_rigid = grid >= 2
    segment_id = np.where(grid >= 2, grid - 1, 0)
    segment_colors = [
        (0.8, 0.2, 0.2, 1.0),
        (0.2, 0.8, 0.2, 1.0),
        (0.2, 0.2, 0.8, 1.0),
        (0.8, 0.8, 0.2, 1.0),
        (0.8, 0.2, 0.8, 1.0),
        (0.2, 0.8, 0.8, 1.0),
        (0.9, 0.5, 0.1, 1.0),
        (0.5, 0.1, 0.9, 1.0),
    ]
    plot_colored_rigid_by_id(ax2, is_rigid, segment_id, segment_colors)
    ax2.set_box_aspect([1, 1, 1])  # Equal aspect ratio
    ax2.set_xlim(0, args.grid_size)
    ax2.set_ylim(0, args.grid_size)
    ax2.set_zlim(0, args.grid_size)

    plt.tight_layout()
    plt.show()

    print("Done!")
