"""
Visualize trajectories from generated dataset.
Plots all trajectories for each mode on the first frame of that mode.

Usage:
    python visualize_trajectories.py --data_path=/path/to/episodes
"""

import os
import sys
import numpy as np
import pickle
import cv2
import matplotlib.pyplot as plt
from collections import defaultdict

# Default data path
DEFAULT_DATA_PATH = "/scratch4/workspace/placeholder-hdp1/dppo/data/close_drawer/variation2/train/episodes"


def load_episode_data(episode_path):
    """Load trajectory data from an episode directory."""
    # Load low_dim_obs.pkl to get observations
    obs_path = os.path.join(episode_path, "low_dim_obs.pkl")
    metadata_path = os.path.join(episode_path, "metadata.npy")

    if not os.path.exists(obs_path) or not os.path.exists(metadata_path):
        return None, None

    with open(obs_path, 'rb') as f:
        observations = pickle.load(f)

    metadata = np.load(metadata_path, allow_pickle=True).item()

    return observations, metadata


def extract_ee_trajectory(observations):
    """Extract end-effector positions from observations."""
    trajectory = []
    for obs in observations:
        if hasattr(obs, 'gripper_pose') and obs.gripper_pose is not None:
            # gripper_pose is [x, y, z, qx, qy, qz, qw]
            pos = obs.gripper_pose[:3]
            trajectory.append(pos)
    return np.array(trajectory)


def get_first_frame(episode_path):
    """Get the first RGB frame from the video (without trajectory overlay)."""
    video_path = os.path.join(episode_path, "video.mp4")
    if os.path.exists(video_path):
        cap = cv2.VideoCapture(video_path)
        ret, frame = cap.read()
        cap.release()
        if ret:
            # Convert BGR to RGB
            return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return None


def get_last_frame(episode_path):
    """Get the last RGB frame from the video (with complete trajectory overlay)."""
    video_path = os.path.join(episode_path, "video.mp4")
    if os.path.exists(video_path):
        cap = cv2.VideoCapture(video_path)
        # Get total frame count
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if total_frames > 0:
            # Seek to last frame
            cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1)
            ret, frame = cap.read()
            cap.release()
            if ret:
                # Convert BGR to RGB
                return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        cap.release()
    return None


def load_camera_config(data_path):
    """
    Load camera configuration from config.npy.
    Returns camera_pos, camera_matrix, camera_fov if available.
    """
    # config.npy is in the variation folder (parent of train/episodes)
    config_path = os.path.join(data_path, "..", "..", "config.npy")
    if os.path.exists(config_path):
        config = np.load(config_path, allow_pickle=True).item()
        if 'camera_matrix' in config:
            return {
                'camera_position': np.array(config['camera_position']),
                'camera_matrix': np.array(config['camera_matrix']),
                'camera_fov': config.get('camera_fov', 60.0),
            }
    return None


def project_3d_to_2d(points_3d, camera_config, image_size=(256, 256)):
    """
    Project 3D world coordinates to 2D image coordinates.
    Uses camera matrix from config if available.
    """
    if camera_config is None:
        # Fallback: use default config values
        from close_drawer_config import CAMERA_POSITION
        cam_pos = np.array(CAMERA_POSITION)
        # Use identity matrix as fallback (won't work well)
        cam_matrix = np.eye(3)
        fov_deg = 60.0
    else:
        cam_pos = camera_config['camera_position']
        cam_matrix = camera_config['camera_matrix']
        fov_deg = camera_config['camera_fov']

    fov = fov_deg * np.pi / 180.0
    f = image_size[0] / (2.0 * np.tan(fov / 2.0))
    cx, cy = image_size[1] / 2.0, image_size[0] / 2.0

    projected = []
    for p in points_3d:
        p_rel = np.array(p) - cam_pos
        # Transform to camera frame using actual camera matrix
        p_cam = cam_matrix.T @ p_rel
        x_cam, y_cam, z_cam = p_cam[0], p_cam[1], p_cam[2]

        if z_cam > 0.01:
            u = cx - f * x_cam / z_cam
            v = cy - f * y_cam / z_cam
            projected.append((int(u), int(v)))
        else:
            projected.append(None)

    return projected


def visualize_modes(data_path, output_path=None):
    """
    Visualize all trajectories grouped by mode.
    Creates a 2x4 grid of subplots, one per mode.
    """
    # Load camera configuration for projection
    camera_config = load_camera_config(data_path)
    use_last_frames = False
    if camera_config is None:
        print("Warning: Camera config not found in config.npy.")
        print("  Will use last frames from videos (which have trajectory overlay).")
        print("  To get multi-trajectory plots, re-run dataset_generator to save camera matrix.")
        use_last_frames = True
    else:
        print(f"Loaded camera config: FOV={camera_config['camera_fov']:.1f}deg")

    # Find all episode directories
    episodes = sorted([d for d in os.listdir(data_path) if d.startswith('episode')])
    print(f"Found {len(episodes)} episodes")

    # Group episodes by mode
    mode_episodes = defaultdict(list)

    for ep_name in episodes:
        ep_path = os.path.join(data_path, ep_name)
        observations, metadata = load_episode_data(ep_path)

        if metadata is None:
            print(f"  Skipping {ep_name} - no metadata")
            continue

        mode = metadata.get('mode', -1)
        trajectory = extract_ee_trajectory(observations) if observations else None
        first_frame = get_first_frame(ep_path)
        last_frame = get_last_frame(ep_path) if use_last_frames else None

        mode_episodes[mode].append({
            'name': ep_name,
            'trajectory': trajectory,
            'first_frame': first_frame,
            'last_frame': last_frame,
            'metadata': metadata,
        })

    print(f"Found {len(mode_episodes)} modes")
    for mode in sorted(mode_episodes.keys()):
        print(f"  Mode {mode}: {len(mode_episodes[mode])} episodes")

    # Create figure with 2x4 grid
    num_modes = len(mode_episodes)
    n_cols = 4
    n_rows = (num_modes + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 8))
    if n_rows == 1:
        axes = axes.reshape(1, -1)

    # Color palette for trajectories
    colors = plt.cm.viridis(np.linspace(0, 1, 10))

    for mode_idx, mode in enumerate(sorted(mode_episodes.keys())):
        row = mode_idx // n_cols
        col = mode_idx % n_cols
        ax = axes[row, col]

        episodes_data = mode_episodes[mode]

        # Get base params for title
        base_params = episodes_data[0]['metadata'].get('base_cp_params', (0, 0, 0))
        angle_deg = np.degrees(base_params[0])
        dist = base_params[1]

        if use_last_frames:
            # Fallback mode: show last frame from first episode (has trajectory overlay)
            last_frame = None
            for ep_data in episodes_data:
                if ep_data['last_frame'] is not None:
                    last_frame = ep_data['last_frame']
                    break

            if last_frame is not None:
                ax.imshow(last_frame)
            else:
                ax.set_facecolor('lightgray')
            ax.set_title(f"Mode {mode}: {angle_deg:.0f}°, d={dist:.1f}\n(single trajectory)", fontsize=10)
        else:
            # Normal mode: plot all trajectories on first frame
            first_frame = None
            for ep_data in episodes_data:
                if ep_data['first_frame'] is not None:
                    first_frame = ep_data['first_frame']
                    break

            if first_frame is not None:
                ax.imshow(first_frame)
            else:
                ax.set_facecolor('lightgray')

            # Plot all trajectories for this mode
            for i, ep_data in enumerate(episodes_data):
                traj = ep_data['trajectory']
                if traj is None or len(traj) == 0:
                    continue

                # Project 3D trajectory to 2D
                image_size = first_frame.shape[:2] if first_frame is not None else (256, 256)
                projected = project_3d_to_2d(traj, camera_config, image_size)

                # Draw trajectory
                color = colors[i % len(colors)]
                valid_points = [(p[0], p[1]) for p in projected if p is not None]

                if len(valid_points) > 1:
                    for j in range(len(valid_points) - 1):
                        p1, p2 = valid_points[j], valid_points[j + 1]
                        # Check bounds
                        if (0 <= p1[0] < image_size[1] and 0 <= p1[1] < image_size[0] and
                            0 <= p2[0] < image_size[1] and 0 <= p2[1] < image_size[0]):
                            ax.plot([p1[0], p2[0]], [p1[1], p2[1]],
                                   color=color, linewidth=1.5, alpha=0.7)

                # Mark start and end points
                if len(valid_points) > 0:
                    start = valid_points[0]
                    end = valid_points[-1]
                    if 0 <= start[0] < image_size[1] and 0 <= start[1] < image_size[0]:
                        ax.scatter([start[0]], [start[1]], c=[color], s=20, marker='o', zorder=5)
                    if 0 <= end[0] < image_size[1] and 0 <= end[1] < image_size[0]:
                        ax.scatter([end[0]], [end[1]], c=[color], s=20, marker='x', zorder=5)

            ax.set_title(f"Mode {mode}: {angle_deg:.0f}°, d={dist:.1f}", fontsize=10)

        ax.axis('off')

    # Hide unused subplots
    for idx in range(num_modes, n_rows * n_cols):
        row = idx // n_cols
        col = idx % n_cols
        axes[row, col].axis('off')

    plt.suptitle(f"Trajectory Visualization: {num_modes} modes x {len(episodes_data)} demos each", fontsize=14)
    plt.tight_layout()

    # Save or show
    if output_path:
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        print(f"Saved visualization to {output_path}")
    else:
        output_path = os.path.join(data_path, "..", "trajectory_visualization.png")
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        print(f"Saved visualization to {output_path}")

    plt.close()


def main():
    import argparse
    parser = argparse.ArgumentParser(description="Visualize trajectories from generated dataset")
    parser.add_argument("--data_path", type=str, default=DEFAULT_DATA_PATH,
                       help="Path to episodes directory")
    parser.add_argument("--output", type=str, default=None,
                       help="Output path for visualization image")
    args = parser.parse_args()

    if not os.path.exists(args.data_path):
        print(f"Error: Data path does not exist: {args.data_path}")
        sys.exit(1)

    visualize_modes(args.data_path, args.output)


if __name__ == "__main__":
    main()
