# filename: wall_collision.py
"""
Wall collision detection and visualization for pick-and-place task.

The wall is a 2D plane positioned between the robot and object/target positions.
Trajectories that cause any part of the robot arm to cross this plane are rejected.

Key Features:
  - Create visual wall (thin semi-transparent cuboid) for rendering
  - Geometric plane collision check for all robot arm links
  - Track collision state during trajectory execution
  - Overlay wall on video frames

For pick-and-place (stack_blocks):
  - Robot HOME position: ~[0.278, -0.008, 1.472]
  - Object position: ~[0.250, 0.100, 0.775]
  - Target position: ~[-0.050, 0.000, 0.775]
  - Wall can be placed in Y or X plane depending on desired filtering

Usage:
  wall = create_wall(task_env, wall_config)

  for step in trajectory:
      # Execute step...
      collision, link_idx, link_pos = check_wall_collision(task_env, wall_config)
      if collision:
          print(f"Wall collision at link {link_idx}!")
          break  # Stop trajectory
"""

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.const import PrimitiveShape


# ============================================================================
# Default Wall Configuration for Pick-and-Place
# ============================================================================

# Key positions from stack_blocks_init.npz:
#   HOME position:       Y ≈ -0.008 (near Y=0)
#   Object position:     Y ≈ -0.048, X ≈ 0.50, Z ≈ 0.775
#   Target position:     Y ≈ +0.043, X ≈ 0.24
#   Pregrasp position:   Y ≈ -0.048, X ≈ 0.50, Z ≈ 0.825
#   Prerelease position: Y ≈ +0.043, X ≈ 0.24, Z ≈ 0.97
#
# REACH phase: HOME (Y=-0.008) → Pregrasp (Y=-0.048)
#   - Wall can block at Y=-0.02 to Y=-0.03 (between HOME and pregrasp)
#
# CARRY phase: Lift (Y=-0.048) → Prerelease (Y=+0.043)
#   - Wall can block at Y=0.0 (between lift and prerelease)
#   - Trajectory must cross from negative Y to positive Y

DEFAULT_WALL_CONFIG = {
    # Wall plane position (Y coordinate where the wall sits)
    # Default at Y=0.0 blocks CARRY trajectories that go through the middle
    "wall_y": 0.0,

    # Wall bounds in X (left-right from robot's perspective)
    # Object at X~0.50, target at X~0.24, so wall should span this range
    "wall_min_x": 0.20,
    "wall_max_x": 0.55,

    # Wall bounds in Z (vertical extent)
    # Trajectories go from Z~0.825 (pregrasp) up to Z~1.3 (arced paths)
    "wall_min_z": 0.75,
    "wall_max_z": 1.15,

    # Visual properties
    "wall_thickness": 0.002,  # 2mm for visibility
    "wall_color": [1.0, 0.2, 0.2],  # Red-ish
    "wall_transparency": 0.6,  # Semi-transparent

    # Opening in the wall (where some trajectories can pass)
    # Set to None to have a solid wall, or define opening bounds
    "opening": None,
}

# Default opening configuration (absolute bounds)
# Opening allows trajectories to pass through to reach the target
DEFAULT_OPENING_CONFIG = {
    "min_x": 0.30,
    "max_x": 0.45,
    "min_z": 1.0,
    "max_z": 1.15,
}

# ============================================================================
# Predefined Wall Styles for Pick-and-Place
# ============================================================================

# Style 1: Wall at Y=0, blocks CARRY trajectories going through center
# This primarily affects the CARRY phase (lift → prerelease)
# Only trajectories that arc high (Z > 1.15) or wide (X < 0.20 or X > 0.55) can pass
WALL_STYLE_1 = DEFAULT_WALL_CONFIG.copy()
WALL_STYLE_1["wall_y"] = 0.0
WALL_STYLE_1["wall_min_x"] = 0.20
WALL_STYLE_1["wall_max_x"] = 0.55
WALL_STYLE_1["wall_min_z"] = 0.75
WALL_STYLE_1["wall_max_z"] = 1.15
WALL_STYLE_1["opening"] = None

# Style 2: Wall at Y=-0.02, blocks REACH trajectories
# This affects the REACH phase (HOME → pregrasp)
# Forces trajectories to arc above or around the wall
WALL_STYLE_2 = DEFAULT_WALL_CONFIG.copy()
WALL_STYLE_2["wall_y"] = -0.02
WALL_STYLE_2["wall_min_x"] = 0.30
WALL_STYLE_2["wall_max_x"] = 0.55
WALL_STYLE_2["wall_min_z"] = 0.80
WALL_STYLE_2["wall_max_z"] = 1.10
WALL_STYLE_2["opening"] = None

# Style 3: Wall at Y=0 with opening for high trajectories
# CARRY trajectories that arc high enough can pass through the opening
WALL_STYLE_3 = DEFAULT_WALL_CONFIG.copy()
WALL_STYLE_3["wall_y"] = 0.0
WALL_STYLE_3["wall_min_x"] = 0.15
WALL_STYLE_3["wall_max_x"] = 0.55
WALL_STYLE_3["wall_min_z"] = 0.75
WALL_STYLE_3["wall_max_z"] = 1.25
WALL_STYLE_3["opening"] = {
    "min_x": 0.30,
    "max_x": 0.45,
    "min_z": 1.05,
    "max_z": 1.20,
}

# Dictionary to access styles by number
WALL_STYLES = {
    1: WALL_STYLE_1,
    2: WALL_STYLE_2,
    3: WALL_STYLE_3,
}


# ============================================================================
# Wall Creation and Management
# ============================================================================

def create_wall(task_env, wall_config=None, name="trajectory_wall", force_recreate=True):
    """
    Create a visual wall in the scene.

    The wall is a thin, semi-transparent cuboid that doesn't interact with physics
    but is rendered in the video frames.

    Args:
        task_env: RLBench task environment
        wall_config: dict with wall configuration (uses DEFAULT_WALL_CONFIG if None)
        name: str, name for the wall shape object
        force_recreate: if True, always remove existing wall and create new one
                       (ensures wall size matches config)

    Returns:
        wall: Shape object representing the wall
    """
    if wall_config is None:
        wall_config = DEFAULT_WALL_CONFIG.copy()

    # Check if wall already exists
    try:
        existing_wall = Shape(name)
        if force_recreate:
            # Remove existing wall to recreate with new config
            existing_wall.remove()
            print(f"  Removed existing wall '{name}' for recreation")
        else:
            # Wall exists, just update its position/properties
            update_wall_properties(existing_wall, wall_config)
            return existing_wall
    except Exception:
        pass  # Wall doesn't exist, create it

    # Calculate wall dimensions
    wall_width = wall_config["wall_max_x"] - wall_config["wall_min_x"]
    wall_height = wall_config["wall_max_z"] - wall_config["wall_min_z"]
    wall_thickness = wall_config["wall_thickness"]

    # Create the wall shape
    wall = Shape.create(
        type=PrimitiveShape.CUBOID,
        size=[wall_width, wall_thickness, wall_height],
        mass=0.0,
        respondable=False,  # Don't interact with physics
        static=True,
        renderable=True,
        color=wall_config["wall_color"]
    )
    wall.set_name(name)

    # Position the wall
    wall_center_x = (wall_config["wall_min_x"] + wall_config["wall_max_x"]) / 2
    wall_center_z = (wall_config["wall_min_z"] + wall_config["wall_max_z"]) / 2
    wall.set_position([wall_center_x, wall_config["wall_y"], wall_center_z])

    # Set transparency
    try:
        wall.set_transparency(wall_config["wall_transparency"])
    except AttributeError:
        pass

    # Step simulation to register the object
    for _ in range(5):
        task_env._scene.pyrep.step()

    print(f"  Wall created: y={wall_config['wall_y']:.2f}, "
          f"x=[{wall_config['wall_min_x']:.2f}, {wall_config['wall_max_x']:.2f}], "
          f"z=[{wall_config['wall_min_z']:.2f}, {wall_config['wall_max_z']:.2f}]")

    return wall


def update_wall_properties(wall, wall_config):
    """Update wall position, size, and properties."""
    wall_width = wall_config["wall_max_x"] - wall_config["wall_min_x"]
    wall_height = wall_config["wall_max_z"] - wall_config["wall_min_z"]
    wall_thickness = wall_config.get("wall_thickness", 0.002)

    # Update wall size (width, thickness, height)
    try:
        wall.set_bounding_box([wall_width, wall_thickness, wall_height])
    except Exception:
        pass

    # Update wall position
    wall_center_x = (wall_config["wall_min_x"] + wall_config["wall_max_x"]) / 2
    wall_center_z = (wall_config["wall_min_z"] + wall_config["wall_max_z"]) / 2
    wall.set_position([wall_center_x, wall_config["wall_y"], wall_center_z])

    print(f"  Wall updated: y={wall_config['wall_y']:.2f}, "
          f"x=[{wall_config['wall_min_x']:.2f}, {wall_config['wall_max_x']:.2f}], "
          f"z=[{wall_config['wall_min_z']:.2f}, {wall_config['wall_max_z']:.2f}]")


def remove_wall(name="trajectory_wall"):
    """Remove the wall from the scene."""
    try:
        wall = Shape(name)
        wall.remove()
        return True
    except Exception:
        return False


# ============================================================================
# Collision Detection
# ============================================================================

def get_robot_link_positions(task_env):
    """
    Get positions of all robot arm links and gripper.

    Returns:
        list of (name, position) tuples for each link
    """
    robot = task_env._scene.robot
    arm = robot.arm
    gripper = robot.gripper

    link_positions = []

    # Get joint positions (each joint is a link)
    for i, joint in enumerate(arm.joints):
        pos = np.array(joint.get_position())
        link_positions.append((f"joint_{i}", pos))

    # Get end-effector tip position
    tip = arm.get_tip()
    tip_pos = np.array(tip.get_position())
    link_positions.append(("tip", tip_pos))

    # Get gripper finger positions if accessible
    try:
        for i, gripper_joint in enumerate(gripper.joints):
            pos = np.array(gripper_joint.get_position())
            link_positions.append((f"gripper_{i}", pos))
    except Exception:
        pass

    return link_positions


def check_wall_collision(task_env, wall_config=None, check_crossing=True):
    """
    Check if any robot arm link has crossed the wall plane.

    The wall is a plane at y = wall_y. Collision occurs when:
    1. A link is at y > wall_y (crossed to the object side for pick-place)
    2. The link's x and z are within the wall bounds
    3. The link is not within any defined opening

    Args:
        task_env: RLBench task environment
        wall_config: dict with wall configuration
        check_crossing: if True, check if link crossed the plane;
                       if False, check if link is exactly at plane

    Returns:
        collision: bool, True if collision detected
        collision_link: str or None, name of the link that collided
        collision_pos: np.ndarray or None, position of the colliding link
    """
    if wall_config is None:
        wall_config = DEFAULT_WALL_CONFIG.copy()

    wall_y = wall_config["wall_y"]
    min_x = wall_config["wall_min_x"]
    max_x = wall_config["wall_max_x"]
    min_z = wall_config["wall_min_z"]
    max_z = wall_config["wall_max_z"]
    opening = wall_config.get("opening", None)

    link_positions = get_robot_link_positions(task_env)

    for link_name, pos in link_positions:
        # Check if link is within wall X-Z bounds
        if not (min_x <= pos[0] <= max_x and min_z <= pos[2] <= max_z):
            continue  # Link is outside wall bounds, no collision

        # Check if link is within opening (if defined)
        if opening is not None:
            if (opening["min_x"] <= pos[0] <= opening["max_x"] and
                opening["min_z"] <= pos[2] <= opening["max_z"]):
                continue  # Link is in the opening, no collision

        # Check if link crossed the wall
        # For pick-place, wall blocks trajectories going from robot side (y<wall_y) to object side (y>wall_y)
        if check_crossing:
            if pos[1] > wall_y:
                return True, link_name, pos
        else:
            # Exact plane check (within small tolerance)
            if abs(pos[1] - wall_y) < 0.01:
                return True, link_name, pos

    return False, None, None


def check_trajectory_wall_collision(ee_positions, wall_config=None):
    """
    Check if an end-effector trajectory crosses the wall.

    This is a simplified check that only looks at EE positions,
    not the full arm geometry. Useful for quick pre-filtering.

    Args:
        ee_positions: list or array of EE positions (Nx3)
        wall_config: dict with wall configuration

    Returns:
        collision: bool, True if trajectory crosses wall
        collision_idx: int or None, index of first collision point
        collision_pos: np.ndarray or None, position at collision
    """
    if wall_config is None:
        wall_config = DEFAULT_WALL_CONFIG.copy()

    wall_y = wall_config["wall_y"]
    min_x = wall_config["wall_min_x"]
    max_x = wall_config["wall_max_x"]
    min_z = wall_config["wall_min_z"]
    max_z = wall_config["wall_max_z"]
    opening = wall_config.get("opening", None)

    for i, pos in enumerate(ee_positions):
        pos = np.array(pos)

        # Check if within wall bounds
        if not (min_x <= pos[0] <= max_x and min_z <= pos[2] <= max_z):
            continue

        # Check if within opening
        if opening is not None:
            if (opening["min_x"] <= pos[0] <= opening["max_x"] and
                opening["min_z"] <= pos[2] <= opening["max_z"]):
                continue

        # Check if crossed wall (y > wall_y means crossed toward object side)
        if pos[1] > wall_y:
            return True, i, pos

    return False, None, None


def check_ee_trajectory_wall_collision(ee_positions, wall_config, debug=False):
    """
    Check if EE trajectory crosses THROUGH the wall (EE only, not full arm).

    Collision is detected when EE transitions from Y <= wall_y to Y > wall_y
    while within the wall's X-Z bounds.

    Returns:
        collision: bool
        collision_idx: int or None, index where collision occurred
    """
    wall_y = wall_config["wall_y"]
    min_x = wall_config["wall_min_x"]
    max_x = wall_config["wall_max_x"]
    min_z = wall_config["wall_min_z"]
    max_z = wall_config["wall_max_z"]
    opening = wall_config.get("opening", None)

    for i in range(1, len(ee_positions)):
        prev_pos = ee_positions[i - 1]
        curr_pos = ee_positions[i]

        # Check if EE crossed the wall Y plane in this step
        # (was on robot side, now on object side)
        if prev_pos[1] <= wall_y and curr_pos[1] > wall_y:
            # Interpolate to find approximate position at wall crossing
            if abs(curr_pos[1] - prev_pos[1]) > 1e-6:
                t = (wall_y - prev_pos[1]) / (curr_pos[1] - prev_pos[1])
                cross_x = prev_pos[0] + t * (curr_pos[0] - prev_pos[0])
                cross_z = prev_pos[2] + t * (curr_pos[2] - prev_pos[2])
            else:
                cross_x = curr_pos[0]
                cross_z = curr_pos[2]

            # Check if crossing point is within wall X-Z bounds
            in_x_bounds = min_x <= cross_x <= max_x
            in_z_bounds = min_z <= cross_z <= max_z

            if debug:
                print(f"    [DEBUG] Y crossing at step {i}: "
                      f"cross_pos=[{cross_x:.3f}, {wall_y:.3f}, {cross_z:.3f}], "
                      f"in_x_bounds={in_x_bounds} (x in [{min_x:.2f},{max_x:.2f}]), "
                      f"in_z_bounds={in_z_bounds}")

            if not (in_x_bounds and in_z_bounds):
                # Crossed Y plane but outside wall bounds - bypassed the wall
                if debug:
                    print(f"    [DEBUG] -> Bypassed wall (outside X-Z bounds)")
                continue

            # Check if crossing point is within opening
            if opening is not None:
                if (opening["min_x"] <= cross_x <= opening["max_x"] and
                    opening["min_z"] <= cross_z <= opening["max_z"]):
                    if debug:
                        print(f"    [DEBUG] -> Passed through opening")
                    continue

            # Collision: crossed wall while within bounds
            if debug:
                print(f"    [DEBUG] -> COLLISION at wall")
            return True, i

    return False, None


# ============================================================================
# Video Overlay Functions
# ============================================================================

def project_wall_to_image(wall_config, camera, image_size):
    """
    Project wall corners to 2D image coordinates for overlay.

    Args:
        wall_config: dict with wall configuration
        camera: camera object for projection
        image_size: (height, width) of the image

    Returns:
        corners_2d: list of 4 (u, v) tuples for wall corners
    """
    wall_y = wall_config["wall_y"]
    corners_3d = [
        [wall_config["wall_min_x"], wall_y, wall_config["wall_min_z"]],
        [wall_config["wall_max_x"], wall_y, wall_config["wall_min_z"]],
        [wall_config["wall_max_x"], wall_y, wall_config["wall_max_z"]],
        [wall_config["wall_min_x"], wall_y, wall_config["wall_max_z"]],
    ]

    cam_pos = np.array(camera.get_position())
    cam_matrix = camera.get_matrix()[:3, :3]

    try:
        fov_deg = camera.get_perspective_angle()
        fov = fov_deg * np.pi / 180.0
    except:
        fov = 60.0 * np.pi / 180.0

    f = image_size[0] / (2.0 * np.tan(fov / 2.0))
    cx, cy = image_size[1] / 2.0, image_size[0] / 2.0

    corners_2d = []
    for p in corners_3d:
        p_rel = np.array(p) - cam_pos
        p_cam = cam_matrix.T @ p_rel
        x_cam, y_cam, z_cam = p_cam[0], p_cam[1], p_cam[2]

        if z_cam > 0.01:
            u = cx - f * x_cam / z_cam
            v = cy - f * y_cam / z_cam
            corners_2d.append((int(u), int(v)))
        else:
            corners_2d.append(None)

    return corners_2d


def overlay_wall_on_frame(frame, wall_config, camera, collision=False):
    """
    Overlay the wall visualization on a video frame.

    Args:
        frame: numpy array (H, W, 3) image
        wall_config: dict with wall configuration
        camera: camera object for projection
        collision: bool, if True, draw wall in red (collision state)

    Returns:
        frame_overlay: frame with wall drawn
    """
    import cv2

    image_size = frame.shape[:2]
    corners_2d = project_wall_to_image(wall_config, camera, image_size)

    # Check if all corners are valid
    if any(c is None for c in corners_2d):
        return frame

    frame_overlay = frame.copy()

    # Draw wall as a filled polygon with transparency
    pts = np.array(corners_2d, dtype=np.int32)

    # Choose color based on collision state
    if collision:
        color = (0, 0, 255)  # Red for collision
        alpha = 0.5
    else:
        color = (100, 100, 255)  # Light red/pink for normal
        alpha = 0.3

    # Create overlay with transparency
    overlay = frame_overlay.copy()
    cv2.fillPoly(overlay, [pts], color)
    cv2.addWeighted(overlay, alpha, frame_overlay, 1 - alpha, 0, frame_overlay)

    # Draw wall outline
    outline_color = (0, 0, 200) if collision else (50, 50, 150)
    cv2.polylines(frame_overlay, [pts], True, outline_color, 2, cv2.LINE_AA)

    return frame_overlay


def overlay_wall_on_frames(frames, wall_config, camera, collision_frame=None):
    """
    Overlay wall on all video frames.

    Args:
        frames: list of numpy array images
        wall_config: dict with wall configuration
        camera: camera object for projection
        collision_frame: int or None, frame index where collision occurred

    Returns:
        frames_with_wall: list of frames with wall overlay
    """
    frames_with_wall = []

    for i, frame in enumerate(frames):
        # Check if this frame is at or after collision
        collision = (collision_frame is not None and i >= collision_frame)
        frame_with_wall = overlay_wall_on_frame(frame, wall_config, camera, collision)
        frames_with_wall.append(frame_with_wall)

    return frames_with_wall


# ============================================================================
# Trajectory Filtering
# ============================================================================

def count_valid_trajectories_with_wall(canonical_params, start_pos, target_pos,
                                        control_point_radius, wall_config=None,
                                        num_samples=50):
    """
    Count how many control point modes produce trajectories that don't cross the wall.

    Args:
        canonical_params: array of control point parameters
        start_pos: start position (3,)
        target_pos: target position (3,)
        control_point_radius: radius for control point offset
        wall_config: wall configuration
        num_samples: number of points to sample along each trajectory

    Returns:
        valid_count: number of valid (non-colliding) trajectories
        valid_indices: list of valid control point indices
    """
    if wall_config is None:
        wall_config = DEFAULT_WALL_CONFIG.copy()

    from utils import compute_control_point_from_params, parabola3D

    valid_indices = []

    for cp_idx, (angle, dist_frac, pos_frac) in enumerate(canonical_params):
        # Compute control point
        control_point = compute_control_point_from_params(
            start_pos, target_pos, control_point_radius, angle, dist_frac, pos_frac
        )

        # Sample trajectory
        ee_positions = []
        for i in range(num_samples):
            t = i / (num_samples - 1)
            pos = parabola3D(start_pos, target_pos, control_point, t)
            ee_positions.append(pos)

        # Check for wall collision
        collision, _, _ = check_trajectory_wall_collision(ee_positions, wall_config)

        if not collision:
            valid_indices.append(cp_idx)

    return len(valid_indices), valid_indices


# ============================================================================
# Integration with Trajectory Generation
# ============================================================================

class WallCollisionTracker:
    """
    Tracker for wall collisions during trajectory execution.

    Usage:
        tracker = WallCollisionTracker(task_env, wall_config)

        for step in trajectory:
            # Execute step...
            if tracker.check_and_update():
                print(f"Collision at step {tracker.collision_step}!")
                break

        # Get collision info
        if tracker.has_collision:
            print(f"Collided at step {tracker.collision_step}")
            print(f"Colliding link: {tracker.collision_link}")
    """

    def __init__(self, task_env, wall_config=None):
        self.task_env = task_env
        self.wall_config = wall_config if wall_config else DEFAULT_WALL_CONFIG.copy()

        self.has_collision = False
        self.collision_step = None
        self.collision_link = None
        self.collision_pos = None
        self.current_step = 0

    def reset(self):
        """Reset tracker for new trajectory."""
        self.has_collision = False
        self.collision_step = None
        self.collision_link = None
        self.collision_pos = None
        self.current_step = 0

    def check_and_update(self):
        """
        Check for collision and update tracker state.

        Returns:
            bool: True if collision detected
        """
        if self.has_collision:
            return True  # Already collided

        collision, link, pos = check_wall_collision(self.task_env, self.wall_config)

        if collision:
            self.has_collision = True
            self.collision_step = self.current_step
            self.collision_link = link
            self.collision_pos = pos

        self.current_step += 1
        return collision

    def get_collision_info(self):
        """Get collision information as dict."""
        return {
            "has_collision": self.has_collision,
            "collision_step": self.collision_step,
            "collision_link": self.collision_link,
            "collision_pos": self.collision_pos.tolist() if self.collision_pos is not None else None,
        }


# ============================================================================
# Testing/Debug Functions
# ============================================================================

def visualize_wall_in_scene(task_env, wall_config=None, duration_steps=100):
    """
    Create wall and visualize it in the scene for debugging.

    Args:
        task_env: RLBench task environment
        wall_config: wall configuration
        duration_steps: number of simulation steps to visualize
    """
    wall = create_wall(task_env, wall_config)

    print("Wall created. Stepping simulation for visualization...")
    for _ in range(duration_steps):
        task_env._scene.pyrep.step()

    return wall


def test_wall_collision(task_env, wall_config=None):
    """
    Test wall collision detection with current robot position.

    Args:
        task_env: RLBench task environment
        wall_config: wall configuration

    Returns:
        collision: bool
        link_name: str or None
        link_pos: array or None
    """
    collision, link, pos = check_wall_collision(task_env, wall_config)

    if collision:
        print(f"Collision detected! Link: {link}, Position: {pos}")
    else:
        print("No collision detected.")

    return collision, link, pos
