"""
Debug script to diagnose JITTER in REACH phase trajectories.

Generates 3 traces for comparison:
- Trace A: Pure math (desired Bezier curve)
- Trace B: Kinematic IK trace (IK+FK only, no dynamics/stepping)
- Trace C: Executed trace (current method with sim stepping)

If B is smooth but C jitters → controller/physics problem
If B already jitters → IK branch switching / singularity sensitivity
"""

import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'custom_tasks'))

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from rlbench.environment import Environment
from pyrep.errors import ConfigurationPathError, IKError

from utils import (
    create_obs_config, create_action_mode, get_task_classes,
    generate_canonical_control_point_params,
    compute_control_point_from_params,
    prefilter_control_points,
    prefilter_control_points_with_collision,
    parabola3D,
    HOME_JOINTS,
)


def generate_three_traces(task_env, start_pos, end_pos, control_point, orientation,
                          num_steps=64, steps_per_point=5):
    """
    Generate 3 traces for jitter diagnosis:
    - Trace A: Pure math (desired Bezier curve)
    - Trace B: Kinematic IK trace (IK+FK only, no sim stepping)
    - Trace C: Executed trace (with sim stepping)
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()

    # ========================================================================
    # TRACE A: Pure math - desired Bezier curve positions
    # ========================================================================
    trace_a = []
    for i in range(num_steps):
        t = i / (num_steps - 1)
        pos = parabola3D(start_pos, end_pos, control_point, t)
        trace_a.append(pos.copy())
    trace_a = np.array(trace_a)

    # ========================================================================
    # TRACE B: Kinematic IK trace - IK+FK only, NO sim stepping
    # ========================================================================
    trace_b = []
    ik_failures_b = 0

    # Start from current joint config
    prev_joints = list(robot.arm.get_joint_positions())
    initial_joints = prev_joints.copy()

    for i in range(num_steps):
        t = i / (num_steps - 1)
        target_pos = parabola3D(start_pos, end_pos, control_point, t)

        try:
            # Set robot to previous config (no dynamics)
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)

            # Solve IK seeded from current joints
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)

            # Set joints to IK solution (no dynamics)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)

            # Record tip position IMMEDIATELY (no sim stepping)
            actual_pos = tip.get_position()
            trace_b.append(actual_pos.copy())

            prev_joints = list(joint_positions)

        except (IKError, ConfigurationPathError) as e:
            ik_failures_b += 1
            if len(trace_b) > 0:
                trace_b.append(trace_b[-1].copy())
            else:
                trace_b.append(start_pos.copy())

    trace_b = np.array(trace_b)

    # ========================================================================
    # TRACE C: Executed trace WITH PHYSICS (to test if collision-free CPs are smooth)
    # ========================================================================
    robot.arm.set_joint_positions(initial_joints, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    for _ in range(10):
        task_env._scene.pyrep.step()

    trace_c = []
    ik_failures_c = 0
    prev_joints = list(robot.arm.get_joint_positions())

    for i in range(num_steps):
        t = i / (num_steps - 1)
        target_pos = parabola3D(start_pos, end_pos, control_point, t)

        try:
            # Set robot to previous config (no dynamics)
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)

            # Solve IK seeded from current joints
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)

            # Teleport to IK solution (no dynamics)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Step physics (this is where jitter appears if there are collisions)
            for _ in range(steps_per_point):
                task_env._scene.pyrep.step()

            actual_pos = tip.get_position()
            trace_c.append(actual_pos.copy())

            prev_joints = list(joint_positions)

        except (IKError, ConfigurationPathError) as e:
            ik_failures_c += 1
            if len(trace_c) > 0:
                trace_c.append(trace_c[-1].copy())
            else:
                trace_c.append(start_pos.copy())

    trace_c = np.array(trace_c)

    # Compute errors
    final_error_b = np.linalg.norm(trace_b[-1] - end_pos) * 1000
    final_error_c = np.linalg.norm(trace_c[-1] - end_pos) * 1000

    return {
        'trace_a': trace_a,  # Pure math
        'trace_b': trace_b,  # Kinematic IK (no stepping)
        'trace_c': trace_c,  # Executed (with stepping)
        'ik_failures_b': ik_failures_b,
        'ik_failures_c': ik_failures_c,
        'final_error_b_mm': final_error_b,
        'final_error_c_mm': final_error_c,
    }


def compute_jitter_metric(trace):
    """Compute jitter as sum of direction changes (higher = more jitter)."""
    if len(trace) < 3:
        return 0.0
    velocities = np.diff(trace, axis=0)
    # Normalize velocities
    norms = np.linalg.norm(velocities, axis=1, keepdims=True)
    norms = np.where(norms < 1e-8, 1e-8, norms)
    velocities_norm = velocities / norms
    # Compute dot products between consecutive velocities
    dots = np.sum(velocities_norm[:-1] * velocities_norm[1:], axis=1)
    # Jitter = sum of (1 - dot), i.e., direction changes
    jitter = np.sum(1 - dots)
    return jitter


def main():
    print("=" * 60)
    print("JITTER DIAGNOSTIC: Comparing 3 Traces")
    print("=" * 60)

    task_name = "meat_off_grill"

    obs_config = create_obs_config(save_video=False)
    action_mode = create_action_mode("abs")  # MUST be "abs" for JointPosition control

    rlbench_env = Environment(
        action_mode=action_mode, obs_config=obs_config, headless=True
    )
    rlbench_env.launch()

    tasks = get_task_classes([task_name])
    t_cls = tasks[0]
    task_env = rlbench_env.get_task(t_cls)
    task_env.set_variation(0)

    # Build canonical control points
    canonical_params = generate_canonical_control_point_params(n_per_axis=5)
    control_point_radius = 0.05
    print(f"\nTotal canonical control points: {len(canonical_params)}")

    # ========================================================================
    # PREFILTER: Do one reset to get positions, then prefilter all 125 CPs
    # ========================================================================
    print(f"\n{'='*60}")
    print("PRE-FILTERING CONTROL POINTS")
    print(f"{'='*60}")

    np.random.seed(42)
    task_env.reset()

    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    # Go to HOME
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    gripper.actuate(1.0, 0.2)
    for _ in range(20):
        task_env._scene.pyrep.step()

    home_pos = tip.get_position().copy()
    home_ori = tip.get_orientation()

    # Get object position for pregrasp
    from pyrep.objects.shape import Shape
    chicken = Shape('chicken')
    object_pos = chicken.get_position()

    PREGRASP_HEIGHT = 0.05
    pregrasp_pos = object_pos.copy()
    pregrasp_pos[2] += PREGRASP_HEIGHT

    print(f"  HOME pos:     [{home_pos[0]:.4f}, {home_pos[1]:.4f}, {home_pos[2]:.4f}]")
    print(f"  HOME ori:     [{home_ori[0]:.4f}, {home_ori[1]:.4f}, {home_ori[2]:.4f}]")
    print(f"  Pregrasp pos: [{pregrasp_pos[0]:.4f}, {pregrasp_pos[1]:.4f}, {pregrasp_pos[2]:.4f}]")

    # Prefilter all 125 control points WITH COLLISION CHECKING
    print(f"\n  Testing IK + COLLISION for all {len(canonical_params)} control points...")
    valid_indices, prefilter_results = prefilter_control_points_with_collision(
        robot.arm, home_pos, pregrasp_pos, home_ori,
        canonical_params, control_point_radius,
        num_samples=16, require_orientation=True  # More samples for better collision coverage
    )

    # Count how many failed due to collision vs IK
    collision_failures = sum(1 for r in prefilter_results.values() if r['collision_count'] > 0)
    ik_failures = sum(1 for r in prefilter_results.values() if r['ik_failures'] > 0)
    print(f"  Valid (IK + collision-free): {len(valid_indices)}/{len(canonical_params)}")
    print(f"  Rejected due to collision: {collision_failures}")
    print(f"  Rejected due to IK failure: {ik_failures}")

    if len(valid_indices) == 0:
        print("  ERROR: No valid control points found!")
        rlbench_env.shutdown()
        return

    # ========================================================================
    # SECOND FILTER: Quick physics rollout to filter high-jitter CPs
    # ========================================================================
    print(f"\n{'='*60}")
    print("SECOND FILTER: Quick physics rollout to filter high-jitter CPs")
    print(f"{'='*60}")

    # Use same parameters as demo generation for accurate filtering
    JITTER_THRESHOLD = 0.08  # Tighter threshold to catch edge cases
    NUM_FILTER_STEPS = 64    # Match demo length for accurate jitter measurement
    STEPS_PER_POINT = 5      # Match demo physics steps

    final_valid_indices = []
    jitter_results = {}  # Store jitter values for later analysis

    for cp_idx in valid_indices:
        # Reset scene
        np.random.seed(42)
        task_env.reset()

        robot = task_env._scene.robot
        tip = robot.arm.get_tip()
        gripper = robot.gripper

        # Go to HOME
        robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
        robot.arm.set_joint_target_velocities([0] * 7)
        gripper.actuate(1.0, 0.2)
        for _ in range(20):
            task_env._scene.pyrep.step()

        home_pos = tip.get_position().copy()
        home_ori = tip.get_orientation()

        chicken = Shape('chicken')
        object_pos = chicken.get_position()
        pregrasp_pos = object_pos.copy()
        pregrasp_pos[2] += PREGRASP_HEIGHT

        # Compute control point
        angle, dist_frac, pos_frac = canonical_params[cp_idx]
        cp = compute_control_point_from_params(
            home_pos, pregrasp_pos, control_point_radius, angle, dist_frac, pos_frac
        )

        # Full rollout with physics (64 steps, 5 physics steps per point - matching demo)
        trace = []
        prev_joints = list(robot.arm.get_joint_positions())

        for i in range(NUM_FILTER_STEPS):
            t = i / (NUM_FILTER_STEPS - 1)
            target_pos = parabola3D(home_pos, pregrasp_pos, cp, t)

            try:
                robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
                joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=home_ori)
                robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
                robot.arm.set_joint_target_velocities([0] * 7)

                for _ in range(STEPS_PER_POINT):
                    task_env._scene.pyrep.step()

                trace.append(tip.get_position().copy())
                prev_joints = list(joint_positions)
            except (IKError, ConfigurationPathError):
                if len(trace) > 0:
                    trace.append(trace[-1].copy())
                else:
                    trace.append(home_pos.copy())

        trace = np.array(trace)
        jitter = compute_jitter_metric(trace)
        jitter_results[cp_idx] = jitter

        if jitter <= JITTER_THRESHOLD:
            final_valid_indices.append(cp_idx)
            print(f"  CP {cp_idx:3d}: jitter={jitter:.4f} - PASS")
        else:
            print(f"  CP {cp_idx:3d}: jitter={jitter:.4f} - REJECT (>{JITTER_THRESHOLD})")

    print(f"\n  Final valid CPs: {len(final_valid_indices)}/{len(valid_indices)} (after jitter filter)")

    if len(final_valid_indices) == 0:
        print("  ERROR: No valid control points after jitter filter!")
        rlbench_env.shutdown()
        return

    # Test 10 different control points from final valid set
    num_tests = min(10, len(final_valid_indices))
    test_cp_indices = [final_valid_indices[i] for i in np.linspace(0, len(final_valid_indices)-1, num_tests, dtype=int)]

    print(f"\n{'='*60}")
    print(f"DIAGNOSTIC: Testing {num_tests} different CPs (collision-filtered)")
    print(f"Generating 3 traces per CP: A (math), B (IK only), C (executed with physics)")
    print(f"{'='*60}")

    all_results = []

    for trial_idx, cp_idx in enumerate(test_cp_indices):
        print(f"\n--- Trial {trial_idx+1}/{num_tests} (cp_idx={cp_idx}) ---")

        # FULL SCENE RESET for each trial
        np.random.seed(42)  # Same scene each time
        task_env.reset()

        robot = task_env._scene.robot
        tip = robot.arm.get_tip()
        gripper = robot.gripper

        # Go to HOME position
        robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
        robot.arm.set_joint_target_velocities([0] * 7)
        gripper.actuate(1.0, 0.2)  # Open gripper

        for _ in range(20):
            task_env._scene.pyrep.step()

        # Get positions
        home_pos = tip.get_position().copy()
        home_ori = tip.get_orientation()

        # Get object position (for pregrasp target)
        chicken = Shape('chicken')
        object_pos = chicken.get_position()

        # Pregrasp position (5cm above object)
        pregrasp_pos = object_pos.copy()
        pregrasp_pos[2] += PREGRASP_HEIGHT

        print(f"  HOME pos:     [{home_pos[0]:.4f}, {home_pos[1]:.4f}, {home_pos[2]:.4f}]")
        print(f"  Pregrasp pos: [{pregrasp_pos[0]:.4f}, {pregrasp_pos[1]:.4f}, {pregrasp_pos[2]:.4f}]")

        # Get control point params
        angle, dist_frac, pos_frac = canonical_params[cp_idx]
        print(f"  Params: angle={np.degrees(angle):.1f}°, dist={dist_frac:.2f}, pos={pos_frac:.2f}")

        # Compute control point
        cp = compute_control_point_from_params(
            home_pos, pregrasp_pos, control_point_radius, angle, dist_frac, pos_frac
        )

        # Generate all 3 traces
        traces = generate_three_traces(
            task_env, home_pos, pregrasp_pos, cp, home_ori,
            num_steps=64, steps_per_point=5
        )

        # Compute jitter metrics
        jitter_a = compute_jitter_metric(traces['trace_a'])
        jitter_b = compute_jitter_metric(traces['trace_b'])
        jitter_c = compute_jitter_metric(traces['trace_c'])

        print(f"  Trace A (math):     jitter={jitter_a:.4f}")
        print(f"  Trace B (IK only):  jitter={jitter_b:.4f}, IK_fail={traces['ik_failures_b']}, error={traces['final_error_b_mm']:.1f}mm")
        print(f"  Trace C (executed): jitter={jitter_c:.4f}, IK_fail={traces['ik_failures_c']}, error={traces['final_error_c_mm']:.1f}mm")

        # Diagnosis
        if jitter_b > jitter_a * 2:
            diagnosis = "IK PROBLEM (B already jittery)"
        elif jitter_c > jitter_b * 2:
            diagnosis = "PHYSICS/CONTROLLER PROBLEM (C jitters but B smooth)"
        else:
            diagnosis = "TRACES SIMILAR (no major jitter source identified)"
        print(f"  DIAGNOSIS: {diagnosis}")

        all_results.append({
            "cp_idx": cp_idx,
            "params": (angle, dist_frac, pos_frac),
            "traces": traces,
            "jitter_a": jitter_a,
            "jitter_b": jitter_b,
            "jitter_c": jitter_c,
            "diagnosis": diagnosis,
            "home_pos": home_pos,
            "pregrasp_pos": pregrasp_pos,
            "control_point": cp
        })

    # ========================================================================
    # SUMMARY
    # ========================================================================
    print(f"\n{'='*60}")
    print("JITTER DIAGNOSTIC SUMMARY")
    print(f"{'='*60}")

    print("\nJitter metrics (lower = smoother):")
    print(f"  {'Trial':>5} | {'Trace A':>10} | {'Trace B':>10} | {'Trace C':>10} | Diagnosis")
    print(f"  {'-'*5}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*30}")
    for i, r in enumerate(all_results):
        print(f"  {i+1:5d} | {r['jitter_a']:10.4f} | {r['jitter_b']:10.4f} | {r['jitter_c']:10.4f} | {r['diagnosis']}")

    # Overall diagnosis
    avg_jitter_a = np.mean([r['jitter_a'] for r in all_results])
    avg_jitter_b = np.mean([r['jitter_b'] for r in all_results])
    avg_jitter_c = np.mean([r['jitter_c'] for r in all_results])

    print(f"\nAverage jitter: A={avg_jitter_a:.4f}, B={avg_jitter_b:.4f}, C={avg_jitter_c:.4f}")

    if avg_jitter_b > avg_jitter_a * 1.5:
        print("\n*** ROOT CAUSE: IK solver (solve_ik_via_jacobian) introduces jitter ***")
        print("    Possible causes: branch switching, singularity sensitivity, step size limits")
    elif avg_jitter_c > avg_jitter_b * 1.5:
        print("\n*** ROOT CAUSE: Physics/controller introduces jitter ***")
        print("    Possible causes: simulation stepping, dynamics, controller fighting teleport")
    else:
        print("\n*** All traces have similar jitter - check if jitter is acceptable ***")

    # ========================================================================
    # VISUALIZATION - Show all 10 tested demo trajectories
    # ========================================================================
    print(f"\n{'='*60}")
    print("Creating diagnostic visualization...")
    print(f"{'='*60}")

    fig = plt.figure(figsize=(18, 12))

    r0 = all_results[0]
    colors = plt.cm.tab10(np.linspace(0, 1, len(all_results)))

    # Plot 1: 3D view - ALL executed trajectories (Trace C)
    ax1 = fig.add_subplot(231, projection='3d')
    for i, r in enumerate(all_results):
        trace_c = r['traces']['trace_c']
        ax1.plot(trace_c[:, 0], trace_c[:, 1], trace_c[:, 2], '-', color=colors[i],
                 linewidth=1.5, alpha=0.8, label=f'CP{r["cp_idx"]} (j={r["jitter_c"]:.2f})')
        ax1.scatter(*r["control_point"], c=[colors[i]], marker='x', s=30, alpha=0.5)

    ax1.scatter(*r0["home_pos"], c='black', marker='o', s=150, zorder=10, label='HOME')
    ax1.scatter(*r0["pregrasp_pos"], c='orange', marker='*', s=200, zorder=10, label='Pregrasp')

    ax1.set_xlabel('X (m)')
    ax1.set_ylabel('Y (m)')
    ax1.set_zlabel('Z (m)')
    ax1.set_title(f'3D View - All {len(all_results)} Executed Trajectories')
    ax1.legend(fontsize=6, loc='upper left')

    # Plot 2: XZ view (side) - ALL executed trajectories
    ax2 = fig.add_subplot(232)
    for i, r in enumerate(all_results):
        trace_c = r['traces']['trace_c']
        ax2.plot(trace_c[:, 0], trace_c[:, 2], '-', color=colors[i], linewidth=1.5, alpha=0.8)

    ax2.scatter(r0["home_pos"][0], r0["home_pos"][2], c='black', marker='o', s=150, zorder=10)
    ax2.scatter(r0["pregrasp_pos"][0], r0["pregrasp_pos"][2], c='orange', marker='*', s=200, zorder=10)

    ax2.set_xlabel('X (m)')
    ax2.set_ylabel('Z (m)')
    ax2.set_title('XZ View (Side) - Executed Trajectories')
    ax2.grid(True, alpha=0.3)

    # Plot 3: XY view (top) - ALL executed trajectories
    ax3 = fig.add_subplot(233)
    for i, r in enumerate(all_results):
        trace_c = r['traces']['trace_c']
        ax3.plot(trace_c[:, 0], trace_c[:, 1], '-', color=colors[i], linewidth=1.5, alpha=0.8)

    ax3.scatter(r0["home_pos"][0], r0["home_pos"][1], c='black', marker='o', s=150, zorder=10)
    ax3.scatter(r0["pregrasp_pos"][0], r0["pregrasp_pos"][1], c='orange', marker='*', s=200, zorder=10)

    ax3.set_xlabel('X (m)')
    ax3.set_ylabel('Y (m)')
    ax3.set_title('XY View (Top) - Executed Trajectories')
    ax3.grid(True, alpha=0.3)

    # Plot 4: Deviation from desired (Trace A) over time for all CPs
    ax4 = fig.add_subplot(234)
    for i, r in enumerate(all_results):
        trace_a = r['traces']['trace_a']
        trace_c = r['traces']['trace_c']
        error_c = np.linalg.norm(trace_c - trace_a, axis=1) * 1000  # mm
        ax4.plot(error_c, '-', color=colors[i], linewidth=1.5, alpha=0.8, label=f'CP{r["cp_idx"]}')

    ax4.set_xlabel('Waypoint')
    ax4.set_ylabel('Deviation from Desired (mm)')
    ax4.set_title('Tracking Error: Executed vs Desired')
    ax4.legend(fontsize=6, ncol=2)
    ax4.grid(True, alpha=0.3)

    # Plot 5: Jitter comparison bar chart
    ax5 = fig.add_subplot(235)
    x = np.arange(len(all_results))
    width = 0.35

    bars_b = ax5.bar(x - width/2, [r['jitter_b'] for r in all_results], width, label='B: IK only', color='green', alpha=0.7)
    bars_c = ax5.bar(x + width/2, [r['jitter_c'] for r in all_results], width, label='C: Executed', color='red', alpha=0.7)

    ax5.set_xlabel('Control Point')
    ax5.set_ylabel('Jitter Metric')
    ax5.set_title('Jitter: IK-only vs Executed')
    ax5.set_xticks(x)
    ax5.set_xticklabels([f'{r["cp_idx"]}' for r in all_results], fontsize=8)
    ax5.legend(fontsize=8)

    # Plot 6: Step size over time for all executed trajectories
    ax6 = fig.add_subplot(236)
    for i, r in enumerate(all_results):
        trace_c = r['traces']['trace_c']
        velocities = np.diff(trace_c, axis=0)
        vel_mag = np.linalg.norm(velocities, axis=1) * 1000  # mm/step
        ax6.plot(vel_mag, '-', color=colors[i], linewidth=1, alpha=0.7)

    ax6.set_xlabel('Waypoint')
    ax6.set_ylabel('Step Size (mm)')
    ax6.set_title('Step Size Over Trajectory (all CPs)')
    ax6.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('reach_only_debug.png', dpi=150)
    print(f"\nPlot saved to: reach_only_debug.png")
    plt.show()

    rlbench_env.shutdown()
    print("\n" + "=" * 60)
    print("DIAGNOSTIC COMPLETE")
    print("=" * 60)


if __name__ == "__main__":
    main()
