#!/usr/bin/env python
"""
Wall configuration for pick-and-place task using trajectory-relative coordinates.

The wall is defined perpendicular to the start→end line, positioned at a specific
point along the trajectory (default: midpoint at pos_frac=0.5).

Wall parameters use the SAME coordinate system as control points:
  - angle: 0-360 degrees, where 0° is "up" relative to trajectory
  - distance: fraction of (radius * path_length), same as CP dist_frac
  - The wall plane is perpendicular to the trajectory direction

Wall definition:
  - position: pos_frac along start→end (0.0=start, 1.0=end, 0.5=middle)
  - corner: (angle, distance) - location of one corner in trajectory-relative coords
  - width: extent along perp1 direction (in units of radius * path_length)
  - height: extent along perp2 direction (in units of radius * path_length)

This allows walls to be defined consistently for both REACH and CARRY phases
using the same parameters, just like control points.
"""

import numpy as np


def build_local_frame(start_pos, target_pos):
    """
    Return (line_vec_norm, perp1, perp2) forming an orthonormal frame.
    Same as utils.py build_local_frame.
    """
    line_vec = target_pos - start_pos
    line_vec_norm = line_vec / np.linalg.norm(line_vec)

    world_up = np.array([0.0, 0.0, 1.0])
    dot = np.dot(world_up, line_vec_norm)
    perp1 = world_up - dot * line_vec_norm
    perp1_len = np.linalg.norm(perp1)

    if perp1_len < 1e-6:
        world_forward = np.array([0.0, 1.0, 0.0])
        dot = np.dot(world_forward, line_vec_norm)
        perp1 = world_forward - dot * line_vec_norm
        perp1_len = np.linalg.norm(perp1)

    perp1 = perp1 / perp1_len
    perp2 = np.cross(line_vec_norm, perp1)
    perp2 = perp2 / np.linalg.norm(perp2)

    return line_vec_norm, perp1, perp2


def trajectory_to_world(start_pos, end_pos, radius, pos_frac, angle_deg, dist_frac):
    """
    Convert trajectory-relative coordinates to world coordinates.

    Args:
        start_pos: np.ndarray(3,), start position
        end_pos: np.ndarray(3,), end position
        radius: float, same radius used for control points
        pos_frac: float, position along trajectory (0=start, 1=end)
        angle_deg: float, angle in degrees (0=up, 90=right, etc.)
        dist_frac: float, distance as fraction of radius

    Returns:
        world_pos: np.ndarray(3,), position in world coordinates
    """
    line_vec = end_pos - start_pos
    path_length = np.linalg.norm(line_vec)
    line_vec_norm, perp1, perp2 = build_local_frame(start_pos, end_pos)

    # Base position along the line
    base_pos = start_pos + pos_frac * line_vec

    # Offset perpendicular to the line
    angle_rad = np.deg2rad(angle_deg)
    offset_dist = dist_frac * radius * path_length
    offset = offset_dist * (np.cos(angle_rad) * perp1 + np.sin(angle_rad) * perp2)

    return base_pos + offset


def compute_wall_corners(start_pos, end_pos, radius, wall_config, offset=None):
    """
    Compute the 4 corners of a wall in world coordinates.

    The wall is a rectangle perpendicular to the trajectory direction.

    Args:
        start_pos: np.ndarray(3,), trajectory start
        end_pos: np.ndarray(3,), trajectory end
        radius: float, same radius used for control points
        wall_config: dict with wall parameters:
            - pos_frac: position along trajectory (0.5 = middle)
            - corner_angle: angle of corner in degrees
            - corner_dist: distance of corner as fraction of radius
            - width: wall width in units of (radius * path_length)
            - height: wall height in units of (radius * path_length)
        offset: [perp1_offset, perp2_offset] or None, offset to translate wall position
                in normalized coordinates. If None, no offset is applied.

    Returns:
        corners: np.ndarray(4, 3), four corners in world coordinates
                 Order: bottom-left, bottom-right, top-right, top-left
    """
    line_vec = end_pos - start_pos
    path_length = np.linalg.norm(line_vec)
    line_vec_norm, perp1, perp2 = build_local_frame(start_pos, end_pos)

    # Get wall parameters
    pos_frac = wall_config.get("pos_frac", 0.5)
    corner_angle = wall_config.get("corner_angle", 0.0)  # degrees
    corner_dist = wall_config.get("corner_dist", 0.0)
    width = wall_config.get("width", 1.0)   # in perp1 direction
    height = wall_config.get("height", 1.0)  # in perp2 direction

    # Apply offset if provided (in normalized perp1/perp2 coordinates)
    offset_perp1 = 0.0
    offset_perp2 = 0.0
    if offset is not None:
        offset_perp1, offset_perp2 = offset

    # Base position on trajectory
    base_pos = start_pos + pos_frac * line_vec

    # Corner position (relative to base) with offset applied
    angle_rad = np.deg2rad(corner_angle)
    corner_offset_dist = corner_dist * radius * path_length
    corner_offset = corner_offset_dist * (np.cos(angle_rad) * perp1 + np.sin(angle_rad) * perp2)
    # Add offset in perp1/perp2 directions
    offset_world = (offset_perp1 * radius * path_length) * perp1 + (offset_perp2 * radius * path_length) * perp2
    corner_pos = base_pos + corner_offset + offset_world

    # Wall dimensions in world units
    wall_width = width * radius * path_length
    wall_height = height * radius * path_length

    # Compute 4 corners
    # Corner is the "upper-left" corner (in perp1/perp2 space)
    # perp1 is "up", perp2 is "right"
    corners = np.array([
        corner_pos,                                          # upper-left
        corner_pos + wall_height * perp2,                    # upper-right
        corner_pos + wall_height * perp2 - wall_width * perp1,  # lower-right
        corner_pos - wall_width * perp1,                     # lower-left
    ])

    return corners


def check_trajectory_wall_collision(ee_positions, start_pos, end_pos, radius, wall_config, offset=None):
    """
    Check if EE trajectory crosses through the wall.

    The wall is a plane perpendicular to the trajectory. Collision is detected
    when the EE crosses this plane within the wall bounds (excluding any opening).

    Args:
        ee_positions: np.ndarray(N, 3), EE trajectory positions
        start_pos: trajectory start position
        end_pos: trajectory end position
        radius: same radius used for control points
        wall_config: dict with wall parameters
        offset: [perp1_offset, perp2_offset] or None, offset to translate wall position
                in normalized coordinates. If None, no offset is applied.

    Returns:
        collision: bool, True if collision detected
        collision_idx: int or None, index where collision occurred
    """
    line_vec = end_pos - start_pos
    path_length = np.linalg.norm(line_vec)
    line_vec_norm, perp1, perp2 = build_local_frame(start_pos, end_pos)

    # Wall position along trajectory
    pos_frac = wall_config.get("pos_frac", 0.5)
    wall_center = start_pos + pos_frac * line_vec

    # Wall corner and dimensions
    corner_angle = wall_config.get("corner_angle", 0.0)
    corner_dist = wall_config.get("corner_dist", 0.0)
    width = wall_config.get("width", 1.0)
    height = wall_config.get("height", 1.0)

    # Apply offset if provided (in normalized perp1/perp2 coordinates)
    offset_perp1 = 0.0
    offset_perp2 = 0.0
    if offset is not None:
        offset_perp1, offset_perp2 = offset

    # Compute corner in perp1/perp2 coordinates with offset applied
    angle_rad = np.deg2rad(corner_angle)
    corner_perp1 = corner_dist * radius * np.cos(angle_rad) + offset_perp1 * radius
    corner_perp2 = corner_dist * radius * np.sin(angle_rad) + offset_perp2 * radius

    # Wall bounds in perp1/perp2 space (relative to trajectory center)
    wall_width_world = width * radius * path_length
    wall_height_world = height * radius * path_length

    # Wall bounds: corner is upper-left, wall extends down (-perp1) and right (+perp2)
    perp1_max = corner_perp1 * path_length
    perp1_min = perp1_max - wall_width_world
    perp2_min = corner_perp2 * path_length
    perp2_max = perp2_min + wall_height_world

    # Check for opening in wall
    opening = wall_config.get("opening", None)
    has_opening = opening is not None

    if has_opening:
        # Opening is offset from upper-left corner
        # distance_to_corner: offset along both -perp1 and +perp2 directions (diagonal)
        # length: size of square opening
        dist_to_corner = opening.get("distance_to_corner", 0.0) * radius * path_length
        opening_length = opening.get("length", 0.0) * radius * path_length

        # Opening starts at (corner - dist_to_corner in perp1, corner + dist_to_corner in perp2)
        # This places the opening diagonally from the corner
        opening_perp1_max = perp1_max - dist_to_corner
        opening_perp1_min = opening_perp1_max - opening_length
        opening_perp2_min = perp2_min + dist_to_corner
        opening_perp2_max = opening_perp2_min + opening_length

    for i in range(1, len(ee_positions)):
        prev_pos = ee_positions[i - 1]
        curr_pos = ee_positions[i]

        # Project positions onto trajectory direction
        prev_t = np.dot(prev_pos - start_pos, line_vec_norm) / path_length
        curr_t = np.dot(curr_pos - start_pos, line_vec_norm) / path_length

        # Check if EE crossed the wall plane in this step
        if (prev_t < pos_frac and curr_t >= pos_frac) or (prev_t > pos_frac and curr_t <= pos_frac):
            # Interpolate to find position at wall plane
            if abs(curr_t - prev_t) > 1e-6:
                alpha = (pos_frac - prev_t) / (curr_t - prev_t)
                cross_pos = prev_pos + alpha * (curr_pos - prev_pos)
            else:
                cross_pos = curr_pos

            # Project crossing point onto perp1/perp2
            rel_pos = cross_pos - wall_center
            cross_perp1 = np.dot(rel_pos, perp1)
            cross_perp2 = np.dot(rel_pos, perp2)

            # Check if within wall bounds
            if perp1_min <= cross_perp1 <= perp1_max and perp2_min <= cross_perp2 <= perp2_max:
                # Check if within opening (no collision if in opening)
                if has_opening:
                    in_opening = (opening_perp1_min <= cross_perp1 <= opening_perp1_max and
                                  opening_perp2_min <= cross_perp2 <= opening_perp2_max)
                    if in_opening:
                        continue  # Pass through opening, no collision
                return True, i

    return False, None


# ============================================================================
# Predefined Wall Configurations (trajectory-relative)
# ============================================================================

# Default wall: centered at trajectory midpoint, blocking the middle
DEFAULT_WALL_CONFIG = {
    "pos_frac": 0.5,        # Middle of trajectory
    "corner_angle": 315.0,   # Upper-left corner at 45 degrees
    "corner_dist": 0.5,     # At distance = 1.0 * radius * path_length
    "width": 0.5,           # Width = 2.0 * radius * path_length (covers -1 to +1 in perp1)
    "height": 0.5,          # Height = 2.0 * radius * path_length (covers -1 to +1 in perp2)
    "reach_offset": [-0.2, 0.2],   # [perp1, perp2] offset for REACH phase
    "carry_offset": [-1.4, 0.1],   # [perp1, perp2] offset for CARRY phase
}

# Wall style 1: Large wall blocking most trajectories, letting mode 7 pass
WALL_STYLE_1 = {
    "pos_frac": 0.5,
    "corner_angle": 330.0,
    "corner_dist": 1.4577,
    "width": 2.5,
    "height": 2.5,
    "reach_offset": [-0.2, 0.0],   # [perp1, perp2] offset for REACH phase
    "carry_offset": [-1.3, -0.1],   # [perp1, perp2] offset for CARRY phase (to be tuned)
}

# Wall style 2: Large wall blocking most trajectories, letting mode 4 pass
WALL_STYLE_2 = {
    "pos_frac": 0.5,
    "corner_angle": 300.0,
    "corner_dist": 1.4577,
    "width": 2.5,
    "height": 2.5,
    "reach_offset": [-0.1, 0.0],   # [perp1, perp2] offset for REACH phase
    "carry_offset": [-1.4, -0.1],   # [perp1, perp2] offset for CARRY phase (to be tuned)
}

# Wall style 3: Wall with opening
# The opening is a square hole in the wall, offset from the upper-left corner
WALL_STYLE_3 = {
    "pos_frac": 0.5,
    "corner_angle": 315.0,
    "corner_dist": 1.7675,
    "width": 2.5,
    "height": 2.5,
    # Opening: offset from upper-left corner along wall edge, square of size length x length
    # The opening starts at distance_to_corner from corner, measured along the wall diagonal
    "opening": {
        "distance_to_corner": 1.5,  # offset from corner (in units of radius * path_length)
        "length": 1.0,              # size of square opening (in units of radius * path_length)
    },
    "reach_offset": [-0.1, 0.0],   # [perp1, perp2] offset for REACH phase
    "carry_offset": [-1.5, 0.0],   # [perp1, perp2] offset for CARRY phase (to be tuned)
}

WALL_STYLES = {
    0: DEFAULT_WALL_CONFIG,  # Default style
    1: WALL_STYLE_1,
    2: WALL_STYLE_2,
    3: WALL_STYLE_3,
}


# ============================================================================
# Visualization helpers
# ============================================================================

def compute_opening_corners(start_pos, end_pos, radius, wall_config, offset=None):
    """
    Compute the 4 corners of the opening in world coordinates.

    Args:
        start_pos: np.ndarray(3,), trajectory start
        end_pos: np.ndarray(3,), trajectory end
        radius: float, same radius used for control points
        wall_config: dict with wall parameters including opening
        offset: [perp1_offset, perp2_offset] or None, offset to translate wall position
                in normalized coordinates. If None, no offset is applied.

    Returns:
        opening_corners: np.ndarray(4, 3), four corners of opening in world coordinates
                        Order: upper-left, upper-right, lower-right, lower-left
                        Returns None if no opening defined
    """
    opening = wall_config.get("opening", None)
    if opening is None:
        return None

    line_vec = end_pos - start_pos
    path_length = np.linalg.norm(line_vec)
    line_vec_norm, perp1, perp2 = build_local_frame(start_pos, end_pos)

    # Get wall parameters
    pos_frac = wall_config.get("pos_frac", 0.5)
    corner_angle = wall_config.get("corner_angle", 0.0)
    corner_dist = wall_config.get("corner_dist", 0.0)

    # Apply offset if provided (in normalized perp1/perp2 coordinates)
    offset_perp1 = 0.0
    offset_perp2 = 0.0
    if offset is not None:
        offset_perp1, offset_perp2 = offset

    # Base position on trajectory
    base_pos = start_pos + pos_frac * line_vec

    # Wall corner position with offset applied
    angle_rad = np.deg2rad(corner_angle)
    corner_offset_dist = corner_dist * radius * path_length
    corner_offset = corner_offset_dist * (np.cos(angle_rad) * perp1 + np.sin(angle_rad) * perp2)
    # Add offset in perp1/perp2 directions
    offset_world = (offset_perp1 * radius * path_length) * perp1 + (offset_perp2 * radius * path_length) * perp2
    corner_pos = base_pos + corner_offset + offset_world

    # Opening parameters
    dist_to_corner = opening.get("distance_to_corner", 0.0) * radius * path_length
    opening_length = opening.get("length", 0.0) * radius * path_length

    # Opening upper-left corner: offset from wall corner along -perp1 and +perp2
    opening_corner = corner_pos - dist_to_corner * perp1 + dist_to_corner * perp2

    # Compute 4 corners of opening (same convention as wall)
    opening_corners = np.array([
        opening_corner,                                            # upper-left
        opening_corner + opening_length * perp2,                   # upper-right
        opening_corner + opening_length * perp2 - opening_length * perp1,  # lower-right
        opening_corner - opening_length * perp1,                   # lower-left
    ])

    return opening_corners


def draw_wall_3d(ax, corners, color='red', alpha=0.3, opening_corners=None):
    """
    Draw wall as a filled polygon in 3D plot, optionally with an opening.

    Args:
        ax: matplotlib 3D axis
        corners: np.ndarray(4, 3), four corners of the wall
        color: wall color
        alpha: transparency
        opening_corners: np.ndarray(4, 3) or None, corners of opening to exclude
    """
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection

    if opening_corners is None:
        # Simple case: solid wall
        verts = [corners.tolist()]
        poly = Poly3DCollection(verts, alpha=alpha, facecolor=color,
                                edgecolor='darkred', linewidth=2)
        ax.add_collection3d(poly)
    else:
        # Wall with opening: draw as 4 trapezoids around the opening
        # Wall corners: 0=upper-left, 1=upper-right, 2=lower-right, 3=lower-left
        # Opening corners: same convention
        w = corners  # wall
        o = opening_corners  # opening

        # Create 4 polygons around the opening:
        # Top: wall[0], wall[1], opening[1], opening[0]
        # Right: wall[1], wall[2], opening[2], opening[1]
        # Bottom: opening[3], opening[2], wall[2], wall[3]
        # Left: wall[0], opening[0], opening[3], wall[3]

        polygons = [
            [w[0], w[1], o[1], o[0]],  # Top
            [w[1], w[2], o[2], o[1]],  # Right
            [o[3], o[2], w[2], w[3]],  # Bottom
            [w[0], o[0], o[3], w[3]],  # Left
        ]

        for poly_verts in polygons:
            verts = [[v.tolist() for v in poly_verts]]
            poly = Poly3DCollection(verts, alpha=alpha, facecolor=color,
                                    edgecolor='darkred', linewidth=1)
            ax.add_collection3d(poly)

        # Draw opening border in a different color
        opening_verts = [opening_corners.tolist()]
        opening_poly = Poly3DCollection(opening_verts, alpha=0.1, facecolor='green',
                                        edgecolor='darkgreen', linewidth=2)
        ax.add_collection3d(opening_poly)


def print_wall_info(start_pos, end_pos, radius, wall_config):
    """Print wall configuration and world coordinates."""
    corners = compute_wall_corners(start_pos, end_pos, radius, wall_config)
    path_length = np.linalg.norm(end_pos - start_pos)

    print("Wall Configuration (trajectory-relative):")
    print(f"  Position along trajectory: {wall_config.get('pos_frac', 0.5):.2f}")
    print(f"  Corner: angle={wall_config.get('corner_angle', 0):.1f}°, dist={wall_config.get('corner_dist', 0):.2f}")
    print(f"  Size: width={wall_config.get('width', 1):.2f}, height={wall_config.get('height', 1):.2f}")
    print(f"  Path length: {path_length:.4f}m")
    print(f"  Radius: {radius:.4f}")
    print(f"\nWall corners (world coordinates):")
    for i, corner in enumerate(corners):
        print(f"  Corner {i}: [{corner[0]:.4f}, {corner[1]:.4f}, {corner[2]:.4f}]")

    # Print opening info if present
    opening = wall_config.get("opening", None)
    if opening is not None:
        opening_corners = compute_opening_corners(start_pos, end_pos, radius, wall_config)
        print(f"\nOpening Configuration:")
        print(f"  Distance to corner: {opening.get('distance_to_corner', 0):.2f}")
        print(f"  Length: {opening.get('length', 0):.2f}")
        if opening_corners is not None:
            print(f"\nOpening corners (world coordinates):")
            for i, corner in enumerate(opening_corners):
                print(f"  Corner {i}: [{corner[0]:.4f}, {corner[1]:.4f}, {corner[2]:.4f}]")
