#!/usr/bin/env python
"""
Visualize EE trajectories with trajectory-relative wall configuration.

The wall is defined perpendicular to the trajectory direction, using the same
coordinate system as control points (angle, distance in units of radius * path_length).

For each wall style, saves:
- wall_config.npy: Wall configuration at root
- style{id}/summary.npy: Summary with successful modes
- style{id}/episodes/: Directory with successful episode data (symlinks)

Usage:
  python visualize_trajectory_wall.py

  # Custom wall configuration
  python visualize_trajectory_wall.py --wall_pos=0.5 --wall_corner_angle=45 --wall_corner_dist=1.0 --wall_width=2.0 --wall_height=2.0

  # Use predefined style
  python visualize_trajectory_wall.py --style=1
"""

import os
import pickle
import shutil
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from absl import app, flags

from wall_config import (
    DEFAULT_WALL_CONFIG,
    WALL_STYLES,
    build_local_frame,
    compute_wall_corners,
    compute_opening_corners,
    check_trajectory_wall_collision,
    draw_wall_3d,
    print_wall_info,
)


# ============================================================================
# Synthetic trajectory generation (same as dataset generator)
# ============================================================================

def parabola3D(S, E, C, t):
    """Quadratic Bezier curve in 3D with endpoints S, E and control point C."""
    M = 0.5 * (S + E)
    P1 = 2 * C - M
    return (1 - t) * (1 - t) * S + 2 * (1 - t) * t * P1 + t * t * E


def compute_control_point_from_params(start_pos, target_pos, radius, angle, dist_frac, pos_frac):
    """
    Compute a control point given canonical parameters.
    Same as utils.py compute_control_point_from_params.
    """
    line_vec = target_pos - start_pos
    path_length = np.linalg.norm(line_vec)
    line_vec_norm, perp1, perp2 = build_local_frame(start_pos, target_pos)

    # Base position on the line
    base_pos = start_pos + pos_frac * line_vec

    # Offset perpendicular to the line
    offset_dist = dist_frac * radius * path_length
    offset = offset_dist * (np.cos(angle) * perp1 + np.sin(angle) * perp2)

    return base_pos + offset


def generate_synthetic_trajectory(start_pos, end_pos, radius, angle_deg, dist_frac, pos_frac=0.5, n_steps=64):
    """
    Generate a synthetic trajectory using Bezier curve with given CP parameters.

    Args:
        start_pos: np.ndarray(3,), start position
        end_pos: np.ndarray(3,), end position
        radius: float, control point radius
        angle_deg: float, angle in degrees (0=up, 90=right, etc.)
        dist_frac: float, distance as fraction of radius
        pos_frac: float, position along trajectory (default 0.5)
        n_steps: int, number of trajectory steps

    Returns:
        trajectory: np.ndarray(n_steps, 3), trajectory positions
    """
    angle_rad = np.deg2rad(angle_deg)
    cp = compute_control_point_from_params(
        start_pos, end_pos, radius,
        angle_rad, dist_frac, pos_frac
    )

    trajectory = []
    for i in range(n_steps):
        t = i / max(n_steps - 1, 1)
        pos = parabola3D(start_pos, end_pos, cp, t)
        trajectory.append(pos)

    return np.array(trajectory)

FLAGS = flags.FLAGS

# Default paths
DEFAULT_DATA_PATH = os.path.join(
    os.environ.get(
        "DPPO_DATA_DIR",
        os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
    ),
    "stack_blocks", "variation0"
)

DEFAULT_OUTPUT_PATH = os.path.join(
    os.path.dirname(os.path.dirname(__file__)),
    "block_setting"
)

flags.DEFINE_string("data_path", DEFAULT_DATA_PATH, "Path to the dataset.")
flags.DEFINE_string("output_path", DEFAULT_OUTPUT_PATH, "Output path for images.")
flags.DEFINE_string("split", "train", "Which split to visualize.")

# Wall style
flags.DEFINE_integer("style", 0, "Predefined wall style (1, 2, 3). 0 = use flags below.")

# Wall configuration (trajectory-relative coordinates)
flags.DEFINE_float("wall_pos", 0.5, "Wall position along trajectory (0=start, 1=end)")
flags.DEFINE_float("wall_corner_angle", 45.0, "Corner angle in degrees (0=up)")
flags.DEFINE_float("wall_corner_dist", 1.0, "Corner distance (fraction of radius)")
flags.DEFINE_float("wall_width", 2.0, "Wall width (in units of radius * path_length)")
flags.DEFINE_float("wall_height", 2.0, "Wall height (in units of radius * path_length)")

# Control point radius (must match dataset generator)
flags.DEFINE_float("radius", 0.05, "Control point radius (same as dataset generator)")

# New mode parameters (synthetic trajectory that can pass through opening)
flags.DEFINE_bool("show_new_mode", True, "Show synthetic trajectory with new CP mode")
flags.DEFINE_float("new_mode_angle", 45.0, "New mode CP angle in degrees")
flags.DEFINE_float("new_mode_dist", 1.0, "New mode CP distance fraction")

# Phase configuration
PHASE_STEPS = {
    "reach": 64,
    "descend": 8,
    "grasp": 8,
    "lift": 8,
    "carry": 64,
    "descend_release": 8,
    "release": 8,
}


def get_phase_indices():
    """Get start/end indices for each phase."""
    indices = {}
    current = 0
    for phase, steps in PHASE_STEPS.items():
        indices[phase] = (current, current + steps)
        current += steps
    return indices


def load_demo_trajectory(episode_path):
    """Load EE positions from a demo's low_dim_obs.pkl."""
    pkl_path = os.path.join(episode_path, "low_dim_obs.pkl")
    if not os.path.exists(pkl_path):
        return None

    with open(pkl_path, "rb") as f:
        demo = pickle.load(f)

    positions = []
    for obs in demo:
        if hasattr(obs, 'gripper_pose') and obs.gripper_pose is not None:
            pos = obs.gripper_pose[:3]
            positions.append(pos)

    if len(positions) > 0:
        return np.array(positions)
    return None


def visualize_phase(ax, trajectories, waypoints_list, phase_name, phase_indices,
                    wall_config, radius, title_suffix="",
                    show_new_mode=False, new_mode_angle=45.0, new_mode_dist=1.0):
    """
    Visualize a single phase with trajectory-relative wall.

    Args:
        ax: matplotlib 3D axis
        trajectories: list of full trajectory arrays
        waypoints_list: list of waypoint dicts
        phase_name: "reach" or "carry"
        phase_indices: dict mapping phase name to (start, end)
        wall_config: trajectory-relative wall configuration
        radius: control point radius
        title_suffix: additional text for title
        show_new_mode: whether to show synthetic trajectory with new CP mode
        new_mode_angle: angle in degrees for new mode CP
        new_mode_dist: distance fraction for new mode CP

    Returns:
        n_success: int, number of successful trajectories
        n_collision: int, number of colliding trajectories
        new_mode_success: bool, whether new mode trajectory passes
        successful_indices: list of int, indices of successful trajectories
    """
    start_idx, end_idx = phase_indices[phase_name]

    # Get waypoint names
    if phase_name == "reach":
        start_wp = "home"
        end_wp = "pregrasp"
    else:
        start_wp = "lift"
        end_wp = "prerelease"

    # Get reference waypoints
    ref_wp = None
    for wp in waypoints_list:
        if wp is not None:
            ref_wp = wp
            break

    if ref_wp is None:
        print(f"Error: No waypoints found for {phase_name}")
        return 0, 0, False, []

    start_pos = np.array(ref_wp[start_wp])
    end_pos = np.array(ref_wp[end_wp])

    # Compute wall corners for this phase
    corners = compute_wall_corners(start_pos, end_pos, radius, wall_config)
    opening_corners = compute_opening_corners(start_pos, end_pos, radius, wall_config)

    print(f"\n{phase_name.upper()} Phase:")
    print_wall_info(start_pos, end_pos, radius, wall_config)

    # Plot trajectories and check collision
    n_success = 0
    n_collision = 0
    successful_indices = []

    for i, traj in enumerate(trajectories):
        phase_traj = traj[start_idx:end_idx]
        if len(phase_traj) == 0:
            continue

        # Check collision using trajectory-relative wall
        collision, collision_idx = check_trajectory_wall_collision(
            phase_traj, start_pos, end_pos, radius, wall_config
        )

        if collision:
            n_collision += 1
            color = 'red'
            alpha = 0.7
            plot_traj = phase_traj[:collision_idx+1]
        else:
            n_success += 1
            successful_indices.append(i)
            color = 'blue'
            alpha = 0.7
            plot_traj = phase_traj

        ax.plot(plot_traj[:, 0], plot_traj[:, 1], plot_traj[:, 2],
                '-', color=color, alpha=alpha, linewidth=1.5)

        # Mark collision point
        if collision and collision_idx is not None:
            ax.scatter(phase_traj[collision_idx, 0],
                      phase_traj[collision_idx, 1],
                      phase_traj[collision_idx, 2],
                      c='red', marker='x', s=100, zorder=15)

    # Generate and plot synthetic trajectory for new mode (dotted line)
    new_mode_success = False
    if show_new_mode:
        n_steps = end_idx - start_idx
        synthetic_traj = generate_synthetic_trajectory(
            start_pos, end_pos, radius,
            new_mode_angle, new_mode_dist,
            pos_frac=0.5, n_steps=n_steps
        )

        # Check collision for synthetic trajectory
        collision, collision_idx = check_trajectory_wall_collision(
            synthetic_traj, start_pos, end_pos, radius, wall_config
        )

        if collision:
            plot_traj = synthetic_traj[:collision_idx+1]
            new_mode_color = 'purple'
            new_mode_label = f'New mode (angle={new_mode_angle}°, dist={new_mode_dist}) - COLLISION'
        else:
            plot_traj = synthetic_traj
            new_mode_color = 'cyan'
            new_mode_label = f'New mode (angle={new_mode_angle}°, dist={new_mode_dist}) - PASS'
            new_mode_success = True

        # Plot as dotted line
        ax.plot(plot_traj[:, 0], plot_traj[:, 1], plot_traj[:, 2],
                ':', color=new_mode_color, alpha=0.9, linewidth=3,
                label=new_mode_label)

        # Mark collision point for new mode
        if collision and collision_idx is not None:
            ax.scatter(synthetic_traj[collision_idx, 0],
                      synthetic_traj[collision_idx, 1],
                      synthetic_traj[collision_idx, 2],
                      c='purple', marker='x', s=150, zorder=16)

        print(f"  New mode trajectory: {'PASS' if new_mode_success else 'COLLISION'}")

    # Draw wall (with opening if defined)
    draw_wall_3d(ax, corners, opening_corners=opening_corners)

    # Mark start and end waypoints
    ax.scatter(*start_pos, c='green', marker='o', s=150,
              label=f'{start_wp.title()} (start)', zorder=10)
    ax.scatter(*end_pos, c='orange', marker='*', s=150,
              label=f'{end_wp.title()} (end)', zorder=10)

    # Labels
    ax.set_xlabel('X (m)')
    ax.set_ylabel('Y (m)')
    ax.set_zlabel('Z (m)')

    # Build title
    title = f'{phase_name.upper()} Phase - {len(trajectories)} Trajectories\n'
    title += f'Success: {n_success} (blue) | Collision: {n_collision} (red)\n'
    title += f'Wall pos={wall_config["pos_frac"]:.2f}, '
    title += f'corner=({wall_config["corner_angle"]:.0f}°, {wall_config["corner_dist"]:.1f}), '
    title += f'size=({wall_config["width"]:.1f}x{wall_config["height"]:.1f})'
    if show_new_mode:
        title += f'\nNew mode: angle={new_mode_angle}°, dist={new_mode_dist} ({"PASS" if new_mode_success else "COLLISION"})'
    title += title_suffix

    ax.set_title(title)
    ax.legend(loc='upper left', fontsize=8)
    ax.view_init(elev=25, azim=-60)

    return n_success, n_collision, new_mode_success, successful_indices


def main(argv):
    print("=" * 70)
    print("TRAJECTORY-RELATIVE WALL VISUALIZATION")
    print("=" * 70)

    # Build wall config
    if FLAGS.style in WALL_STYLES:
        wall_config = WALL_STYLES[FLAGS.style].copy()
        print(f"Using predefined wall style {FLAGS.style}")
    else:
        wall_config = {
            "pos_frac": FLAGS.wall_pos,
            "corner_angle": FLAGS.wall_corner_angle,
            "corner_dist": FLAGS.wall_corner_dist,
            "width": FLAGS.wall_width,
            "height": FLAGS.wall_height,
        }

    # For style 3 (wall with opening), automatically set new mode to pass through opening
    # New mode: angle=135°, dist=1.0 (designed to pass through the opening)
    if FLAGS.style == 3:
        new_mode_angle = 135.0
        new_mode_dist = 1.0
        show_new_mode = True
        print(f"  Style 3 detected: automatically setting new mode to angle={new_mode_angle}°, dist={new_mode_dist}")
    else:
        new_mode_angle = FLAGS.new_mode_angle
        new_mode_dist = FLAGS.new_mode_dist
        show_new_mode = FLAGS.show_new_mode

    print(f"\nWall Configuration (trajectory-relative):")
    print(f"  Position: {wall_config['pos_frac']:.2f} (0=start, 1=end)")
    print(f"  Corner: angle={wall_config['corner_angle']:.1f}°, dist={wall_config['corner_dist']:.2f}")
    print(f"  Size: {wall_config['width']:.2f} x {wall_config['height']:.2f}")
    print(f"  Radius: {FLAGS.radius}")

    # Find data path
    if os.path.exists(os.path.join(FLAGS.data_path, "train")):
        variation_path = FLAGS.data_path
    else:
        print(f"\nError: Cannot find data at {FLAGS.data_path}")
        return

    print(f"\nLoading data from: {variation_path}")

    # Load metadata
    split_root = os.path.join(variation_path, FLAGS.split)
    episodes_path = os.path.join(split_root, "episodes")
    metadata_path = os.path.join(split_root, f"{FLAGS.split}_metadata.npy")

    if not os.path.exists(episodes_path):
        print(f"Error: Episodes path not found: {episodes_path}")
        return

    metadata = None
    if os.path.exists(metadata_path):
        metadata = np.load(metadata_path, allow_pickle=True)
        print(f"Total {FLAGS.split} demos in metadata: {len(metadata)}")

    # Get episode directories
    episode_dirs = sorted([d for d in os.listdir(episodes_path) if d.startswith("episode")])

    # Get phase indices
    phase_indices = get_phase_indices()

    # Load anchor trajectories only
    print("\nLoading anchor trajectories...")
    trajectories = []
    waypoints_list = []

    for i, ep_dir in enumerate(episode_dirs):
        ep_path = os.path.join(episodes_path, ep_dir)

        # Only load anchors (demo_in_mode == 0)
        if metadata is not None and i < len(metadata):
            if metadata[i].get('demo_in_mode', 0) != 0 or metadata[i].get('with_noise', False):
                continue

        traj = load_demo_trajectory(ep_path)

        if traj is not None and len(traj) > 0:
            trajectories.append(traj)

            if metadata is not None and i < len(metadata):
                waypoints = {
                    'home': np.array(metadata[i].get('home_pos', [0, 0, 0])),
                    'pregrasp': np.array(metadata[i].get('pregrasp_pos', [0, 0, 0])),
                    'lift': np.array(metadata[i].get('lift_pos', [0, 0, 0])),
                    'prerelease': np.array(metadata[i].get('prerelease_pos', [0, 0, 0])),
                }
                waypoints_list.append(waypoints)
            else:
                waypoints_list.append(None)

    print(f"Loaded {len(trajectories)} anchor trajectories")

    if len(trajectories) == 0:
        print("Error: No trajectories loaded!")
        return

    # Output path
    output_path = FLAGS.output_path
    os.makedirs(output_path, exist_ok=True)

    # New mode parameters
    if show_new_mode:
        print(f"\nNew mode parameters:")
        print(f"  Angle: {new_mode_angle}°")
        print(f"  Distance: {new_mode_dist}")

    # ========================================================================
    # REACH PHASE
    # ========================================================================
    fig1 = plt.figure(figsize=(10, 8))
    ax1 = fig1.add_subplot(111, projection='3d')

    reach_success, reach_collision, reach_new_mode_success, reach_successful_indices = visualize_phase(
        ax1, trajectories, waypoints_list, "reach", phase_indices,
        wall_config, FLAGS.radius,
        show_new_mode=show_new_mode,
        new_mode_angle=new_mode_angle,
        new_mode_dist=new_mode_dist
    )

    plt.tight_layout()
    reach_path = os.path.join(output_path, f'{FLAGS.split}_reach_traj_wall.png')
    plt.savefig(reach_path, dpi=150, bbox_inches='tight')
    print(f"\nSaved REACH visualization to: {reach_path}")
    plt.close()

    # ========================================================================
    # CARRY PHASE
    # ========================================================================
    fig2 = plt.figure(figsize=(10, 8))
    ax2 = fig2.add_subplot(111, projection='3d')

    carry_success, carry_collision, carry_new_mode_success, carry_successful_indices = visualize_phase(
        ax2, trajectories, waypoints_list, "carry", phase_indices,
        wall_config, FLAGS.radius,
        show_new_mode=show_new_mode,
        new_mode_angle=new_mode_angle,
        new_mode_dist=new_mode_dist
    )

    plt.tight_layout()
    carry_path = os.path.join(output_path, f'{FLAGS.split}_carry_traj_wall.png')
    plt.savefig(carry_path, dpi=150, bbox_inches='tight')
    print(f"Saved CARRY visualization to: {carry_path}")
    plt.close()

    # Summary
    print(f"\n{'='*70}")
    print("SUMMARY")
    print(f"{'='*70}")
    print(f"Total trajectories: {len(trajectories)}")
    print(f"\nREACH: {reach_success} success, {reach_collision} collision")
    print(f"  Successful modes: {reach_successful_indices}")
    print(f"CARRY: {carry_success} success, {carry_collision} collision")
    print(f"  Successful modes: {carry_successful_indices}")

    if show_new_mode:
        print(f"\nNew mode (angle={new_mode_angle}°, dist={new_mode_dist}):")
        print(f"  REACH: {'PASS' if reach_new_mode_success else 'COLLISION'}")
        print(f"  CARRY: {'PASS' if carry_new_mode_success else 'COLLISION'}")

    # ========================================================================
    # SAVE DATA (like close_drawer block_setting structure)
    # ========================================================================

    # Save wall config at root
    np.save(os.path.join(output_path, 'wall_config.npy'), wall_config)
    print(f"\nSaved wall_config.npy to {output_path}")

    # Create style-specific directory
    style_name = f"style{FLAGS.style}"
    style_path = os.path.join(output_path, style_name)
    os.makedirs(style_path, exist_ok=True)

    # Determine which modes are successful for the entire trajectory
    # A mode must pass BOTH reach and carry phases
    successful_modes = sorted(set(reach_successful_indices) & set(carry_successful_indices))
    print(f"\nModes passing both REACH and CARRY: {successful_modes}")

    # For style 3, if new mode passes both phases, add it as mode 8
    if FLAGS.style == 3 and reach_new_mode_success and carry_new_mode_success:
        # Mode 8 is the new mode (angle=135°, dist=1.0)
        new_mode_idx = 8  # Convention: new mode is mode 8
        print(f"New mode (mode {new_mode_idx}) passes both phases, adding to successful modes")
        successful_modes = successful_modes + [new_mode_idx]

    # Get reference waypoints for summary
    ref_wp = None
    for wp in waypoints_list:
        if wp is not None:
            ref_wp = wp
            break

    # Save summary
    summary = {
        'style': FLAGS.style,
        'wall_config': wall_config,
        'num_successful': len(successful_modes),
        'successful_modes': successful_modes,
        'reach_successful_modes': reach_successful_indices,
        'carry_successful_modes': carry_successful_indices,
        'new_mode_reach_success': reach_new_mode_success if show_new_mode else None,
        'new_mode_carry_success': carry_new_mode_success if show_new_mode else None,
        'new_mode_angle': new_mode_angle if show_new_mode else None,
        'new_mode_dist': new_mode_dist if show_new_mode else None,
    }
    if ref_wp is not None:
        summary['home_pos'] = ref_wp['home'].tolist()
        summary['pregrasp_pos'] = ref_wp['pregrasp'].tolist()
        summary['lift_pos'] = ref_wp['lift'].tolist()
        summary['prerelease_pos'] = ref_wp['prerelease'].tolist()

    np.save(os.path.join(style_path, 'summary.npy'), summary)
    print(f"Saved summary.npy to {style_path}")

    # Copy successful episodes to style directory
    if successful_modes:
        demos_path = os.path.join(style_path, "episodes")
        os.makedirs(demos_path, exist_ok=True)

        print(f"\nCopying {len(successful_modes)} successful episodes to {demos_path}")
        for idx, mode_idx in enumerate(successful_modes):
            if mode_idx < len(episode_dirs):
                # This is an existing mode from the dataset
                src_ep_path = os.path.join(episodes_path, episode_dirs[mode_idx])
                dst_ep_path = os.path.join(demos_path, f"episode{idx}")

                if os.path.exists(dst_ep_path):
                    shutil.rmtree(dst_ep_path)

                # Copy the entire episode directory
                shutil.copytree(src_ep_path, dst_ep_path)
                print(f"  Copied mode {mode_idx} -> episode{idx}")
            else:
                # This is the new mode (mode 8) - need to generate synthetic trajectory
                # For now, just save metadata about the new mode
                dst_ep_path = os.path.join(demos_path, f"episode{idx}")
                os.makedirs(dst_ep_path, exist_ok=True)

                # Save metadata for new mode
                new_mode_metadata = {
                    'mode': mode_idx,
                    'cp_angle': new_mode_angle,
                    'cp_dist': new_mode_dist,
                    'cp_pos_frac': 0.5,
                    'is_synthetic': True,
                }
                np.save(os.path.join(dst_ep_path, 'metadata.npy'), new_mode_metadata)
                print(f"  Created synthetic mode {mode_idx} -> episode{idx}")

    print("\nDone!")


if __name__ == "__main__":
    app.run(main)
