# filename: wall_collision.py
"""
Wall collision detection and visualization for multi-modal trajectory filtering.

The wall is a 2D plane (zero-width) positioned between the robot and drawer handle.
Trajectories that cause any part of the robot arm to cross this plane are rejected.

Key Features:
  - Create visual wall (thin semi-transparent cuboid) for rendering
  - Geometric plane collision check for all robot arm links
  - Track collision state during trajectory execution
  - Overlay wall on video frames

Usage:
  wall = create_wall(task_env, wall_config)

  for step in trajectory:
      # Execute step...
      collision, link_idx, link_pos = check_wall_collision(task_env, wall_config)
      if collision:
          print(f"Wall collision at link {link_idx}!")
          break  # Stop trajectory
"""

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.const import PrimitiveShape


# ============================================================================
# Default Wall Configuration
# ============================================================================

# Wall is positioned between robot (y≈0) and drawer handle (y≈-0.35)
# Wall plane: y = WALL_Y, bounded by [WALL_MIN_X, WALL_MAX_X] and [WALL_MIN_Z, WALL_MAX_Z]
DEFAULT_WALL_CONFIG = {
    # Wall plane position (Y coordinate where the wall sits)
    # Computed from: home_ee_y + 0.5 * (handle_y - home_ee_y)
    # With HOME_JOINTS=[0, 0.15, 0, -1.7, 0, 1.7, 0.785], home_ee_y≈0.088, handle_y≈-0.35
    # => wall_y = 0.088 + 0.5*(-0.35-0.088) = -0.131
    "wall_y": -0.13,  # At control point plane (pos_frac=0.5)

    # Wall bounds in X (left-right from robot's perspective)
    "wall_min_x": 0.30,
    "wall_max_x": 0.35,

    # Wall bounds in Z (vertical extent)
    "wall_min_z": 1.15,
    "wall_max_z": 1.2,

    # Visual properties
    "wall_thickness": 0.002,  # 2mm for visibility
    "wall_color": [1.0, 0.2, 0.2],  # Red-ish
    "wall_transparency": 0.6,  # Semi-transparent

    # Opening in the wall (where some trajectories can pass)
    # Set to None to have a solid wall, or define opening bounds
    "opening": None,  # Example: {"min_x": 0.2, "max_x": 0.4, "min_z": 0.8, "max_z": 1.0}
}

# Default opening configuration (absolute bounds)
# Opening allows trajectories to pass through to reach the handle
# Handle position is approximately [0.325, -0.25, 1.173]
DEFAULT_OPENING_CONFIG = {
    "min_x": 0.2,
    "max_x": 0.3,
    "min_z": 1.2,
    "max_z": 1.3,
}
# DEFAULT_OPENING_CONFIG = {
#     "min_x": 0.25,
#     "max_x": 0.27,
#     "min_z": 1.21,
#     "max_z": 1.24,
# }

# ============================================================================
# Predefined Wall Styles
# ============================================================================

# Style 1: Only 0 degree could success (no opening)
WALL_STYLE_1 = DEFAULT_WALL_CONFIG.copy()
WALL_STYLE_1["wall_min_x"] = 0.25
WALL_STYLE_1["wall_max_x"] = 0.55
WALL_STYLE_1["wall_max_z"] = 1.3
WALL_STYLE_1["wall_min_z"] = 0.5
WALL_STYLE_1["opening"] = None

# Style 2: Only 270 degree could success (no opening)
WALL_STYLE_2 = DEFAULT_WALL_CONFIG.copy()
WALL_STYLE_2["wall_min_x"] = 0.2
WALL_STYLE_2["wall_max_x"] = 0.5
WALL_STYLE_2["wall_max_z"] = 1.23
WALL_STYLE_2["wall_min_z"] = 0.5
WALL_STYLE_2["opening"] = None

# Style 3: With opening enabled
WALL_STYLE_3 = DEFAULT_WALL_CONFIG.copy()
WALL_STYLE_3["wall_min_x"] = 0.2
WALL_STYLE_3["wall_max_x"] = 0.5
WALL_STYLE_3["wall_max_z"] = 1.3
WALL_STYLE_3["wall_min_z"] = 0.5
WALL_STYLE_3["opening"] = DEFAULT_OPENING_CONFIG.copy()

# Dictionary to access styles by number
WALL_STYLES = {
    1: WALL_STYLE_1,
    2: WALL_STYLE_2,
    3: WALL_STYLE_3,
}


# ============================================================================
# Wall Creation and Management
# ============================================================================

def create_wall(task_env, wall_config=None, name="trajectory_wall", force_recreate=True):
    """
    Create a visual wall in the scene.

    The wall is a thin, semi-transparent cuboid that doesn't interact with physics
    but is rendered in the video frames.

    Args:
        task_env: RLBench task environment
        wall_config: dict with wall configuration (uses DEFAULT_WALL_CONFIG if None)
        name: str, name for the wall shape object
        force_recreate: if True, always remove existing wall and create new one
                       (ensures wall size matches config)

    Returns:
        wall: Shape object representing the wall
    """
    if wall_config is None:
        wall_config = DEFAULT_WALL_CONFIG.copy()

    # Check if wall already exists
    try:
        existing_wall = Shape(name)
        if force_recreate:
            # Remove existing wall to recreate with new config
            existing_wall.remove()
            print(f"  Removed existing wall '{name}' for recreation")
        else:
            # Wall exists, just update its position/properties
            update_wall_properties(existing_wall, wall_config)
            return existing_wall
    except Exception:
        pass  # Wall doesn't exist, create it

    # Calculate wall dimensions
    wall_width = wall_config["wall_max_x"] - wall_config["wall_min_x"]
    wall_height = wall_config["wall_max_z"] - wall_config["wall_min_z"]
    wall_thickness = wall_config["wall_thickness"]

    # Create the wall shape
    wall = Shape.create(
        type=PrimitiveShape.CUBOID,
        size=[wall_width, wall_thickness, wall_height],
        mass=0.0,
        respondable=False,  # Don't interact with physics
        static=True,
        renderable=True,
        color=wall_config["wall_color"]
    )
    wall.set_name(name)

    # Position the wall
    wall_center_x = (wall_config["wall_min_x"] + wall_config["wall_max_x"]) / 2
    wall_center_z = (wall_config["wall_min_z"] + wall_config["wall_max_z"]) / 2
    wall.set_position([wall_center_x, wall_config["wall_y"], wall_center_z])

    # Set transparency
    try:
        # PyRep may have different methods for transparency
        wall.set_transparency(wall_config["wall_transparency"])
    except AttributeError:
        # Fallback: try setting color with alpha
        pass

    # Step simulation to register the object
    for _ in range(5):
        task_env._scene.pyrep.step()

    print(f"  Wall created: y={wall_config['wall_y']:.2f}, "
          f"x=[{wall_config['wall_min_x']:.2f}, {wall_config['wall_max_x']:.2f}], "
          f"z=[{wall_config['wall_min_z']:.2f}, {wall_config['wall_max_z']:.2f}]")

    return wall


def update_wall_properties(wall, wall_config):
    """Update wall position, size, and properties."""
    wall_width = wall_config["wall_max_x"] - wall_config["wall_min_x"]
    wall_height = wall_config["wall_max_z"] - wall_config["wall_min_z"]
    wall_thickness = wall_config.get("wall_thickness", 0.002)

    # Update wall size (width, thickness, height)
    try:
        wall.set_bounding_box([wall_width, wall_thickness, wall_height])
    except Exception:
        # Fallback: remove and recreate wall if size can't be updated
        # This happens because some PyRep versions don't support set_bounding_box
        pass

    # Update wall position
    wall_center_x = (wall_config["wall_min_x"] + wall_config["wall_max_x"]) / 2
    wall_center_z = (wall_config["wall_min_z"] + wall_config["wall_max_z"]) / 2
    wall.set_position([wall_center_x, wall_config["wall_y"], wall_center_z])

    print(f"  Wall updated: y={wall_config['wall_y']:.2f}, "
          f"x=[{wall_config['wall_min_x']:.2f}, {wall_config['wall_max_x']:.2f}], "
          f"z=[{wall_config['wall_min_z']:.2f}, {wall_config['wall_max_z']:.2f}]")


def remove_wall(name="trajectory_wall"):
    """Remove the wall from the scene."""
    try:
        wall = Shape(name)
        wall.remove()
        return True
    except Exception:
        return False


# ============================================================================
# Collision Detection
# ============================================================================

def get_robot_link_positions(task_env):
    """
    Get positions of all robot arm links and gripper.

    Returns:
        list of (name, position) tuples for each link
    """
    robot = task_env._scene.robot
    arm = robot.arm
    gripper = robot.gripper

    link_positions = []

    # Get joint positions (each joint is a link)
    for i, joint in enumerate(arm.joints):
        pos = np.array(joint.get_position())
        link_positions.append((f"joint_{i}", pos))

    # Get end-effector tip position
    tip = arm.get_tip()
    tip_pos = np.array(tip.get_position())
    link_positions.append(("tip", tip_pos))

    # Get gripper finger positions if accessible
    try:
        for i, gripper_joint in enumerate(gripper.joints):
            pos = np.array(gripper_joint.get_position())
            link_positions.append((f"gripper_{i}", pos))
    except Exception:
        pass

    return link_positions


def check_wall_collision(task_env, wall_config=None, check_crossing=True):
    """
    Check if any robot arm link has crossed the wall plane.

    The wall is a plane at y = wall_y. Collision occurs when:
    1. A link is at y < wall_y (crossed to the cabinet side)
    2. The link's x and z are within the wall bounds
    3. The link is not within any defined opening

    Args:
        task_env: RLBench task environment
        wall_config: dict with wall configuration
        check_crossing: if True, check if link crossed the plane;
                       if False, check if link is exactly at plane

    Returns:
        collision: bool, True if collision detected
        collision_link: str or None, name of the link that collided
        collision_pos: np.ndarray or None, position of the colliding link
    """
    if wall_config is None:
        wall_config = DEFAULT_WALL_CONFIG.copy()

    wall_y = wall_config["wall_y"]
    min_x = wall_config["wall_min_x"]
    max_x = wall_config["wall_max_x"]
    min_z = wall_config["wall_min_z"]
    max_z = wall_config["wall_max_z"]
    opening = wall_config.get("opening", None)

    link_positions = get_robot_link_positions(task_env)

    for link_name, pos in link_positions:
        # Check if link is within wall X-Z bounds
        if not (min_x <= pos[0] <= max_x and min_z <= pos[2] <= max_z):
            continue  # Link is outside wall bounds, no collision

        # Check if link is within opening (if defined)
        if opening is not None:
            if (opening["min_x"] <= pos[0] <= opening["max_x"] and
                opening["min_z"] <= pos[2] <= opening["max_z"]):
                continue  # Link is in the opening, no collision

        # Check if link crossed the wall (y < wall_y means crossed toward cabinet side)
        # Wall blocks trajectories that go past the wall toward the cabinet
        if check_crossing:
            if pos[1] < wall_y:
                return True, link_name, pos
        else:
            # Exact plane check (within small tolerance)
            if abs(pos[1] - wall_y) < 0.01:
                return True, link_name, pos

    return False, None, None


def check_trajectory_wall_collision(ee_positions, wall_config=None):
    """
    Check if an end-effector trajectory crosses the wall.

    This is a simplified check that only looks at EE positions,
    not the full arm geometry. Useful for quick pre-filtering.

    Args:
        ee_positions: list or array of EE positions (Nx3)
        wall_config: dict with wall configuration

    Returns:
        collision: bool, True if trajectory crosses wall
        collision_idx: int or None, index of first collision point
        collision_pos: np.ndarray or None, position at collision
    """
    if wall_config is None:
        wall_config = DEFAULT_WALL_CONFIG.copy()

    wall_y = wall_config["wall_y"]
    min_x = wall_config["wall_min_x"]
    max_x = wall_config["wall_max_x"]
    min_z = wall_config["wall_min_z"]
    max_z = wall_config["wall_max_z"]
    opening = wall_config.get("opening", None)

    for i, pos in enumerate(ee_positions):
        pos = np.array(pos)

        # Check if within wall bounds
        if not (min_x <= pos[0] <= max_x and min_z <= pos[2] <= max_z):
            continue

        # Check if within opening
        if opening is not None:
            if (opening["min_x"] <= pos[0] <= opening["max_x"] and
                opening["min_z"] <= pos[2] <= opening["max_z"]):
                continue

        # Check if crossed wall (y < wall_y means crossed toward cabinet side)
        if pos[1] < wall_y:
            return True, i, pos

    return False, None, None


# ============================================================================
# Video Overlay Functions
# ============================================================================

def project_wall_to_image(wall_config, camera, image_size):
    """
    Project wall corners to 2D image coordinates for overlay.

    Args:
        wall_config: dict with wall configuration
        camera: camera object for projection
        image_size: (height, width) of the image

    Returns:
        corners_2d: list of 4 (u, v) tuples for wall corners
    """
    # Get wall corners in 3D (4 corners of the rectangular wall)
    wall_y = wall_config["wall_y"]
    corners_3d = [
        [wall_config["wall_min_x"], wall_y, wall_config["wall_min_z"]],
        [wall_config["wall_max_x"], wall_y, wall_config["wall_min_z"]],
        [wall_config["wall_max_x"], wall_y, wall_config["wall_max_z"]],
        [wall_config["wall_min_x"], wall_y, wall_config["wall_max_z"]],
    ]

    # Project to 2D using camera intrinsics
    cam_pos = np.array(camera.get_position())
    cam_matrix = camera.get_matrix()[:3, :3]

    try:
        fov_deg = camera.get_perspective_angle()
        fov = fov_deg * np.pi / 180.0
    except:
        fov = 60.0 * np.pi / 180.0

    f = image_size[0] / (2.0 * np.tan(fov / 2.0))
    cx, cy = image_size[1] / 2.0, image_size[0] / 2.0

    corners_2d = []
    for p in corners_3d:
        p_rel = np.array(p) - cam_pos
        p_cam = cam_matrix.T @ p_rel
        x_cam, y_cam, z_cam = p_cam[0], p_cam[1], p_cam[2]

        if z_cam > 0.01:
            u = cx - f * x_cam / z_cam
            v = cy - f * y_cam / z_cam
            corners_2d.append((int(u), int(v)))
        else:
            corners_2d.append(None)

    return corners_2d


def overlay_wall_on_frame(frame, wall_config, camera, collision=False):
    """
    Overlay the wall visualization on a video frame.

    Args:
        frame: numpy array (H, W, 3) image
        wall_config: dict with wall configuration
        camera: camera object for projection
        collision: bool, if True, draw wall in red (collision state)

    Returns:
        frame_overlay: frame with wall drawn
    """
    import cv2

    image_size = frame.shape[:2]
    corners_2d = project_wall_to_image(wall_config, camera, image_size)

    # Check if all corners are valid
    if any(c is None for c in corners_2d):
        return frame

    frame_overlay = frame.copy()

    # Draw wall as a filled polygon with transparency
    pts = np.array(corners_2d, dtype=np.int32)

    # Choose color based on collision state
    if collision:
        color = (0, 0, 255)  # Red for collision
        alpha = 0.5
    else:
        color = (100, 100, 255)  # Light red/pink for normal
        alpha = 0.3

    # Create overlay with transparency
    overlay = frame_overlay.copy()
    cv2.fillPoly(overlay, [pts], color)
    cv2.addWeighted(overlay, alpha, frame_overlay, 1 - alpha, 0, frame_overlay)

    # Draw wall outline
    outline_color = (0, 0, 200) if collision else (50, 50, 150)
    cv2.polylines(frame_overlay, [pts], True, outline_color, 2, cv2.LINE_AA)

    return frame_overlay


def overlay_wall_on_frames(frames, wall_config, camera, collision_frame=None):
    """
    Overlay wall on all video frames.

    Args:
        frames: list of numpy array images
        wall_config: dict with wall configuration
        camera: camera object for projection
        collision_frame: int or None, frame index where collision occurred

    Returns:
        frames_with_wall: list of frames with wall overlay
    """
    frames_with_wall = []

    for i, frame in enumerate(frames):
        # Check if this frame is at or after collision
        collision = (collision_frame is not None and i >= collision_frame)
        frame_with_wall = overlay_wall_on_frame(frame, wall_config, camera, collision)
        frames_with_wall.append(frame_with_wall)

    return frames_with_wall


# ============================================================================
# Trajectory Filtering
# ============================================================================

def count_valid_trajectories_with_wall(canonical_params, start_pos, target_pos,
                                        control_point_radius, wall_config=None,
                                        num_samples=50):
    """
    Count how many control point modes produce trajectories that don't cross the wall.

    This is useful for understanding how the wall filters the trajectory distribution.

    Args:
        canonical_params: array of control point parameters
        start_pos: start position (3,)
        target_pos: target position (3,)
        control_point_radius: radius for control point offset
        wall_config: wall configuration
        num_samples: number of points to sample along each trajectory

    Returns:
        valid_count: number of valid (non-colliding) trajectories
        valid_indices: list of valid control point indices
    """
    if wall_config is None:
        wall_config = DEFAULT_WALL_CONFIG.copy()

    # Import trajectory generation function
    from push_utils import compute_control_point_from_params, parabola3D

    valid_indices = []

    for cp_idx, (angle, dist_frac, pos_frac) in enumerate(canonical_params):
        # Compute control point
        control_point = compute_control_point_from_params(
            start_pos, target_pos, control_point_radius, angle, dist_frac, pos_frac
        )

        # Sample trajectory
        ee_positions = []
        for i in range(num_samples):
            t = i / (num_samples - 1)
            pos = parabola3D(start_pos, target_pos, control_point, t)
            ee_positions.append(pos)

        # Check for wall collision
        collision, _, _ = check_trajectory_wall_collision(ee_positions, wall_config)

        if not collision:
            valid_indices.append(cp_idx)

    return len(valid_indices), valid_indices


# ============================================================================
# Integration with Trajectory Generation
# ============================================================================

class WallCollisionTracker:
    """
    Tracker for wall collisions during trajectory execution.

    Usage:
        tracker = WallCollisionTracker(task_env, wall_config)

        for step in trajectory:
            # Execute step...
            if tracker.check_and_update():
                print(f"Collision at step {tracker.collision_step}!")
                break

        # Get collision info
        if tracker.has_collision:
            print(f"Collided at step {tracker.collision_step}")
            print(f"Colliding link: {tracker.collision_link}")
    """

    def __init__(self, task_env, wall_config=None):
        self.task_env = task_env
        self.wall_config = wall_config if wall_config else DEFAULT_WALL_CONFIG.copy()

        self.has_collision = False
        self.collision_step = None
        self.collision_link = None
        self.collision_pos = None
        self.current_step = 0

    def reset(self):
        """Reset tracker for new trajectory."""
        self.has_collision = False
        self.collision_step = None
        self.collision_link = None
        self.collision_pos = None
        self.current_step = 0

    def check_and_update(self):
        """
        Check for collision and update tracker state.

        Returns:
            bool: True if collision detected
        """
        if self.has_collision:
            return True  # Already collided

        collision, link, pos = check_wall_collision(self.task_env, self.wall_config)

        if collision:
            self.has_collision = True
            self.collision_step = self.current_step
            self.collision_link = link
            self.collision_pos = pos

        self.current_step += 1
        return collision

    def get_collision_info(self):
        """Get collision information as dict."""
        return {
            "has_collision": self.has_collision,
            "collision_step": self.collision_step,
            "collision_link": self.collision_link,
            "collision_pos": self.collision_pos.tolist() if self.collision_pos is not None else None,
        }


# ============================================================================
# Testing/Debug Functions
# ============================================================================

def visualize_wall_in_scene(task_env, wall_config=None, duration_steps=100):
    """
    Create wall and visualize it in the scene for debugging.

    Args:
        task_env: RLBench task environment
        wall_config: wall configuration
        duration_steps: number of simulation steps to visualize
    """
    wall = create_wall(task_env, wall_config)

    print("Wall created. Stepping simulation for visualization...")
    for _ in range(duration_steps):
        task_env._scene.pyrep.step()

    return wall


def test_wall_collision(task_env, wall_config=None):
    """
    Test wall collision detection with current robot position.

    Args:
        task_env: RLBench task environment
        wall_config: wall configuration

    Returns:
        collision: bool
        link_name: str or None
        link_pos: array or None
    """
    collision, link, pos = check_wall_collision(task_env, wall_config)

    if collision:
        print(f"Collision detected! Link: {link}, Position: {pos}")
    else:
        print("No collision detected.")

    return collision, link, pos
