"""
Visualize grasp positions from generated dataset.
Plots EE positions at grasp point, colored by approach angle mode.
Shows points arranged in a circle around the cup.

Usage:
    python visualize_trajectories.py --data_path=/path/to/episodes
"""

import os
import sys
import numpy as np
import pickle
import matplotlib.pyplot as plt
from collections import defaultdict

# Default data path for grasp task
DEFAULT_DATA_PATH = "/scratch4/workspace/placeholder-hdp1/dppo/data/grasp/variation0/train/episodes"

# Mode colors: 0°=red, 90°=green, 180°=blue, 270°=orange
MODE_COLORS = {
    0: 'red',      # 0° approach
    1: 'green',    # 90° approach
    2: 'blue',     # 180° approach
    3: 'orange',   # 270° approach
}

MODE_LABELS = {
    0: '0°',
    1: '90°',
    2: '180°',
    3: '270°',
}


def load_episode_data(episode_path):
    """Load trajectory data from an episode directory."""
    obs_path = os.path.join(episode_path, "low_dim_obs.pkl")
    metadata_path = os.path.join(episode_path, "metadata.npy")

    if not os.path.exists(obs_path) or not os.path.exists(metadata_path):
        return None, None

    with open(obs_path, 'rb') as f:
        observations = pickle.load(f)

    metadata = np.load(metadata_path, allow_pickle=True).item()

    return observations, metadata


def get_grasp_position(observations, metadata):
    """
    Extract EE position at the grasp point.
    Grasp point is at the end of 'descend' phase.
    """
    phase_indices = metadata.get('phase_indices', {})

    if 'descend' in phase_indices:
        _, grasp_idx = phase_indices['descend']
    elif 'hold_grasp' in phase_indices:
        grasp_idx, _ = phase_indices['hold_grasp']
    else:
        # Fallback: use step 72 (default grasp position)
        grasp_idx = 72

    # Ensure index is valid
    grasp_idx = min(grasp_idx, len(observations) - 1)

    if hasattr(observations[grasp_idx], 'gripper_pose'):
        return observations[grasp_idx].gripper_pose[:3]
    return None


def visualize_grasp_positions(data_path, output_path=None):
    """
    Visualize EE positions at grasp point for all episodes.
    Creates a top-down (XY) view showing points colored by mode.
    Points should form a circle around the cup position.
    """
    # Find all episode directories
    episodes = sorted([d for d in os.listdir(data_path) if d.startswith('episode')])
    print(f"Found {len(episodes)} episodes")

    # Collect grasp positions grouped by mode
    mode_positions = defaultdict(list)
    mode_angles = defaultdict(list)  # Store actual approach angles
    cup_positions = []

    for ep_name in episodes:
        ep_path = os.path.join(data_path, ep_name)
        observations, metadata = load_episode_data(ep_path)

        if metadata is None:
            print(f"  Skipping {ep_name} - no metadata")
            continue

        # Check if episode was successful
        task_success = metadata.get('task_success', False)
        if not task_success:
            print(f"  Skipping {ep_name} - task failed")
            continue

        mode = metadata.get('mode', -1)
        approach_angle = metadata.get('approach_angle', 0)  # radians
        cup_pos = metadata.get('cup_pos', None)

        grasp_pos = get_grasp_position(observations, metadata)

        if grasp_pos is not None:
            mode_positions[mode].append(grasp_pos)
            mode_angles[mode].append(np.degrees(approach_angle))
            if cup_pos is not None:
                cup_positions.append(cup_pos)

    # Calculate average cup position
    if cup_positions:
        avg_cup_pos = np.mean(cup_positions, axis=0)
        print(f"Average cup position: {avg_cup_pos}")
    else:
        avg_cup_pos = np.array([0.25, -0.05, 0.82])
        print(f"Using default cup position: {avg_cup_pos}")

    # Print statistics
    total_points = sum(len(pts) for pts in mode_positions.values())
    print(f"\nCollected {total_points} grasp positions:")
    for mode in sorted(mode_positions.keys()):
        angles = mode_angles[mode]
        avg_angle = np.mean(angles)
        std_angle = np.std(angles)
        print(f"  Mode {mode} ({MODE_LABELS.get(mode, '?')}): {len(mode_positions[mode])} points, "
              f"angle={avg_angle:.1f}° ± {std_angle:.1f}°")

    # Create figure with two subplots: top-down (XY) and side view (XZ)
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # ==== Plot 1: Top-down view (XY) ====
    ax1 = axes[0]

    for mode in sorted(mode_positions.keys()):
        positions = np.array(mode_positions[mode])
        color = MODE_COLORS.get(mode, 'gray')
        label = MODE_LABELS.get(mode, f'Mode {mode}')

        ax1.scatter(positions[:, 0], positions[:, 1],
                   c=color, s=80, alpha=0.7, label=label, edgecolors='black', linewidths=0.5)

    # Plot cup position (center of circle)
    ax1.scatter([avg_cup_pos[0]], [avg_cup_pos[1]],
               c='purple', s=200, marker='*', label='Cup', zorder=10, edgecolors='black')

    # Draw reference circle around cup
    circle_radius = 0.04  # gripper offset from cup center
    theta = np.linspace(0, 2*np.pi, 100)
    circle_x = avg_cup_pos[0] + circle_radius * np.cos(theta)
    circle_y = avg_cup_pos[1] + circle_radius * np.sin(theta)
    ax1.plot(circle_x, circle_y, 'k--', alpha=0.3, linewidth=1, label='Gripper offset circle')

    ax1.set_xlabel('X (m)', fontsize=12)
    ax1.set_ylabel('Y (m)', fontsize=12)
    ax1.set_title('Top-Down View (XY) - Grasp Positions', fontsize=14)
    ax1.legend(loc='upper right')
    ax1.set_aspect('equal')
    ax1.grid(True, alpha=0.3)

    # ==== Plot 2: Side view (XZ) or angle distribution ====
    ax2 = axes[1]

    # Plot as polar: show angle distribution
    # Convert positions to angle relative to cup center
    for mode in sorted(mode_positions.keys()):
        positions = np.array(mode_positions[mode])
        color = MODE_COLORS.get(mode, 'gray')
        label = MODE_LABELS.get(mode, f'Mode {mode}')

        # Calculate angle from cup center
        rel_x = positions[:, 0] - avg_cup_pos[0]
        rel_y = positions[:, 1] - avg_cup_pos[1]
        angles = np.degrees(np.arctan2(rel_y, rel_x))
        distances = np.sqrt(rel_x**2 + rel_y**2)

        ax2.scatter(angles, distances * 100,  # Convert to cm
                   c=color, s=80, alpha=0.7, label=label, edgecolors='black', linewidths=0.5)

    ax2.axhline(y=4.0, color='k', linestyle='--', alpha=0.3, label='Expected offset (4cm)')
    ax2.set_xlabel('Approach Angle (degrees)', fontsize=12)
    ax2.set_ylabel('Distance from Cup Center (cm)', fontsize=12)
    ax2.set_title('Angle vs Distance Distribution', fontsize=14)
    ax2.legend(loc='upper right')
    ax2.grid(True, alpha=0.3)
    ax2.set_xlim(-200, 200)

    plt.suptitle(f'Grasp Position Visualization ({total_points} successful grasps)', fontsize=16)
    plt.tight_layout()

    # Save figure
    if output_path is None:
        output_path = os.path.join(data_path, "..", "grasp_positions.png")

    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"\nSaved visualization to {output_path}")
    plt.close()


def main():
    import argparse
    parser = argparse.ArgumentParser(description="Visualize grasp positions from generated dataset")
    parser.add_argument("--data_path", type=str, default=DEFAULT_DATA_PATH,
                       help="Path to episodes directory")
    parser.add_argument("--output", type=str, default=None,
                       help="Output path for visualization image")
    args = parser.parse_args()

    if not os.path.exists(args.data_path):
        print(f"Error: Data path does not exist: {args.data_path}")
        sys.exit(1)

    visualize_grasp_positions(args.data_path, args.output)


if __name__ == "__main__":
    main()
