"""
Stack Blocks Initialization Script

This script pre-computes and saves:
1. Fixed positions (home, object, target, waypoints)
2. Feasible control points for REACH phase
3. Feasible control points for CARRY phase
4. Grasp parameters extracted from demo

The saved data can be loaded by other scripts to skip the expensive
control point filtering step.

Output file: stack_blocks_init.npz
"""

import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'custom_tasks'))

import numpy as np

from rlbench.environment import Environment
from pyrep.errors import ConfigurationPathError, IKError
from pyrep.objects.shape import Shape

from utils import (
    create_obs_config, create_action_mode, get_task_classes,
    generate_canonical_control_point_params,
    compute_control_point_from_params,
    prefilter_control_points_with_collision,
    parabola3D,
    HOME_JOINTS,
)

from pick_and_place_utils import (
    extract_demo_waypoints,
    DEFAULT_GRASP_ORIENTATION,
)

# Output file path
OUTPUT_FILE = os.path.join(os.path.dirname(__file__), "stack_blocks_init.npz")


def compute_jitter_metric(trace):
    """Compute jitter as sum of direction changes (higher = more jitter)."""
    if len(trace) < 3:
        return 0.0
    velocities = np.diff(trace, axis=0)
    norms = np.linalg.norm(velocities, axis=1, keepdims=True)
    norms = np.where(norms < 1e-8, 1e-8, norms)
    velocities_norm = velocities / norms
    dots = np.sum(velocities_norm[:-1] * velocities_norm[1:], axis=1)
    jitter = np.sum(1 - dots)
    return jitter


def main():
    print("=" * 60)
    print("STACK BLOCKS INITIALIZATION")
    print("Pre-computing feasible control points...")
    print("=" * 60)

    task_name = "stack_blocks"

    # Phase configuration
    REACH_STEPS = 64
    CARRY_STEPS = 64
    STEPS_PER_POINT = 5
    JITTER_THRESHOLD = 0.08

    obs_config = create_obs_config(save_video=False)
    action_mode = create_action_mode("abs")

    rlbench_env = Environment(
        action_mode=action_mode, obs_config=obs_config, headless=True
    )
    rlbench_env.launch()

    tasks = get_task_classes([task_name])
    t_cls = tasks[0]
    task_env = rlbench_env.get_task(t_cls)
    task_env.set_variation(0)

    # ========================================================================
    # STEP 1: Extract grasp parameters from RLBench demo
    # ========================================================================
    print(f"\n{'='*60}")
    print("EXTRACTING GRASP PARAMETERS FROM RLBENCH DEMO")
    print(f"{'='*60}")

    waypoint_params = extract_demo_waypoints(task_env, task_name, num_demos=1)

    PREGRASP_HEIGHT = 0.05  # 5cm above object
    GRASP_HEIGHT = waypoint_params.get("grasp_height", 0.0)
    if GRASP_HEIGHT < -0.05:
        print(f"  WARNING: GRASP_HEIGHT={GRASP_HEIGHT:.4f}m too low, clamping to 0.0")
        GRASP_HEIGHT = 0.0
    LIFT_HEIGHT = waypoint_params.get("lift_height", 0.15)
    GRASP_ORIENTATION = waypoint_params.get("grasp_orientation", DEFAULT_GRASP_ORIENTATION.copy())
    GRASP_XY_OFFSET = waypoint_params.get("grasp_xy_offset", np.zeros(2))

    print(f"\n  Extracted parameters:")
    print(f"    PREGRASP_HEIGHT: {PREGRASP_HEIGHT:.4f}m")
    print(f"    GRASP_HEIGHT:    {GRASP_HEIGHT:.4f}m")
    print(f"    LIFT_HEIGHT:     {LIFT_HEIGHT:.4f}m")
    print(f"    GRASP_XY_OFFSET: [{GRASP_XY_OFFSET[0]:.4f}, {GRASP_XY_OFFSET[1]:.4f}]m")
    print(f"    GRASP_ORIENTATION: [{GRASP_ORIENTATION[0]:.4f}, {GRASP_ORIENTATION[1]:.4f}, {GRASP_ORIENTATION[2]:.4f}]")

    # Build canonical control points (4 angles x 2 distances = 8 CPs)
    canonical_params = generate_canonical_control_point_params()
    control_point_radius = 0.05
    print(f"\nTotal canonical control points: {len(canonical_params)}")
    print("  Angles: 0°, 90°, 180°, 270°")
    print("  Distances: 0.5, 1.0")
    print("  Position: fixed at 0.5")

    # ========================================================================
    # STEP 2: Get fixed positions
    # ========================================================================
    print(f"\n{'='*60}")
    print("SETTING UP FIXED POSITIONS")
    print(f"{'='*60}")

    np.random.seed(42)
    task_env.reset()

    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    # Go to HOME
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    gripper.actuate(1.0, 0.2)
    for _ in range(20):
        task_env._scene.pyrep.step()

    home_pos = tip.get_position().copy()

    # Get object position
    target_block = Shape('stack_blocks_target0')
    FIXED_OBJECT_POS = target_block.get_position().copy()

    # Hide other blocks for cleaner setup
    for i in range(1, 4):
        try:
            other_block = Shape(f'stack_blocks_target{i}')
            other_block.set_position([10, 10, 0])
        except:
            pass
    for i in range(4):
        try:
            distractor = Shape(f'stack_blocks_distractor{i}')
            distractor.set_position([10, 10 + i*0.1, 0])
        except:
            pass
    print("  Removed unnecessary blocks")

    # Compute waypoints
    pregrasp_pos = FIXED_OBJECT_POS.copy()
    pregrasp_pos[:2] += GRASP_XY_OFFSET
    pregrasp_pos[2] += PREGRASP_HEIGHT

    grasp_pos = FIXED_OBJECT_POS.copy()
    grasp_pos[:2] += GRASP_XY_OFFSET
    grasp_pos[2] += GRASP_HEIGHT

    lift_pos = pregrasp_pos.copy()

    # Get target plane position
    target_plane = Shape('stack_blocks_target_plane')
    TARGET_POS = target_plane.get_position().copy()
    PREPLACE_HEIGHT = waypoint_params.get("preplace_height", 0.20)
    PLACE_HEIGHT = 0.08

    prerelease_pos = TARGET_POS.copy()
    prerelease_pos[2] += PREPLACE_HEIGHT

    release_pos = TARGET_POS.copy()
    release_pos[2] += PLACE_HEIGHT

    print(f"  HOME pos:       [{home_pos[0]:.4f}, {home_pos[1]:.4f}, {home_pos[2]:.4f}]")
    print(f"  Object pos:     [{FIXED_OBJECT_POS[0]:.4f}, {FIXED_OBJECT_POS[1]:.4f}, {FIXED_OBJECT_POS[2]:.4f}]")
    print(f"  Target pos:     [{TARGET_POS[0]:.4f}, {TARGET_POS[1]:.4f}, {TARGET_POS[2]:.4f}]")
    print(f"  Pregrasp pos:   [{pregrasp_pos[0]:.4f}, {pregrasp_pos[1]:.4f}, {pregrasp_pos[2]:.4f}]")
    print(f"  Grasp pos:      [{grasp_pos[0]:.4f}, {grasp_pos[1]:.4f}, {grasp_pos[2]:.4f}]")
    print(f"  Lift pos:       [{lift_pos[0]:.4f}, {lift_pos[1]:.4f}, {lift_pos[2]:.4f}]")
    print(f"  Prerelease pos: [{prerelease_pos[0]:.4f}, {prerelease_pos[1]:.4f}, {prerelease_pos[2]:.4f}]")
    print(f"  Release pos:    [{release_pos[0]:.4f}, {release_pos[1]:.4f}, {release_pos[2]:.4f}]")

    # ========================================================================
    # STEP 3: Pre-filter REACH control points (IK + collision)
    # ========================================================================
    print(f"\n{'='*60}")
    print("PRE-FILTERING REACH CONTROL POINTS (IK + Collision)")
    print(f"{'='*60}")

    valid_indices, prefilter_results = prefilter_control_points_with_collision(
        robot.arm, home_pos, pregrasp_pos, GRASP_ORIENTATION,
        canonical_params, control_point_radius,
        num_samples=16, require_orientation=True
    )

    collision_failures = sum(1 for r in prefilter_results.values() if r['collision_count'] > 0)
    ik_failures = sum(1 for r in prefilter_results.values() if r['ik_failures'] > 0)
    print(f"  Valid (IK + collision-free): {len(valid_indices)}/{len(canonical_params)}")
    print(f"  Rejected due to collision: {collision_failures}")
    print(f"  Rejected due to IK failure: {ik_failures}")

    if len(valid_indices) == 0:
        print("  ERROR: No valid REACH control points found!")
        rlbench_env.shutdown()
        return

    # ========================================================================
    # STEP 4: Jitter filter for REACH
    # ========================================================================
    print(f"\n{'='*60}")
    print("JITTER FILTER FOR REACH PHASE")
    print(f"{'='*60}")

    reach_valid_indices = []

    for cp_idx in valid_indices:
        np.random.seed(42)
        task_env.reset()

        robot = task_env._scene.robot
        tip = robot.arm.get_tip()
        gripper = robot.gripper

        robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
        robot.arm.set_joint_target_velocities([0] * 7)
        gripper.actuate(1.0, 0.2)
        for _ in range(20):
            task_env._scene.pyrep.step()

        home_pos_curr = tip.get_position().copy()

        # Fix object position
        target_block = Shape('stack_blocks_target0')
        target_block.set_position(FIXED_OBJECT_POS)
        for _ in range(5):
            task_env._scene.pyrep.step()

        pregrasp_pos_curr = FIXED_OBJECT_POS.copy()
        pregrasp_pos_curr[:2] += GRASP_XY_OFFSET
        pregrasp_pos_curr[2] += PREGRASP_HEIGHT

        angle, dist_frac, pos_frac = canonical_params[cp_idx]
        cp = compute_control_point_from_params(
            home_pos_curr, pregrasp_pos_curr, control_point_radius, angle, dist_frac, pos_frac
        )

        trace = []
        prev_joints = list(robot.arm.get_joint_positions())

        for i in range(REACH_STEPS):
            t = i / (REACH_STEPS - 1)
            target_pos = parabola3D(home_pos_curr, pregrasp_pos_curr, cp, t)

            try:
                robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
                joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=GRASP_ORIENTATION)
                robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
                robot.arm.set_joint_target_velocities([0] * 7)

                for _ in range(STEPS_PER_POINT):
                    task_env._scene.pyrep.step()

                trace.append(tip.get_position().copy())
                prev_joints = list(joint_positions)
            except (IKError, ConfigurationPathError):
                if len(trace) > 0:
                    trace.append(trace[-1].copy())
                else:
                    trace.append(home_pos_curr.copy())

        trace = np.array(trace)
        jitter = compute_jitter_metric(trace)

        angle_deg = np.degrees(angle)
        if jitter <= JITTER_THRESHOLD:
            reach_valid_indices.append(cp_idx)
            print(f"  REACH CP {cp_idx}: angle={angle_deg:.0f}°, dist={dist_frac:.1f}, jitter={jitter:.4f} - PASS")
        else:
            print(f"  REACH CP {cp_idx}: angle={angle_deg:.0f}°, dist={dist_frac:.1f}, jitter={jitter:.4f} - REJECT (>{JITTER_THRESHOLD})")

    print(f"\n  Final valid REACH CPs: {len(reach_valid_indices)}/{len(valid_indices)}")

    if len(reach_valid_indices) == 0:
        print("  ERROR: No valid REACH control points after jitter filter!")
        rlbench_env.shutdown()
        return

    # ========================================================================
    # STEP 5: Pre-filter CARRY control points (IK + collision)
    # ========================================================================
    print(f"\n{'='*60}")
    print("PRE-FILTERING CARRY CONTROL POINTS (IK + Collision)")
    print(f"{'='*60}")

    np.random.seed(42)
    task_env.reset()

    robot = task_env._scene.robot
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    gripper = robot.gripper
    gripper.actuate(1.0, 0.2)
    for _ in range(20):
        task_env._scene.pyrep.step()

    print(f"  CARRY path: lift_pos → prerelease_pos")
    print(f"    Start: [{lift_pos[0]:.4f}, {lift_pos[1]:.4f}, {lift_pos[2]:.4f}]")
    print(f"    End:   [{prerelease_pos[0]:.4f}, {prerelease_pos[1]:.4f}, {prerelease_pos[2]:.4f}]")

    carry_valid_indices, carry_prefilter_results = prefilter_control_points_with_collision(
        robot.arm, lift_pos, prerelease_pos, GRASP_ORIENTATION,
        canonical_params, control_point_radius,
        num_samples=16, require_orientation=True
    )

    carry_collision_failures = sum(1 for r in carry_prefilter_results.values() if r['collision_count'] > 0)
    carry_ik_failures = sum(1 for r in carry_prefilter_results.values() if r['ik_failures'] > 0)
    print(f"  Valid (IK + collision-free): {len(carry_valid_indices)}/{len(canonical_params)}")
    print(f"  Rejected due to collision: {carry_collision_failures}")
    print(f"  Rejected due to IK failure: {carry_ik_failures}")

    if len(carry_valid_indices) == 0:
        print("  ERROR: No valid CARRY control points found!")
        rlbench_env.shutdown()
        return

    # ========================================================================
    # STEP 6: Jitter filter for CARRY
    # ========================================================================
    print(f"\n{'='*60}")
    print("JITTER FILTER FOR CARRY PHASE")
    print(f"{'='*60}")

    carry_final_valid_indices = []

    for idx_num, cp_idx in enumerate(carry_valid_indices):
        np.random.seed(42)
        task_env.reset()

        robot = task_env._scene.robot
        tip = robot.arm.get_tip()
        gripper = robot.gripper

        robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
        robot.arm.set_joint_target_velocities([0] * 7)
        gripper.actuate(1.0, 0.2)
        for _ in range(10):
            task_env._scene.pyrep.step()

        try:
            joint_positions = robot.arm.solve_ik_via_jacobian(lift_pos, euler=GRASP_ORIENTATION)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)
            gripper.actuate(0.0, 0.2)
            for _ in range(20):
                task_env._scene.pyrep.step()
        except (IKError, ConfigurationPathError):
            continue

        angle, dist_frac, pos_frac = canonical_params[cp_idx]
        cp = compute_control_point_from_params(
            lift_pos, prerelease_pos, control_point_radius, angle, dist_frac, pos_frac
        )

        trace = []
        prev_joints = list(robot.arm.get_joint_positions())

        for i in range(CARRY_STEPS):
            t = i / (CARRY_STEPS - 1)
            target_pos = parabola3D(lift_pos, prerelease_pos, cp, t)

            try:
                robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
                joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=GRASP_ORIENTATION)
                robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
                robot.arm.set_joint_target_velocities([0] * 7)

                for _ in range(STEPS_PER_POINT):
                    task_env._scene.pyrep.step()

                trace.append(tip.get_position().copy())
                prev_joints = list(joint_positions)
            except (IKError, ConfigurationPathError):
                if len(trace) > 0:
                    trace.append(trace[-1].copy())
                else:
                    trace.append(lift_pos.copy())

        trace = np.array(trace)
        jitter = compute_jitter_metric(trace)

        angle_deg = np.degrees(angle)
        if jitter <= JITTER_THRESHOLD:
            carry_final_valid_indices.append(cp_idx)
            print(f"  CARRY CP {cp_idx}: angle={angle_deg:.0f}°, dist={dist_frac:.1f}, jitter={jitter:.4f} - PASS")
        else:
            print(f"  CARRY CP {cp_idx}: angle={angle_deg:.0f}°, dist={dist_frac:.1f}, jitter={jitter:.4f} - REJECT (>{JITTER_THRESHOLD})")

    print(f"\n  Final valid CARRY CPs: {len(carry_final_valid_indices)}/{len(carry_valid_indices)}")

    if len(carry_final_valid_indices) == 0:
        print("  ERROR: No valid CARRY control points after jitter filter!")
        rlbench_env.shutdown()
        return

    # ========================================================================
    # STEP 7: Save everything to file
    # ========================================================================
    print(f"\n{'='*60}")
    print("SAVING INITIALIZATION DATA")
    print(f"{'='*60}")

    np.savez(
        OUTPUT_FILE,
        # Positions
        home_pos=home_pos,
        object_pos=FIXED_OBJECT_POS,
        target_pos=TARGET_POS,
        pregrasp_pos=pregrasp_pos,
        grasp_pos=grasp_pos,
        lift_pos=lift_pos,
        prerelease_pos=prerelease_pos,
        release_pos=release_pos,
        # Grasp parameters
        grasp_orientation=GRASP_ORIENTATION,
        grasp_xy_offset=GRASP_XY_OFFSET,
        pregrasp_height=PREGRASP_HEIGHT,
        grasp_height=GRASP_HEIGHT,
        lift_height=LIFT_HEIGHT,
        preplace_height=PREPLACE_HEIGHT,
        place_height=PLACE_HEIGHT,
        # Control point parameters
        canonical_params=np.array(canonical_params),
        control_point_radius=control_point_radius,
        # Valid control point indices
        reach_valid_indices=np.array(reach_valid_indices),
        carry_valid_indices=np.array(carry_final_valid_indices),
        # Configuration
        reach_steps=REACH_STEPS,
        carry_steps=CARRY_STEPS,
        jitter_threshold=JITTER_THRESHOLD,
    )

    print(f"\n  Saved to: {OUTPUT_FILE}")
    print(f"\n  Contents:")
    print(f"    Positions: home, object, target, pregrasp, grasp, lift, prerelease, release")
    print(f"    Grasp params: orientation, xy_offset, heights")
    print(f"    Valid REACH CPs: {len(reach_valid_indices)} indices")
    print(f"    Valid CARRY CPs: {len(carry_final_valid_indices)} indices")
    print(f"    Canonical params: {len(canonical_params)} total (4 angles x 2 distances)")

    # Print valid CPs details
    print(f"\n  Valid REACH control points:")
    for idx in reach_valid_indices:
        angle, dist, pos = canonical_params[idx]
        print(f"    CP {idx}: angle={np.degrees(angle):.0f}°, dist={dist:.1f}, pos={pos:.1f}")

    print(f"\n  Valid CARRY control points:")
    for idx in carry_final_valid_indices:
        angle, dist, pos = canonical_params[idx]
        print(f"    CP {idx}: angle={np.degrees(angle):.0f}°, dist={dist:.1f}, pos={pos:.1f}")

    rlbench_env.shutdown()

    print("\n" + "=" * 60)
    print("INITIALIZATION COMPLETE")
    print("=" * 60)
    print(f"\nTo use this data in other scripts:")
    print(f"  data = np.load('{OUTPUT_FILE}')")
    print(f"  reach_valid_indices = data['reach_valid_indices']")
    print(f"  carry_valid_indices = data['carry_valid_indices']")
    print(f"  home_pos = data['home_pos']")
    print(f"  # etc...")


if __name__ == "__main__":
    main()
