#!/usr/bin/env python
"""
Visualize EE trajectories from saved dataset with wall collision detection.

Loads saved trajectory data (same as dataset_visualization.py) and creates
two separate 3D plots for REACH and CARRY phases with the SAME wall position.
Trajectories are colored: Blue = no collision, Red = collision.

No simulation required - just reads saved demo data.

Usage:
  python visualize_dataset_wall_collision.py

  # Custom wall position
  python visualize_dataset_wall_collision.py --wall_y=0.0

  # Use predefined style
  python visualize_dataset_wall_collision.py --style=1
"""

import os
import pickle
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

from absl import app, flags

from wall_collision import (
    DEFAULT_WALL_CONFIG,
    WALL_STYLES,
    check_ee_trajectory_wall_collision,
)

FLAGS = flags.FLAGS

# Default path matches dataset_generator_pick_place_cp.py output
DEFAULT_DATA_PATH = os.path.join(
    os.environ.get(
        "DPPO_DATA_DIR",
        os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
    ),
    "stack_blocks", "variation0"
)

# Output path for visualizations
DEFAULT_OUTPUT_PATH = os.path.join(
    os.path.dirname(os.path.dirname(__file__)),  # RLBench_pick_place/
    "block_setting"
)

flags.DEFINE_string("data_path", DEFAULT_DATA_PATH,
    "Path to the dataset (e.g., .../stack_blocks/variation0).")
flags.DEFINE_string("output_path", DEFAULT_OUTPUT_PATH,
    "Output path for images. Defaults to RLBench_pick_place/block_setting.")
flags.DEFINE_string("split", "train",
    "Which split to visualize: train or eval.")

# Wall style (predefined configs: 1, 2, 3)
flags.DEFINE_integer("style", 0, "Predefined wall style (1, 2, or 3). 0 = use individual flags.")

# Wall configuration (used when style=0)
flags.DEFINE_float("wall_y", DEFAULT_WALL_CONFIG["wall_y"], "Wall Y position")
flags.DEFINE_float("wall_min_x", DEFAULT_WALL_CONFIG["wall_min_x"], "Wall min X")
flags.DEFINE_float("wall_max_x", DEFAULT_WALL_CONFIG["wall_max_x"], "Wall max X")
flags.DEFINE_float("wall_min_z", DEFAULT_WALL_CONFIG["wall_min_z"], "Wall min Z")
flags.DEFINE_float("wall_max_z", DEFAULT_WALL_CONFIG["wall_max_z"], "Wall max Z")

# Phase configuration (must match dataset_generator_pick_place_cp.py)
PHASE_STEPS = {
    "reach": 64,
    "descend": 8,
    "grasp": 8,
    "lift": 8,
    "carry": 64,
    "descend_release": 8,
    "release": 8,
}


def get_phase_indices():
    """Get start/end indices for each phase."""
    indices = {}
    current = 0
    for phase, steps in PHASE_STEPS.items():
        indices[phase] = (current, current + steps)
        current += steps
    return indices


def load_demo_trajectory(episode_path):
    """Load EE positions from a demo's low_dim_obs.pkl."""
    pkl_path = os.path.join(episode_path, "low_dim_obs.pkl")
    if not os.path.exists(pkl_path):
        return None

    with open(pkl_path, "rb") as f:
        demo = pickle.load(f)

    positions = []
    for obs in demo:
        if hasattr(obs, 'gripper_pose') and obs.gripper_pose is not None:
            pos = obs.gripper_pose[:3]
            positions.append(pos)

    if len(positions) > 0:
        return np.array(positions)
    return None


def draw_wall_3d(ax, wall_config):
    """Draw wall as a semi-transparent plane in 3D plot."""
    wall_y = wall_config["wall_y"]
    min_x = wall_config["wall_min_x"]
    max_x = wall_config["wall_max_x"]
    min_z = wall_config["wall_min_z"]
    max_z = wall_config["wall_max_z"]

    wall_vertices = [
        [min_x, wall_y, min_z],
        [max_x, wall_y, min_z],
        [max_x, wall_y, max_z],
        [min_x, wall_y, max_z],
    ]
    wall_poly = Poly3DCollection([wall_vertices], alpha=0.3, facecolor='red',
                                  edgecolor='darkred', linewidth=2)
    ax.add_collection3d(wall_poly)


def main(argv):
    print("=" * 70)
    print("DATASET WALL COLLISION VISUALIZATION")
    print("=" * 70)

    # Build wall config
    if FLAGS.style in WALL_STYLES:
        wall_config = WALL_STYLES[FLAGS.style].copy()
        print(f"Using predefined wall style {FLAGS.style}")
    else:
        wall_config = {
            "wall_y": FLAGS.wall_y,
            "wall_min_x": FLAGS.wall_min_x,
            "wall_max_x": FLAGS.wall_max_x,
            "wall_min_z": FLAGS.wall_min_z,
            "wall_max_z": FLAGS.wall_max_z,
            "opening": None,
        }

    print(f"\nWall Configuration:")
    print(f"  Y position: {wall_config['wall_y']:.3f}")
    print(f"  X bounds: [{wall_config['wall_min_x']:.2f}, {wall_config['wall_max_x']:.2f}]")
    print(f"  Z bounds: [{wall_config['wall_min_z']:.2f}, {wall_config['wall_max_z']:.2f}]")

    # Find data path
    if os.path.exists(os.path.join(FLAGS.data_path, "train")):
        variation_path = FLAGS.data_path
    elif os.path.exists(os.path.join(FLAGS.data_path, "stack_blocks", "variation0")):
        variation_path = os.path.join(FLAGS.data_path, "stack_blocks", "variation0")
    else:
        print(f"\nError: Cannot find data at {FLAGS.data_path}")
        print("Make sure dataset_generator_pick_place_cp.py has been run first.")
        return

    print(f"\nLoading data from: {variation_path}")

    # Load metadata
    split_root = os.path.join(variation_path, FLAGS.split)
    episodes_path = os.path.join(split_root, "episodes")
    metadata_path = os.path.join(split_root, f"{FLAGS.split}_metadata.npy")

    if not os.path.exists(episodes_path):
        print(f"Error: Episodes path not found: {episodes_path}")
        return

    metadata = None
    if os.path.exists(metadata_path):
        metadata = np.load(metadata_path, allow_pickle=True)
        print(f"Total {FLAGS.split} demos in metadata: {len(metadata)}")

    # Get episode directories
    episode_dirs = sorted([d for d in os.listdir(episodes_path) if d.startswith("episode")])
    print(f"Episode directories found: {len(episode_dirs)}")

    # Get phase indices
    phase_indices = get_phase_indices()
    reach_start, reach_end = phase_indices["reach"]
    carry_start, carry_end = phase_indices["carry"]

    print(f"\nPhase indices:")
    print(f"  REACH: [{reach_start}, {reach_end})")
    print(f"  CARRY: [{carry_start}, {carry_end})")

    # Load trajectories - only anchors (demo_in_mode == 0, i.e., no noise)
    print("\nLoading anchor trajectories (no noise)...")
    trajectories = []
    waypoints_list = []
    cp_indices = []

    for i, ep_dir in enumerate(episode_dirs):
        ep_path = os.path.join(episodes_path, ep_dir)

        # Check if this is an anchor trajectory (demo_in_mode == 0)
        if metadata is not None and i < len(metadata):
            demo_in_mode = metadata[i].get('demo_in_mode', 0)
            with_noise = metadata[i].get('with_noise', False)

            # Skip noisy trajectories, only keep anchors
            if demo_in_mode != 0 or with_noise:
                continue

        traj = load_demo_trajectory(ep_path)

        if traj is not None and len(traj) > 0:
            trajectories.append(traj)

            if metadata is not None and i < len(metadata):
                cp_indices.append(metadata[i].get('cp_index', 0))
                waypoints = {
                    'home': np.array(metadata[i].get('home_pos', [0, 0, 0])),
                    'pregrasp': np.array(metadata[i].get('pregrasp_pos', [0, 0, 0])),
                    'lift': np.array(metadata[i].get('lift_pos', [0, 0, 0])),
                    'prerelease': np.array(metadata[i].get('prerelease_pos', [0, 0, 0])),
                }
                waypoints_list.append(waypoints)
            else:
                cp_indices.append(0)
                waypoints_list.append(None)

    print(f"Successfully loaded {len(trajectories)} anchor trajectories (8 modes)")

    if len(trajectories) == 0:
        print("Error: No trajectories loaded!")
        return

    # Output path
    output_path = FLAGS.output_path if FLAGS.output_path else variation_path
    os.makedirs(output_path, exist_ok=True)

    # Get reference waypoints
    ref_wp = None
    for wp in waypoints_list:
        if wp is not None:
            ref_wp = wp
            break

    # ========================================================================
    # REACH PHASE - Separate Figure
    # ========================================================================
    fig1 = plt.figure(figsize=(10, 8))
    ax1 = fig1.add_subplot(111, projection='3d')

    reach_success = 0
    reach_collision = 0

    for i, traj in enumerate(trajectories):
        reach_traj = traj[reach_start:reach_end]
        if len(reach_traj) == 0:
            continue

        # Check wall collision
        collision, collision_idx = check_ee_trajectory_wall_collision(reach_traj, wall_config)

        if collision:
            reach_collision += 1
            color = 'red'
            alpha = 0.7
            # Only plot up to collision point
            plot_traj = reach_traj[:collision_idx+1]
        else:
            reach_success += 1
            color = 'blue'
            alpha = 0.7
            # Plot full trajectory
            plot_traj = reach_traj

        ax1.plot(plot_traj[:, 0], plot_traj[:, 1], plot_traj[:, 2],
                '-', color=color, alpha=alpha, linewidth=1.5)

        # Mark collision point with X
        if collision and collision_idx is not None:
            ax1.scatter(reach_traj[collision_idx, 0], reach_traj[collision_idx, 1], reach_traj[collision_idx, 2],
                       c='red', marker='x', s=100, zorder=15)

    # Draw wall
    draw_wall_3d(ax1, wall_config)

    # Mark waypoints
    if ref_wp is not None:
        ax1.scatter(*ref_wp['home'], c='green', marker='o', s=150, label='Home (start)', zorder=10)
        ax1.scatter(*ref_wp['pregrasp'], c='orange', marker='*', s=150, label='Pregrasp (end)', zorder=10)

    ax1.set_xlabel('X (m)')
    ax1.set_ylabel('Y (m)')
    ax1.set_zlabel('Z (m)')
    ax1.set_title(f'REACH Phase - {len(trajectories)} Trajectories\n'
                 f'Success: {reach_success} (blue) | Collision: {reach_collision} (red)\n'
                 f'Wall Y={wall_config["wall_y"]:.3f}')
    ax1.legend(loc='upper left', fontsize=8)
    ax1.view_init(elev=25, azim=-60)

    plt.tight_layout()
    reach_fig_path = os.path.join(output_path, f'{FLAGS.split}_reach_wall_collision.png')
    plt.savefig(reach_fig_path, dpi=150, bbox_inches='tight')
    print(f"\nSaved REACH visualization to: {reach_fig_path}")
    plt.close()

    # ========================================================================
    # CARRY PHASE - Separate Figure
    # ========================================================================
    fig2 = plt.figure(figsize=(10, 8))
    ax2 = fig2.add_subplot(111, projection='3d')

    carry_success = 0
    carry_collision = 0

    for i, traj in enumerate(trajectories):
        carry_traj = traj[carry_start:carry_end]
        if len(carry_traj) == 0:
            continue

        # Check wall collision
        collision, collision_idx = check_ee_trajectory_wall_collision(carry_traj, wall_config)

        if collision:
            carry_collision += 1
            color = 'red'
            alpha = 0.7
            # Only plot up to collision point
            plot_traj = carry_traj[:collision_idx+1]
        else:
            carry_success += 1
            color = 'blue'
            alpha = 0.7
            # Plot full trajectory
            plot_traj = carry_traj

        ax2.plot(plot_traj[:, 0], plot_traj[:, 1], plot_traj[:, 2],
                '-', color=color, alpha=alpha, linewidth=1.5)

        # Mark collision point with X
        if collision and collision_idx is not None:
            ax2.scatter(carry_traj[collision_idx, 0], carry_traj[collision_idx, 1], carry_traj[collision_idx, 2],
                       c='red', marker='x', s=100, zorder=15)

    # Draw wall (SAME wall config)
    draw_wall_3d(ax2, wall_config)

    # Mark waypoints
    if ref_wp is not None:
        ax2.scatter(*ref_wp['lift'], c='purple', marker='^', s=150, label='Lift (start)', zorder=10)
        ax2.scatter(*ref_wp['prerelease'], c='blue', marker='s', s=150, label='Prerelease (end)', zorder=10)

    ax2.set_xlabel('X (m)')
    ax2.set_ylabel('Y (m)')
    ax2.set_zlabel('Z (m)')
    ax2.set_title(f'CARRY Phase - {len(trajectories)} Trajectories\n'
                 f'Success: {carry_success} (blue) | Collision: {carry_collision} (red)\n'
                 f'Wall Y={wall_config["wall_y"]:.3f}')
    ax2.legend(loc='upper left', fontsize=8)
    ax2.view_init(elev=25, azim=-60)

    plt.tight_layout()
    carry_fig_path = os.path.join(output_path, f'{FLAGS.split}_carry_wall_collision.png')
    plt.savefig(carry_fig_path, dpi=150, bbox_inches='tight')
    print(f"Saved CARRY visualization to: {carry_fig_path}")
    plt.close()

    # Print summary
    print(f"\n{'='*70}")
    print("SUMMARY")
    print(f"{'='*70}")
    print(f"Total trajectories: {len(trajectories)}")
    print(f"\nREACH phase:")
    print(f"  Success (no collision): {reach_success} ({100*reach_success/len(trajectories):.1f}%)")
    print(f"  Wall collision: {reach_collision} ({100*reach_collision/len(trajectories):.1f}%)")
    print(f"\nCARRY phase:")
    print(f"  Success (no collision): {carry_success} ({100*carry_success/len(trajectories):.1f}%)")
    print(f"  Wall collision: {carry_collision} ({100*carry_collision/len(trajectories):.1f}%)")

    print("\nDone!")


if __name__ == "__main__":
    app.run(main)
