# filename: blocked_zone.py
"""
Blocked zone definitions for grasp task evaluation.

Similar to wall_collision.py for close_drawer, this defines constraints on valid
grasp configurations (approach_angle, grasp_height) for different evaluation styles.

The goal is to test adaptation:
- Style 1: Only one specific angle works (other angles are "blocked")
- Style 2: Only one specific height works (other heights are "blocked")
- Style 3: Novel configuration not seen in training (requires adaptation)

In the physical interpretation:
- These constraints simulate objects with restricted grasp zones
- E.g., a cup in a corner where only certain approach angles are feasible
- E.g., a cup with a band around it marking the valid grasp height
"""

import numpy as np


# ============================================================================
# Evaluation Style Definitions
# ============================================================================

# Style 1: Only 0 degree approach angle works
# Simulates a cup placed against a wall, only front approach is valid
STYLE_1 = {
    "name": "single_angle_0",
    "description": "Only 0 degree approach works (cup against wall)",

    # Valid configurations
    "valid_angles_deg": [0],           # Only 0 degree approach
    "valid_heights": [0.12],           # Single rim height

    # For checking: tolerance around valid values
    "angle_tolerance_deg": 15.0,       # Accept angles within 15 degrees of valid
    "height_tolerance": 0.02,          # Accept heights within 20mm of valid
}

# Style 2: Only 180 degree approach works
# Simulates a cup placed against opposite wall
STYLE_2 = {
    "name": "single_angle_180",
    "description": "Only 180 degree approach works (opposite wall)",

    # Valid configurations
    "valid_angles_deg": [180],         # Only 180 degree approach
    "valid_heights": [0.12],           # Single rim height

    # For checking: tolerance around valid values
    "angle_tolerance_deg": 15.0,
    "height_tolerance": 0.02,
}

# Style 3: Novel angle not in training distribution
# Tests interpolation/adaptation capability
# Training has angles [0, 90, 180, 270]
# This style requires 45 degrees - not seen in training
STYLE_3 = {
    "name": "novel_angle",
    "description": "Novel 45 degree approach - not in training (adaptation test)",

    # Valid configurations (novel values)
    "valid_angles_deg": [45],          # Novel angle (between 0 and 90)
    "valid_heights": [0.12],           # Single rim height

    # For checking: tighter tolerance for novel config
    "angle_tolerance_deg": 10.0,       # Must be close to 45 degrees
    "height_tolerance": 0.02,          # Accept heights within 20mm
}

# Dictionary to access styles by number
BLOCKED_ZONE_STYLES = {
    1: STYLE_1,
    2: STYLE_2,
    3: STYLE_3,
}


# ============================================================================
# Validation Functions
# ============================================================================

def check_grasp_valid(approach_angle_rad, grasp_height, style_config):
    """
    Check if a grasp configuration is valid for a given blocked zone style.

    Args:
        approach_angle_rad: float, approach angle in radians
        grasp_height: float, grasp height in meters
        style_config: dict, style configuration

    Returns:
        is_valid: bool, True if configuration is valid
        reason: str, explanation of why invalid (if not valid)
    """
    approach_angle_deg = np.degrees(approach_angle_rad) % 360

    # Get valid values and tolerances
    valid_angles_deg = style_config["valid_angles_deg"]
    valid_heights = style_config["valid_heights"]
    angle_tol = style_config.get("angle_tolerance_deg", 15.0)
    height_tol = style_config.get("height_tolerance", 0.008)

    # Check if angle is valid (within tolerance of any valid angle)
    angle_valid = False
    min_angle_diff = float('inf')
    for valid_angle in valid_angles_deg:
        # Handle wraparound (0 and 360 are the same)
        diff = min(abs(approach_angle_deg - valid_angle),
                   360 - abs(approach_angle_deg - valid_angle))
        min_angle_diff = min(min_angle_diff, diff)
        if diff <= angle_tol:
            angle_valid = True
            break

    # Check if height is valid (within tolerance of any valid height)
    height_valid = False
    min_height_diff = float('inf')
    for valid_height in valid_heights:
        diff = abs(grasp_height - valid_height)
        min_height_diff = min(min_height_diff, diff)
        if diff <= height_tol:
            height_valid = True
            break

    # Both must be valid
    if angle_valid and height_valid:
        return True, "Valid configuration"
    elif not angle_valid and not height_valid:
        return False, f"Invalid angle ({approach_angle_deg:.1f} deg, diff={min_angle_diff:.1f}) and height ({grasp_height:.4f}m, diff={min_height_diff:.4f}m)"
    elif not angle_valid:
        return False, f"Invalid angle: {approach_angle_deg:.1f} deg (min diff={min_angle_diff:.1f} deg, tolerance={angle_tol} deg)"
    else:
        return False, f"Invalid height: {grasp_height:.4f}m (min diff={min_height_diff:.4f}m, tolerance={height_tol}m)"


def get_valid_modes_for_style(canonical_params, style_config):
    """
    Find which training modes are valid for a given blocked zone style.

    Args:
        canonical_params: np.ndarray of (angle, height) pairs from training
        style_config: dict, style configuration

    Returns:
        valid_indices: list of valid mode indices
        valid_params: list of (angle, height) tuples that are valid
    """
    valid_indices = []
    valid_params = []

    for idx, (angle, height) in enumerate(canonical_params):
        is_valid, _ = check_grasp_valid(angle, height, style_config)
        if is_valid:
            valid_indices.append(idx)
            valid_params.append((angle, height))

    return valid_indices, valid_params


def get_target_config_for_style(style_config):
    """
    Get the target (ideal) configuration for a blocked zone style.

    For styles 1 and 2, uses the center of valid range.
    For style 3, uses the novel configuration.

    Args:
        style_config: dict, style configuration

    Returns:
        target_angle_rad: float, target approach angle in radians
        target_height: float, target grasp height in meters
    """
    # Use first valid angle and height as target
    target_angle_deg = style_config["valid_angles_deg"][0]
    target_height = style_config["valid_heights"][0]

    return np.radians(target_angle_deg), target_height


# ============================================================================
# Demo Generation Helpers
# ============================================================================

def sample_valid_config_for_style(style_config, rng=None):
    """
    Sample a valid (angle, height) configuration for a blocked zone style.

    Args:
        style_config: dict, style configuration
        rng: numpy random generator (optional)

    Returns:
        approach_angle_rad: float, sampled approach angle in radians
        grasp_height: float, sampled grasp height in meters
    """
    if rng is None:
        rng = np.random.default_rng()

    # Get valid values and tolerances
    valid_angles_deg = style_config["valid_angles_deg"]
    valid_heights = style_config["valid_heights"]
    angle_tol = style_config.get("angle_tolerance_deg", 15.0)
    height_tol = style_config.get("height_tolerance", 0.008)

    # Sample an angle (with small noise within tolerance)
    base_angle = rng.choice(valid_angles_deg)
    angle_noise = rng.uniform(-angle_tol * 0.5, angle_tol * 0.5)  # Use half tolerance for safety
    angle_deg = base_angle + angle_noise

    # Sample a height (with small noise within tolerance)
    base_height = rng.choice(valid_heights)
    height_noise = rng.uniform(-height_tol * 0.5, height_tol * 0.5)
    height = base_height + height_noise

    return np.radians(angle_deg), height


# ============================================================================
# Visualization Helpers
# ============================================================================

def get_valid_region_description(style_config):
    """
    Get a human-readable description of the valid grasp region.

    Args:
        style_config: dict, style configuration

    Returns:
        str: Description of valid region
    """
    angles = style_config["valid_angles_deg"]
    heights = style_config["valid_heights"]
    angle_tol = style_config.get("angle_tolerance_deg", 15.0)
    height_tol = style_config.get("height_tolerance", 0.008)

    if len(angles) == 1:
        angle_str = f"{angles[0]} deg (+/- {angle_tol})"
    else:
        angle_str = f"any of {angles} deg (+/- {angle_tol})"

    if len(heights) == 1:
        height_str = f"{heights[0]*1000:.1f}mm (+/- {height_tol*1000:.1f}mm)"
    else:
        height_str = f"any of {[h*1000 for h in heights]}mm (+/- {height_tol*1000:.1f}mm)"

    return f"Valid angles: {angle_str}, Valid heights: {height_str}"


def print_style_info(style_id):
    """Print information about a blocked zone style."""
    style = BLOCKED_ZONE_STYLES.get(style_id)
    if style is None:
        print(f"Unknown style: {style_id}")
        return

    print(f"\n{'='*60}")
    print(f"Blocked Zone Style {style_id}: {style['name']}")
    print(f"{'='*60}")
    print(f"Description: {style['description']}")
    print(f"Valid region: {get_valid_region_description(style)}")
    print(f"{'='*60}")


# ============================================================================
# Testing
# ============================================================================

if __name__ == "__main__":
    # Print info about all styles
    for style_id in [1, 2, 3]:
        print_style_info(style_id)

    # Test validation
    print("\n\nTesting validation:")

    # Test Style 1 (only 0 degree works)
    style1 = BLOCKED_ZONE_STYLES[1]
    test_cases = [
        (np.radians(0), 0.03),      # Valid
        (np.radians(10), 0.03),     # Valid (within tolerance)
        (np.radians(90), 0.03),     # Invalid angle
        (np.radians(0), 0.035),     # Valid (height within tolerance)
    ]
    print(f"\nStyle 1 tests:")
    for angle, height in test_cases:
        valid, reason = check_grasp_valid(angle, height, style1)
        print(f"  angle={np.degrees(angle):.1f}, height={height:.4f}: {valid} - {reason}")

    # Test Style 3 (novel config)
    style3 = BLOCKED_ZONE_STYLES[3]
    test_cases = [
        (np.radians(45), 0.035),    # Valid (target)
        (np.radians(40), 0.035),    # Valid (within tolerance)
        (np.radians(0), 0.035),     # Invalid angle
        (np.radians(45), 0.03),     # Invalid height
    ]
    print(f"\nStyle 3 tests:")
    for angle, height in test_cases:
        valid, reason = check_grasp_valid(angle, height, style3)
        print(f"  angle={np.degrees(angle):.1f}, height={height:.4f}: {valid} - {reason}")
