"""
Meat Off Grill Initialization Script

This script pre-computes and saves:
1. Fixed positions (home, object, target, waypoints)
2. Feasible control points for REACH phase
3. Feasible control points for CARRY phase
4. Grasp parameters extracted from demo

The saved data can be loaded by other scripts to skip the expensive
control point filtering step.

Output file: meat_off_grill_init.npz
"""

import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'custom_tasks'))

import numpy as np

from rlbench.environment import Environment
from pyrep.errors import ConfigurationPathError, IKError

from utils import (
    create_obs_config, create_action_mode, get_task_classes,
    generate_canonical_control_point_params,
    compute_control_point_from_params,
    prefilter_control_points_with_collision,
    parabola3D,
    HOME_JOINTS,
)

from pick_and_place_utils import (
    extract_demo_waypoints,
    DEFAULT_GRASP_ORIENTATION,
)

# Output file path
OUTPUT_FILE = os.path.join(os.path.dirname(__file__), "meat_off_grill_init.npz")


def compute_jitter_metric(trace):
    """Compute jitter as sum of direction changes (higher = more jitter)."""
    if len(trace) < 3:
        return 0.0
    velocities = np.diff(trace, axis=0)
    norms = np.linalg.norm(velocities, axis=1, keepdims=True)
    norms = np.where(norms < 1e-8, 1e-8, norms)
    velocities_norm = velocities / norms
    dots = np.sum(velocities_norm[:-1] * velocities_norm[1:], axis=1)
    jitter = np.sum(1 - dots)
    return jitter


def main():
    print("=" * 60)
    print("MEAT OFF GRILL INITIALIZATION")
    print("Pre-computing feasible control points...")
    print("=" * 60)

    task_name = "meat_off_grill"

    # Phase configuration
    REACH_STEPS = 64
    CARRY_STEPS = 64
    STEPS_PER_POINT = 5
    JITTER_THRESHOLD = 0.08

    obs_config = create_obs_config(save_video=False)
    action_mode = create_action_mode("abs")

    rlbench_env = Environment(
        action_mode=action_mode, obs_config=obs_config, headless=True
    )
    rlbench_env.launch()

    tasks = get_task_classes([task_name])
    t_cls = tasks[0]
    task_env = rlbench_env.get_task(t_cls)
    task_env.set_variation(0)  # Use chicken (variation 0)

    # ========================================================================
    # STEP 1: Extract grasp parameters from RLBench demo
    # ========================================================================
    print(f"\n{'='*60}")
    print("EXTRACTING GRASP PARAMETERS FROM RLBENCH DEMO")
    print(f"{'='*60}")

    waypoint_params = extract_demo_waypoints(task_env, task_name, num_demos=1)

    PREGRASP_HEIGHT = 0.05  # 5cm above object
    GRASP_HEIGHT = waypoint_params.get("grasp_height", 0.0)
    if GRASP_HEIGHT < -0.05:
        print(f"  WARNING: GRASP_HEIGHT={GRASP_HEIGHT:.4f}m too low, clamping to 0.0")
        GRASP_HEIGHT = 0.0
    LIFT_HEIGHT = waypoint_params.get("lift_height", 0.12)  # Slightly lower for flat meat
    GRASP_ORIENTATION = waypoint_params.get("grasp_orientation", DEFAULT_GRASP_ORIENTATION.copy())
    GRASP_XY_OFFSET = waypoint_params.get("grasp_xy_offset", np.zeros(2))

    print(f"\n  Extracted parameters:")
    print(f"    PREGRASP_HEIGHT: {PREGRASP_HEIGHT:.4f}m")
    print(f"    GRASP_HEIGHT:    {GRASP_HEIGHT:.4f}m")
    print(f"    LIFT_HEIGHT:     {LIFT_HEIGHT:.4f}m")
    print(f"    GRASP_XY_OFFSET: [{GRASP_XY_OFFSET[0]:.4f}, {GRASP_XY_OFFSET[1]:.4f}]m")
    print(f"    GRASP_ORIENTATION: [{GRASP_ORIENTATION[0]:.4f}, {GRASP_ORIENTATION[1]:.4f}, {GRASP_ORIENTATION[2]:.4f}]")

    # Build canonical control points (4 angles x 2 distances = 8 CPs)
    canonical_params = generate_canonical_control_point_params()
    control_point_radius = 0.05
    print(f"\nTotal canonical control points: {len(canonical_params)}")
    print("  Angles: 0°, 90°, 180°, 270°")
    print("  Distances: 0.5, 1.0")
    print("  Position: fixed at 0.5")

    # ========================================================================
    # STEP 2: Get fixed positions
    # ========================================================================
    print(f"\n{'='*60}")
    print("SETTING UP FIXED POSITIONS")
    print(f"{'='*60}")

    np.random.seed(42)
    task_env.reset()

    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    # Go to HOME
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    gripper.actuate(1.0, 0.2)
    for _ in range(20):
        task_env._scene.pyrep.step()

    home_pos = tip.get_position().copy()

    # Get object (chicken or steak depending on variation)
    task = task_env._task
    try:
        target_object = task._chicken
        object_shape_name = 'chicken'
    except AttributeError:
        target_object = task._steak
        object_shape_name = 'steak'
    FIXED_OBJECT_POS = target_object.get_position().copy()
    print(f"  Using object: {object_shape_name}")

    # Get target position from success sensor
    success_sensor = task._success_sensor
    TARGET_POS = success_sensor.get_position().copy()
    print(f"  Target (success sensor) pos: [{TARGET_POS[0]:.4f}, {TARGET_POS[1]:.4f}, {TARGET_POS[2]:.4f}]")

    # Compute waypoints
    pregrasp_pos = FIXED_OBJECT_POS.copy()
    pregrasp_pos[:2] += GRASP_XY_OFFSET
    pregrasp_pos[2] += PREGRASP_HEIGHT

    grasp_pos = FIXED_OBJECT_POS.copy()
    grasp_pos[:2] += GRASP_XY_OFFSET
    grasp_pos[2] += GRASP_HEIGHT

    lift_pos = pregrasp_pos.copy()

    # Height parameters for placement (meat is flatter than blocks)
    PREPLACE_HEIGHT = waypoint_params.get("preplace_height", 0.15)
    PLACE_HEIGHT = 0.05  # Lower for flat meat pieces

    prerelease_pos = TARGET_POS.copy()
    prerelease_pos[2] += PREPLACE_HEIGHT

    release_pos = TARGET_POS.copy()
    release_pos[2] += PLACE_HEIGHT

    print(f"  HOME pos:       [{home_pos[0]:.4f}, {home_pos[1]:.4f}, {home_pos[2]:.4f}]")
    print(f"  Object pos:     [{FIXED_OBJECT_POS[0]:.4f}, {FIXED_OBJECT_POS[1]:.4f}, {FIXED_OBJECT_POS[2]:.4f}]")
    print(f"  Target pos:     [{TARGET_POS[0]:.4f}, {TARGET_POS[1]:.4f}, {TARGET_POS[2]:.4f}]")
    print(f"  Pregrasp pos:   [{pregrasp_pos[0]:.4f}, {pregrasp_pos[1]:.4f}, {pregrasp_pos[2]:.4f}]")
    print(f"  Grasp pos:      [{grasp_pos[0]:.4f}, {grasp_pos[1]:.4f}, {grasp_pos[2]:.4f}]")
    print(f"  Lift pos:       [{lift_pos[0]:.4f}, {lift_pos[1]:.4f}, {lift_pos[2]:.4f}]")
    print(f"  Prerelease pos: [{prerelease_pos[0]:.4f}, {prerelease_pos[1]:.4f}, {prerelease_pos[2]:.4f}]")
    print(f"  Release pos:    [{release_pos[0]:.4f}, {release_pos[1]:.4f}, {release_pos[2]:.4f}]")

    # ========================================================================
    # STEP 3: Pre-filter REACH control points (IK + collision)
    # ========================================================================
    print(f"\n{'='*60}")
    print("PRE-FILTERING REACH CONTROL POINTS (IK + Collision)")
    print(f"{'='*60}")

    valid_indices, prefilter_results = prefilter_control_points_with_collision(
        robot.arm, home_pos, pregrasp_pos, GRASP_ORIENTATION,
        canonical_params, control_point_radius,
        num_samples=16, require_orientation=True
    )

    collision_failures = sum(1 for r in prefilter_results.values() if r['collision_count'] > 0)
    ik_failures = sum(1 for r in prefilter_results.values() if r['ik_failures'] > 0)
    print(f"  Valid (IK + collision-free): {len(valid_indices)}/{len(canonical_params)}")
    print(f"  Rejected due to collision: {collision_failures}")
    print(f"  Rejected due to IK failure: {ik_failures}")

    if len(valid_indices) == 0:
        print("  ERROR: No valid REACH control points found!")
        rlbench_env.shutdown()
        return

    # ========================================================================
    # STEP 4: Jitter filter for REACH
    # ========================================================================
    print(f"\n{'='*60}")
    print("JITTER FILTER FOR REACH PHASE")
    print(f"{'='*60}")

    reach_valid_indices = []

    for cp_idx in valid_indices:
        np.random.seed(42)
        task_env.reset()

        robot = task_env._scene.robot
        tip = robot.arm.get_tip()
        gripper = robot.gripper

        robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
        robot.arm.set_joint_target_velocities([0] * 7)
        gripper.actuate(1.0, 0.2)
        for _ in range(20):
            task_env._scene.pyrep.step()

        home_pos_curr = tip.get_position().copy()

        # Get object reference and fix position
        task = task_env._task
        try:
            target_object = task._chicken
        except AttributeError:
            target_object = task._steak
        target_object.set_position(FIXED_OBJECT_POS)
        for _ in range(5):
            task_env._scene.pyrep.step()

        pregrasp_pos_curr = FIXED_OBJECT_POS.copy()
        pregrasp_pos_curr[:2] += GRASP_XY_OFFSET
        pregrasp_pos_curr[2] += PREGRASP_HEIGHT

        angle, dist_frac, pos_frac = canonical_params[cp_idx]
        cp = compute_control_point_from_params(
            home_pos_curr, pregrasp_pos_curr, control_point_radius, angle, dist_frac, pos_frac
        )

        trace = []
        prev_joints = list(robot.arm.get_joint_positions())

        for i in range(REACH_STEPS):
            t = i / (REACH_STEPS - 1)
            target_pos = parabola3D(home_pos_curr, pregrasp_pos_curr, cp, t)

            try:
                robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
                joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=GRASP_ORIENTATION)
                robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
                robot.arm.set_joint_target_velocities([0] * 7)

                for _ in range(STEPS_PER_POINT):
                    task_env._scene.pyrep.step()

                trace.append(tip.get_position().copy())
                prev_joints = list(joint_positions)
            except (IKError, ConfigurationPathError):
                if len(trace) > 0:
                    trace.append(trace[-1].copy())
                else:
                    trace.append(home_pos_curr.copy())

        trace = np.array(trace)
        jitter = compute_jitter_metric(trace)

        angle_deg = np.degrees(angle)
        if jitter <= JITTER_THRESHOLD:
            reach_valid_indices.append(cp_idx)
            print(f"  REACH CP {cp_idx}: angle={angle_deg:.0f}°, dist={dist_frac:.1f}, jitter={jitter:.4f} - PASS")
        else:
            print(f"  REACH CP {cp_idx}: angle={angle_deg:.0f}°, dist={dist_frac:.1f}, jitter={jitter:.4f} - REJECT (>{JITTER_THRESHOLD})")

    print(f"\n  Final valid REACH CPs: {len(reach_valid_indices)}/{len(valid_indices)}")

    if len(reach_valid_indices) == 0:
        print("  ERROR: No valid REACH control points after jitter filter!")
        rlbench_env.shutdown()
        return

    # ========================================================================
    # STEP 5: Pre-filter CARRY control points (IK + collision)
    # ========================================================================
    print(f"\n{'='*60}")
    print("PRE-FILTERING CARRY CONTROL POINTS (IK + Collision)")
    print(f"{'='*60}")

    np.random.seed(42)
    task_env.reset()

    robot = task_env._scene.robot
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    gripper = robot.gripper
    gripper.actuate(1.0, 0.2)
    for _ in range(20):
        task_env._scene.pyrep.step()

    print(f"  CARRY path: lift_pos → prerelease_pos")
    print(f"    Start: [{lift_pos[0]:.4f}, {lift_pos[1]:.4f}, {lift_pos[2]:.4f}]")
    print(f"    End:   [{prerelease_pos[0]:.4f}, {prerelease_pos[1]:.4f}, {prerelease_pos[2]:.4f}]")

    carry_valid_indices, carry_prefilter_results = prefilter_control_points_with_collision(
        robot.arm, lift_pos, prerelease_pos, GRASP_ORIENTATION,
        canonical_params, control_point_radius,
        num_samples=16, require_orientation=True
    )

    carry_collision_failures = sum(1 for r in carry_prefilter_results.values() if r['collision_count'] > 0)
    carry_ik_failures = sum(1 for r in carry_prefilter_results.values() if r['ik_failures'] > 0)
    print(f"  Valid (IK + collision-free): {len(carry_valid_indices)}/{len(canonical_params)}")
    print(f"  Rejected due to collision: {carry_collision_failures}")
    print(f"  Rejected due to IK failure: {carry_ik_failures}")

    if len(carry_valid_indices) == 0:
        print("  ERROR: No valid CARRY control points found!")
        rlbench_env.shutdown()
        return

    # ========================================================================
    # STEP 6: Jitter filter for CARRY
    # ========================================================================
    print(f"\n{'='*60}")
    print("JITTER FILTER FOR CARRY PHASE")
    print(f"{'='*60}")

    carry_final_valid_indices = []

    for idx_num, cp_idx in enumerate(carry_valid_indices):
        np.random.seed(42)
        task_env.reset()

        robot = task_env._scene.robot
        tip = robot.arm.get_tip()
        gripper = robot.gripper

        robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
        robot.arm.set_joint_target_velocities([0] * 7)
        gripper.actuate(1.0, 0.2)
        for _ in range(10):
            task_env._scene.pyrep.step()

        try:
            joint_positions = robot.arm.solve_ik_via_jacobian(lift_pos, euler=GRASP_ORIENTATION)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)
            gripper.actuate(0.0, 0.2)
            for _ in range(20):
                task_env._scene.pyrep.step()
        except (IKError, ConfigurationPathError):
            continue

        angle, dist_frac, pos_frac = canonical_params[cp_idx]
        cp = compute_control_point_from_params(
            lift_pos, prerelease_pos, control_point_radius, angle, dist_frac, pos_frac
        )

        trace = []
        prev_joints = list(robot.arm.get_joint_positions())

        for i in range(CARRY_STEPS):
            t = i / (CARRY_STEPS - 1)
            target_pos = parabola3D(lift_pos, prerelease_pos, cp, t)

            try:
                robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
                joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=GRASP_ORIENTATION)
                robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
                robot.arm.set_joint_target_velocities([0] * 7)

                for _ in range(STEPS_PER_POINT):
                    task_env._scene.pyrep.step()

                trace.append(tip.get_position().copy())
                prev_joints = list(joint_positions)
            except (IKError, ConfigurationPathError):
                if len(trace) > 0:
                    trace.append(trace[-1].copy())
                else:
                    trace.append(lift_pos.copy())

        trace = np.array(trace)
        jitter = compute_jitter_metric(trace)

        angle_deg = np.degrees(angle)
        if jitter <= JITTER_THRESHOLD:
            carry_final_valid_indices.append(cp_idx)
            print(f"  CARRY CP {cp_idx}: angle={angle_deg:.0f}°, dist={dist_frac:.1f}, jitter={jitter:.4f} - PASS")
        else:
            print(f"  CARRY CP {cp_idx}: angle={angle_deg:.0f}°, dist={dist_frac:.1f}, jitter={jitter:.4f} - REJECT (>{JITTER_THRESHOLD})")

    print(f"\n  Final valid CARRY CPs: {len(carry_final_valid_indices)}/{len(carry_valid_indices)}")

    if len(carry_final_valid_indices) == 0:
        print("  ERROR: No valid CARRY control points after jitter filter!")
        rlbench_env.shutdown()
        return

    # ========================================================================
    # STEP 7: Save everything to file
    # ========================================================================
    print(f"\n{'='*60}")
    print("SAVING INITIALIZATION DATA")
    print(f"{'='*60}")

    np.savez(
        OUTPUT_FILE,
        # Positions
        home_pos=home_pos,
        object_pos=FIXED_OBJECT_POS,
        target_pos=TARGET_POS,
        pregrasp_pos=pregrasp_pos,
        grasp_pos=grasp_pos,
        lift_pos=lift_pos,
        prerelease_pos=prerelease_pos,
        release_pos=release_pos,
        # Grasp parameters
        grasp_orientation=GRASP_ORIENTATION,
        grasp_xy_offset=GRASP_XY_OFFSET,
        pregrasp_height=PREGRASP_HEIGHT,
        grasp_height=GRASP_HEIGHT,
        lift_height=LIFT_HEIGHT,
        preplace_height=PREPLACE_HEIGHT,
        place_height=PLACE_HEIGHT,
        # Control point parameters
        canonical_params=np.array(canonical_params),
        control_point_radius=control_point_radius,
        # Valid control point indices
        reach_valid_indices=np.array(reach_valid_indices),
        carry_valid_indices=np.array(carry_final_valid_indices),
        # Configuration
        reach_steps=REACH_STEPS,
        carry_steps=CARRY_STEPS,
        jitter_threshold=JITTER_THRESHOLD,
        # Task-specific
        object_shape_name=object_shape_name,
    )

    print(f"\n  Saved to: {OUTPUT_FILE}")
    print(f"\n  Contents:")
    print(f"    Positions: home, object, target, pregrasp, grasp, lift, prerelease, release")
    print(f"    Grasp params: orientation, xy_offset, heights")
    print(f"    Valid REACH CPs: {len(reach_valid_indices)} indices")
    print(f"    Valid CARRY CPs: {len(carry_final_valid_indices)} indices")
    print(f"    Canonical params: {len(canonical_params)} total (4 angles x 2 distances)")
    print(f"    Object shape name: {object_shape_name}")

    # Print valid CPs details
    print(f"\n  Valid REACH control points:")
    for idx in reach_valid_indices:
        angle, dist, pos = canonical_params[idx]
        print(f"    CP {idx}: angle={np.degrees(angle):.0f}°, dist={dist:.1f}, pos={pos:.1f}")

    print(f"\n  Valid CARRY control points:")
    for idx in carry_final_valid_indices:
        angle, dist, pos = canonical_params[idx]
        print(f"    CP {idx}: angle={np.degrees(angle):.0f}°, dist={dist:.1f}, pos={pos:.1f}")

    rlbench_env.shutdown()

    print("\n" + "=" * 60)
    print("INITIALIZATION COMPLETE")
    print("=" * 60)
    print(f"\nTo use this data in other scripts:")
    print(f"  data = np.load('{OUTPUT_FILE}')")
    print(f"  reach_valid_indices = data['reach_valid_indices']")
    print(f"  carry_valid_indices = data['carry_valid_indices']")
    print(f"  home_pos = data['home_pos']")
    print(f"  # etc...")


if __name__ == "__main__":
    main()
