"""
Visualize EE trajectories from the pick-and-place dataset.

4 subplots (all 3D):
1. Overall trajectory - full EE path
2. Reach phase only - start -> pregrasp (64 steps)
3. Carry phase only - lift -> prerelease (64 steps)
4. Relative trajectory - both reach and carry normalized to [0,1] along start-end axis

No simulation required - just reads the saved demo data.
"""

import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from absl import app
from absl import flags

FLAGS = flags.FLAGS

# Default path matches dataset_generator_pick_place_cp.py output
# Use DPPO_DATA_DIR env var, or fall back to local data/ directory
DEFAULT_DATA_PATH = os.path.join(
    os.environ.get(
        "DPPO_DATA_DIR",
        os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
    ),
    "meat_off_grill", "variation0"
)

flags.DEFINE_string("data_path", DEFAULT_DATA_PATH,
    "Path to the dataset (e.g., .../stack_blocks/variation0).")
flags.DEFINE_string("output_path", None,
    "Output path for the image. Defaults to data_path.")
flags.DEFINE_string("split", "train",
    "Which split to visualize: train or eval.")
flags.DEFINE_bool("color_by_reach_cp", True,
    "Color trajectories by reach CP index.")
flags.DEFINE_bool("color_by_carry_cp", False,
    "Color trajectories by carry CP index (overrides color_by_reach_cp).")

# Phase configuration (must match dataset_generator_pick_place_cp.py)
PHASE_STEPS = {
    "reach": 64,           # Phase 1: Bezier curve start -> pregrasp (LEARNED)
    "descend": 8,          # Phase 2: pregrasp -> grasp (linear Z)
    "grasp": 8,            # Phase 3: hold + close gripper
    "lift": 8,             # Phase 4: grasp -> lift (linear Z)
    "carry": 64,           # Phase 5: Bezier curve lift -> prerelease (LEARNED)
    "descend_release": 8,  # Phase 6: prerelease -> release (linear Z)
    "release": 8,          # Phase 7: hold + open gripper
}


def get_phase_indices():
    """Get start/end indices for each phase."""
    indices = {}
    current = 0
    for phase, steps in PHASE_STEPS.items():
        indices[phase] = (current, current + steps)
        current += steps
    return indices


def load_demo_trajectory(episode_path):
    """Load EE positions from a demo's low_dim_obs.pkl."""
    pkl_path = os.path.join(episode_path, "low_dim_obs.pkl")
    if not os.path.exists(pkl_path):
        return None

    with open(pkl_path, "rb") as f:
        demo = pickle.load(f)

    positions = []
    for obs in demo:
        if hasattr(obs, 'gripper_pose') and obs.gripper_pose is not None:
            pos = obs.gripper_pose[:3]
            positions.append(pos)

    if len(positions) > 0:
        return np.array(positions)
    return None


def build_local_frame(start_pos, end_pos):
    """
    Build trajectory-relative orthonormal frame (same as utils.py).

    Returns:
        line_vec_norm: unit vector along trajectory
        perp1: perpendicular axis in vertical plane (has Z component)
        perp2: perpendicular axis horizontal (cross product)
    """
    line_vec = end_pos - start_pos
    line_vec_norm = line_vec / np.linalg.norm(line_vec)

    # World up vector
    world_up = np.array([0.0, 0.0, 1.0])

    # Project world_up onto the plane perpendicular to line_vec
    dot = np.dot(world_up, line_vec_norm)
    perp1 = world_up - dot * line_vec_norm
    perp1_len = np.linalg.norm(perp1)

    if perp1_len < 1e-6:
        # line_vec is nearly vertical, use world +Y as fallback
        world_forward = np.array([0.0, 1.0, 0.0])
        dot = np.dot(world_forward, line_vec_norm)
        perp1 = world_forward - dot * line_vec_norm
        perp1_len = np.linalg.norm(perp1)

    perp1 = perp1 / perp1_len

    # perp2 is perpendicular to both line_vec and perp1
    perp2 = np.cross(line_vec_norm, perp1)
    perp2 = perp2 / np.linalg.norm(perp2)

    return line_vec_norm, perp1, perp2


def normalize_trajectory(traj, start_pos, end_pos):
    """
    Normalize trajectory relative to start and end positions.

    Uses the SAME trajectory-relative frame as utils.build_local_frame():
    - X-axis (progress): projection along start->end direction, scaled to [0, 1]
    - Y-axis: offset along perp1 (vertical plane), scaled by path length
    - Z-axis: offset along perp2 (horizontal), scaled by path length

    This ensures that REACH and CARRY curves with the same CP params will overlap
    in this normalized coordinate system.
    """
    # Direction from start to end
    direction = end_pos - start_pos
    dist = np.linalg.norm(direction)
    if dist < 1e-6:
        return traj - start_pos  # Fallback: just translate

    # Build trajectory-relative frame (same as utils.py)
    line_dir, perp1, perp2 = build_local_frame(start_pos, end_pos)

    normalized = np.zeros_like(traj)
    for i, p in enumerate(traj):
        # Progress along the line (0 at start, 1 at end)
        t = np.dot(p - start_pos, line_dir) / dist

        # Point on the straight line at this progress
        point_on_line = start_pos + t * direction

        # Offset from the straight line (perpendicular deviation)
        offset = p - point_on_line

        # Project offset onto perp1 and perp2, normalize by path length
        normalized[i, 0] = t  # Progress: 0 at start, 1 at end
        normalized[i, 1] = np.dot(offset, perp1) / dist  # Offset along perp1 (vertical)
        normalized[i, 2] = np.dot(offset, perp2) / dist  # Offset along perp2 (horizontal)

    return normalized


def main(argv):
    # Find variation path
    if os.path.exists(os.path.join(FLAGS.data_path, "train")):
        variation_path = FLAGS.data_path
    elif os.path.exists(os.path.join(FLAGS.data_path, "meat_off_grill", "variation0")):
        variation_path = os.path.join(FLAGS.data_path, "meat_off_grill", "variation0")
    else:
        print(f"Error: Cannot find data at {FLAGS.data_path}")
        print(f"Make sure dataset_generator_pick_place_cp.py has been run first.")
        return

    print(f"Loading data from: {variation_path}")

    # Load metadata
    split_root = os.path.join(variation_path, FLAGS.split)
    episodes_path = os.path.join(split_root, "episodes")
    metadata_path = os.path.join(split_root, f"{FLAGS.split}_metadata.npy")

    if not os.path.exists(episodes_path):
        print(f"Error: Episodes path not found: {episodes_path}")
        return

    metadata = None
    if os.path.exists(metadata_path):
        metadata = np.load(metadata_path, allow_pickle=True)
        print(f"Total {FLAGS.split} demos in metadata: {len(metadata)}")
    else:
        print(f"Warning: {FLAGS.split}_metadata.npy not found")

    # Get episode directories
    episode_dirs = sorted([d for d in os.listdir(episodes_path) if d.startswith("episode")])
    print(f"Episode directories found: {len(episode_dirs)}")

    # Get phase indices
    phase_indices = get_phase_indices()
    reach_start, reach_end = phase_indices["reach"]
    carry_start, carry_end = phase_indices["carry"]

    print(f"\nPhase indices:")
    print(f"  REACH: [{reach_start}, {reach_end})")
    print(f"  CARRY: [{carry_start}, {carry_end})")

    # Load all trajectories
    print("\nLoading trajectories...")
    trajectories = []
    reach_cp_indices = []
    carry_cp_indices = []
    task_successes = []
    waypoints_list = []

    for i, ep_dir in enumerate(episode_dirs):
        ep_path = os.path.join(episodes_path, ep_dir)
        traj = load_demo_trajectory(ep_path)

        if traj is not None and len(traj) > 0:
            trajectories.append(traj)

            if metadata is not None and i < len(metadata):
                # Support both old format (reach_cp_index/carry_cp_index) and new format (cp_index)
                cp_idx = metadata[i].get('cp_index', metadata[i].get('reach_cp_index', 0))
                reach_cp_indices.append(cp_idx)
                carry_cp_indices.append(metadata[i].get('carry_cp_index', cp_idx))
                task_successes.append(metadata[i].get('task_success', False))

                # Get waypoints for normalization
                waypoints = {
                    'home': np.array(metadata[i].get('home_pos', [0, 0, 0])),
                    'pregrasp': np.array(metadata[i].get('pregrasp_pos', [0, 0, 0])),
                    'lift': np.array(metadata[i].get('lift_pos', [0, 0, 0])),
                    'prerelease': np.array(metadata[i].get('prerelease_pos', [0, 0, 0])),
                }
                waypoints_list.append(waypoints)
            else:
                reach_cp_indices.append(0)
                carry_cp_indices.append(0)
                task_successes.append(False)
                waypoints_list.append(None)

        if (i + 1) % 50 == 0:
            print(f"  Loaded {i + 1}/{len(episode_dirs)} episodes...")

    print(f"Successfully loaded {len(trajectories)} trajectories")

    if len(trajectories) == 0:
        print("Error: No trajectories loaded!")
        return

    # DEBUG: Print normalized offsets at midpoint for first few demos
    print("\n=== DEBUG: Normalized offsets at trajectory midpoint (trajectory-relative frame) ===")
    for i in range(min(3, len(trajectories))):
        traj = trajectories[i]
        wp = waypoints_list[i]
        if wp is None:
            continue

        reach_traj = traj[reach_start:reach_end]
        carry_traj = traj[carry_start:carry_end]

        # Normalize both trajectories using trajectory-relative frames
        reach_norm = normalize_trajectory(reach_traj, wp['home'], wp['pregrasp'])
        carry_norm = normalize_trajectory(carry_traj, wp['lift'], wp['prerelease'])

        # Get midpoint values
        reach_mid_idx = len(reach_norm) // 2
        carry_mid_idx = len(carry_norm) // 2

        reach_len = np.linalg.norm(wp['pregrasp'] - wp['home'])
        carry_len = np.linalg.norm(wp['prerelease'] - wp['lift'])

        print(f"Demo {i} (reach_cp={reach_cp_indices[i]}, carry_cp={carry_cp_indices[i]}):")
        print(f"  REACH: path_len={reach_len:.4f}")
        print(f"    normalized midpoint: t={reach_norm[reach_mid_idx, 0]:.3f}, perp1={reach_norm[reach_mid_idx, 1]:.6f}, perp2={reach_norm[reach_mid_idx, 2]:.6f}")
        print(f"  CARRY: path_len={carry_len:.4f}")
        print(f"    normalized midpoint: t={carry_norm[carry_mid_idx, 0]:.3f}, perp1={carry_norm[carry_mid_idx, 1]:.6f}, perp2={carry_norm[carry_mid_idx, 2]:.6f}")
    print("=" * 60)

    # Determine coloring
    if FLAGS.color_by_carry_cp:
        color_indices = carry_cp_indices
        color_label = "Carry CP"
    else:
        color_indices = reach_cp_indices
        color_label = "Reach CP"

    unique_indices = sorted(set(color_indices))
    n_colors = len(unique_indices)
    print(f"\nColoring by: {color_label}")
    print(f"Unique {color_label}s: {n_colors}")
    print(f"Task successes: {sum(task_successes)}/{len(trajectories)}")

    # Create color map
    if n_colors <= 10:
        colors = plt.cm.tab10(np.linspace(0, 1, 10))
    elif n_colors <= 20:
        colors = plt.cm.tab20(np.linspace(0, 1, 20))
    else:
        colors = plt.cm.viridis(np.linspace(0, 1, n_colors))

    idx_to_color = {idx: colors[i % len(colors)] for i, idx in enumerate(unique_indices)}

    # Create 4-subplot figure (2x2 grid, all 3D)
    fig = plt.figure(figsize=(16, 14))

    ax1 = fig.add_subplot(221, projection='3d')  # Overall trajectory
    ax2 = fig.add_subplot(222, projection='3d')  # Reach phase
    ax3 = fig.add_subplot(223, projection='3d')  # Carry phase
    ax4 = fig.add_subplot(224, projection='3d')  # Relative trajectory

    # Get reference waypoints from first valid metadata
    ref_waypoints = None
    for wp in waypoints_list:
        if wp is not None:
            ref_waypoints = wp
            break

    # ========== Plot 1: Overall Trajectory ==========
    for i, traj in enumerate(trajectories):
        cidx = color_indices[i]
        color = idx_to_color.get(cidx, colors[0])
        success = task_successes[i]
        alpha = 0.7 if success else 0.3
        linewidth = 1.2 if success else 0.6

        ax1.plot(traj[:, 0], traj[:, 1], traj[:, 2],
                '-', color=color, alpha=alpha, linewidth=linewidth)

    # Mark key waypoints
    if ref_waypoints is not None:
        ax1.scatter(*ref_waypoints['home'], c='green', marker='o', s=150, label='Home', zorder=10)
        ax1.scatter(*ref_waypoints['pregrasp'], c='orange', marker='*', s=150, label='Pregrasp', zorder=10)
        ax1.scatter(*ref_waypoints['lift'], c='purple', marker='^', s=150, label='Lift', zorder=10)
        ax1.scatter(*ref_waypoints['prerelease'], c='blue', marker='s', s=150, label='Prerelease', zorder=10)

    ax1.set_xlabel('X (m)')
    ax1.set_ylabel('Y (m)')
    ax1.set_zlabel('Z (m)')
    n_success = sum(task_successes)
    ax1.set_title(f'Overall Trajectory\n({len(trajectories)} demos, {n_success} success)')
    ax1.legend(loc='upper left', fontsize=8)

    # ========== Plot 2: Reach Phase Only ==========
    for i, traj in enumerate(trajectories):
        cidx = color_indices[i]
        color = idx_to_color.get(cidx, colors[0])
        success = task_successes[i]
        alpha = 0.7 if success else 0.3
        linewidth = 1.2 if success else 0.6

        # Extract reach phase
        reach_traj = traj[reach_start:reach_end]
        if len(reach_traj) > 0:
            ax2.plot(reach_traj[:, 0], reach_traj[:, 1], reach_traj[:, 2],
                    '-', color=color, alpha=alpha, linewidth=linewidth)

    if ref_waypoints is not None:
        ax2.scatter(*ref_waypoints['home'], c='green', marker='o', s=150, label='Home (start)', zorder=10)
        ax2.scatter(*ref_waypoints['pregrasp'], c='orange', marker='*', s=150, label='Pregrasp (end)', zorder=10)

    ax2.set_xlabel('X (m)')
    ax2.set_ylabel('Y (m)')
    ax2.set_zlabel('Z (m)')
    ax2.set_title(f'REACH Phase\n(Home -> Pregrasp, {reach_end - reach_start} steps)')
    ax2.legend(loc='upper left', fontsize=8)

    # ========== Plot 3: Carry Phase Only ==========
    for i, traj in enumerate(trajectories):
        cidx = color_indices[i]
        color = idx_to_color.get(cidx, colors[0])
        success = task_successes[i]
        alpha = 0.7 if success else 0.3
        linewidth = 1.2 if success else 0.6

        # Extract carry phase
        carry_traj = traj[carry_start:carry_end]
        if len(carry_traj) > 0:
            ax3.plot(carry_traj[:, 0], carry_traj[:, 1], carry_traj[:, 2],
                    '-', color=color, alpha=alpha, linewidth=linewidth)

    if ref_waypoints is not None:
        ax3.scatter(*ref_waypoints['lift'], c='purple', marker='^', s=150, label='Lift (start)', zorder=10)
        ax3.scatter(*ref_waypoints['prerelease'], c='blue', marker='s', s=150, label='Prerelease (end)', zorder=10)

    ax3.set_xlabel('X (m)')
    ax3.set_ylabel('Y (m)')
    ax3.set_zlabel('Z (m)')
    ax3.set_title(f'CARRY Phase\n(Lift -> Prerelease, {carry_end - carry_start} steps)')
    ax3.legend(loc='upper left', fontsize=8)

    # ========== Plot 4: Relative Trajectory (Normalized) ==========
    # Normalize both reach and carry trajectories to same coordinate frame
    # X: 0 at start, 1 at end
    # Y, Z: perpendicular offsets (normalized by path length)

    for i, traj in enumerate(trajectories):
        cidx = color_indices[i]
        color = idx_to_color.get(cidx, colors[0])
        success = task_successes[i]
        alpha = 0.7 if success else 0.3
        linewidth = 1.2 if success else 0.6

        wp = waypoints_list[i]
        if wp is None:
            continue

        # Normalize reach trajectory
        reach_traj = traj[reach_start:reach_end]
        if len(reach_traj) > 0:
            reach_norm = normalize_trajectory(reach_traj, wp['home'], wp['pregrasp'])
            ax4.plot(reach_norm[:, 0], reach_norm[:, 1], reach_norm[:, 2],
                    '-', color=color, alpha=alpha, linewidth=linewidth)

        # Normalize carry trajectory (same frame, different color intensity)
        carry_traj = traj[carry_start:carry_end]
        if len(carry_traj) > 0:
            carry_norm = normalize_trajectory(carry_traj, wp['lift'], wp['prerelease'])
            ax4.plot(carry_norm[:, 0], carry_norm[:, 1], carry_norm[:, 2],
                    '--', color=color, alpha=alpha * 0.8, linewidth=linewidth)

    # Mark normalized start (0,0,0) and end (1,0,0)
    ax4.scatter(0, 0, 0, c='green', marker='o', s=150, label='Start (0,0,0)', zorder=10)
    ax4.scatter(1, 0, 0, c='red', marker='*', s=150, label='End (1,0,0)', zorder=10)

    # Draw reference straight line
    ax4.plot([0, 1], [0, 0], [0, 0], 'k--', alpha=0.5, linewidth=2, label='Straight path')

    ax4.set_xlabel('Progress (0→1)')
    ax4.set_ylabel('Lateral offset')
    ax4.set_zlabel('Vertical offset')
    ax4.set_title('Relative Trajectory (Normalized)\nSolid=REACH, Dashed=CARRY')
    ax4.legend(loc='upper left', fontsize=8)

    # Set symmetric limits for offset axes
    ax4.set_xlim(-0.1, 1.1)

    plt.tight_layout()

    # Save figure
    output_path = FLAGS.output_path if FLAGS.output_path else variation_path
    output_file = os.path.join(output_path, f"{FLAGS.split}_ee_trajectories.png")
    plt.savefig(output_file, dpi=150, bbox_inches='tight')
    print(f"\nSaved visualization to: {output_file}")

    # Print statistics
    print(f"\n{'='*60}")
    print("Dataset Statistics:")
    print(f"{'='*60}")
    print(f"  Split: {FLAGS.split}")
    print(f"  Total trajectories: {len(trajectories)}")
    print(f"  Unique reach CPs: {len(set(reach_cp_indices))}")
    print(f"  Unique carry CPs: {len(set(carry_cp_indices))}")
    print(f"  Task successes: {sum(task_successes)}/{len(trajectories)}")

    traj_lengths = [len(t) for t in trajectories]
    print(f"  Trajectory length: {np.mean(traj_lengths):.1f} ± {np.std(traj_lengths):.1f} steps")

    if ref_waypoints is not None:
        print(f"\n  Waypoints (from first demo):")
        for name, pos in ref_waypoints.items():
            print(f"    {name}: [{pos[0]:.4f}, {pos[1]:.4f}, {pos[2]:.4f}]")

    plt.close()


if __name__ == "__main__":
    app.run(main)
