"""
Plot EE trajectories for the generalization test.

Uses the same color for each trajectory as in the z visualization plots.
"""
import os
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# State format: [joint_positions(7), joint_velocities(7), gripper_open(1), gripper_pose(7)]
# gripper_pose: [x, y, z, qx, qy, qz, qw]
# EE position is in gripper_pose[0:3], which is state indices 15:18


def get_12_colors():
    """Get 12 distinct colors for trajectories (same as z visualization)."""
    # Use tab20 colormap for distinct colors
    cmap = plt.cm.tab20
    colors = [cmap(i / 12) for i in range(12)]
    return colors


def extract_ee_positions(states):
    """Extract EE positions from states."""
    # State dim = 22: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
    # gripper_pose starts at index 15
    return states[:, 15:18]


def load_ee_trajectory(ep_path):
    """Load EE trajectory from saved file or extract from states."""
    ee_traj_path = os.path.join(ep_path, 'ee_trajectory.npy')
    if os.path.exists(ee_traj_path):
        return np.load(ee_traj_path)
    else:
        # Fallback to extracting from states
        states = np.load(os.path.join(ep_path, 'states.npy'))
        return extract_ee_positions(states)


def plot_angle_test_trajectories(test_dir, output_path):
    """Plot EE trajectories for angle generalization test."""
    episode_dirs = sorted([d for d in os.listdir(test_dir) if d.startswith('episode')],
                          key=lambda x: int(x.replace('episode', '')))

    # Create figure with 2D and 3D views
    fig = plt.figure(figsize=(18, 12))

    # 3D plot
    ax3d = fig.add_subplot(2, 2, 1, projection='3d')

    # Top-down view (X-Y)
    ax_xy = fig.add_subplot(2, 2, 2)

    # Side view (X-Z)
    ax_xz = fig.add_subplot(2, 2, 3)

    # Side view (Y-Z)
    ax_yz = fig.add_subplot(2, 2, 4)

    # Use 12 distinct colors (same as z visualization)
    colors = get_12_colors()

    for i, ep_dir in enumerate(episode_dirs):
        ep_path = os.path.join(test_dir, ep_dir)
        ee_pos = load_ee_trajectory(ep_path)
        metadata = np.load(os.path.join(ep_path, 'metadata.npy'), allow_pickle=True).item()
        angle = metadata.get('angle', i * 30)
        success = metadata.get('success', True)

        color = colors[i % len(colors)]
        label = f'{angle:.0f}°' + ('' if success else ' (FAIL)')

        # 3D plot
        ax3d.plot(ee_pos[:, 0], ee_pos[:, 1], ee_pos[:, 2],
                 color=color, linewidth=1.5, alpha=0.8, label=label)
        ax3d.scatter(ee_pos[0, 0], ee_pos[0, 1], ee_pos[0, 2],
                    color=color, s=50, marker='o', edgecolors='black')
        ax3d.scatter(ee_pos[-1, 0], ee_pos[-1, 1], ee_pos[-1, 2],
                    color=color, s=50, marker='s', edgecolors='black')

        # X-Y view
        ax_xy.plot(ee_pos[:, 0], ee_pos[:, 1], color=color, linewidth=1.5, alpha=0.8)
        ax_xy.scatter(ee_pos[0, 0], ee_pos[0, 1], color=color, s=30, marker='o', edgecolors='black')
        ax_xy.scatter(ee_pos[-1, 0], ee_pos[-1, 1], color=color, s=30, marker='s', edgecolors='black')

        # X-Z view
        ax_xz.plot(ee_pos[:, 0], ee_pos[:, 2], color=color, linewidth=1.5, alpha=0.8)
        ax_xz.scatter(ee_pos[0, 0], ee_pos[0, 2], color=color, s=30, marker='o', edgecolors='black')
        ax_xz.scatter(ee_pos[-1, 0], ee_pos[-1, 2], color=color, s=30, marker='s', edgecolors='black')

        # Y-Z view
        ax_yz.plot(ee_pos[:, 1], ee_pos[:, 2], color=color, linewidth=1.5, alpha=0.8)
        ax_yz.scatter(ee_pos[0, 1], ee_pos[0, 2], color=color, s=30, marker='o', edgecolors='black')
        ax_yz.scatter(ee_pos[-1, 1], ee_pos[-1, 2], color=color, s=30, marker='s', edgecolors='black')

    # Formatting
    ax3d.set_xlabel('X (m)')
    ax3d.set_ylabel('Y (m)')
    ax3d.set_zlabel('Z (m)')
    ax3d.set_title('3D View - Angle Test (dist=1.0, pos=0.5)')
    ax3d.legend(loc='upper left', fontsize=8, ncol=2)

    ax_xy.set_xlabel('X (m)')
    ax_xy.set_ylabel('Y (m)')
    ax_xy.set_title('Top-Down View (X-Y)')
    ax_xy.grid(True, alpha=0.3)
    ax_xy.set_aspect('equal')

    ax_xz.set_xlabel('X (m)')
    ax_xz.set_ylabel('Z (m)')
    ax_xz.set_title('Side View (X-Z)')
    ax_xz.grid(True, alpha=0.3)

    ax_yz.set_xlabel('Y (m)')
    ax_yz.set_ylabel('Z (m)')
    ax_yz.set_title('Side View (Y-Z)')
    ax_yz.grid(True, alpha=0.3)

    plt.suptitle('Angle Generalization Test - EE Trajectories\n(12 angles, distance=1.0, position=0.5)', fontsize=14)
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved angle test EE trajectories to {output_path}")


def plot_distance_test_trajectories(test_dir, output_path):
    """Plot EE trajectories for distance generalization test."""
    episode_dirs = sorted([d for d in os.listdir(test_dir) if d.startswith('episode')],
                          key=lambda x: int(x.replace('episode', '')))

    # Create figure with 2D and 3D views
    fig = plt.figure(figsize=(18, 12))

    # 3D plot
    ax3d = fig.add_subplot(2, 2, 1, projection='3d')

    # Top-down view (X-Y)
    ax_xy = fig.add_subplot(2, 2, 2)

    # Side view (X-Z)
    ax_xz = fig.add_subplot(2, 2, 3)

    # Side view (Y-Z)
    ax_yz = fig.add_subplot(2, 2, 4)

    # Use 12 distinct colors (same as z visualization)
    colors = get_12_colors()

    for i, ep_dir in enumerate(episode_dirs):
        ep_path = os.path.join(test_dir, ep_dir)
        ee_pos = load_ee_trajectory(ep_path)
        metadata = np.load(os.path.join(ep_path, 'metadata.npy'), allow_pickle=True).item()

        angle = metadata.get('angle', 0)
        dist = metadata.get('distance', 0.5)
        success = metadata.get('success', True)

        # Use the same color index as the trajectory index
        color = colors[i % len(colors)]

        # Line style by angle (solid for angle=0, dashed for angle=180)
        linestyle = '-' if angle == 0 else '--'
        marker_start = 'o' if angle == 0 else '^'
        marker_end = 's' if angle == 0 else 'v'

        label = f'angle={angle:.0f}°, dist={dist:.2f}' + ('' if success else ' (FAIL)')

        # 3D plot
        ax3d.plot(ee_pos[:, 0], ee_pos[:, 1], ee_pos[:, 2],
                 color=color, linewidth=1.5, alpha=0.8, linestyle=linestyle, label=label)
        ax3d.scatter(ee_pos[0, 0], ee_pos[0, 1], ee_pos[0, 2],
                    color=color, s=50, marker=marker_start, edgecolors='black')
        ax3d.scatter(ee_pos[-1, 0], ee_pos[-1, 1], ee_pos[-1, 2],
                    color=color, s=50, marker=marker_end, edgecolors='black')

        # X-Y view
        ax_xy.plot(ee_pos[:, 0], ee_pos[:, 1], color=color, linewidth=1.5, alpha=0.8, linestyle=linestyle)
        ax_xy.scatter(ee_pos[0, 0], ee_pos[0, 1], color=color, s=30, marker=marker_start, edgecolors='black')
        ax_xy.scatter(ee_pos[-1, 0], ee_pos[-1, 1], color=color, s=30, marker=marker_end, edgecolors='black')

        # X-Z view
        ax_xz.plot(ee_pos[:, 0], ee_pos[:, 2], color=color, linewidth=1.5, alpha=0.8, linestyle=linestyle)
        ax_xz.scatter(ee_pos[0, 0], ee_pos[0, 2], color=color, s=30, marker=marker_start, edgecolors='black')
        ax_xz.scatter(ee_pos[-1, 0], ee_pos[-1, 2], color=color, s=30, marker=marker_end, edgecolors='black')

        # Y-Z view
        ax_yz.plot(ee_pos[:, 1], ee_pos[:, 2], color=color, linewidth=1.5, alpha=0.8, linestyle=linestyle)
        ax_yz.scatter(ee_pos[0, 1], ee_pos[0, 2], color=color, s=30, marker=marker_start, edgecolors='black')
        ax_yz.scatter(ee_pos[-1, 1], ee_pos[-1, 2], color=color, s=30, marker=marker_end, edgecolors='black')

    # Formatting
    ax3d.set_xlabel('X (m)')
    ax3d.set_ylabel('Y (m)')
    ax3d.set_zlabel('Z (m)')
    ax3d.set_title('3D View - Distance Test (angle=0,180, pos=0.5)')
    ax3d.legend(loc='upper left', fontsize=7, ncol=2)

    ax_xy.set_xlabel('X (m)')
    ax_xy.set_ylabel('Y (m)')
    ax_xy.set_title('Top-Down View (X-Y)\n(solid=angle 0°, dashed=angle 180°)')
    ax_xy.grid(True, alpha=0.3)
    ax_xy.set_aspect('equal')

    ax_xz.set_xlabel('X (m)')
    ax_xz.set_ylabel('Z (m)')
    ax_xz.set_title('Side View (X-Z)')
    ax_xz.grid(True, alpha=0.3)

    ax_yz.set_xlabel('Y (m)')
    ax_yz.set_ylabel('Z (m)')
    ax_yz.set_title('Side View (Y-Z)')
    ax_yz.grid(True, alpha=0.3)

    plt.suptitle('Distance Generalization Test - EE Trajectories\n(angle=0°,180°, 6 distances each, position=0.5)', fontsize=14)
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved distance test EE trajectories to {output_path}")


def main():
    base_dir = '/scratch4/workspace/placeholder-hdp1/dppo/data/close_drawer/variation2/encoder/z2/generalize_test'

    angle_test_dir = os.path.join(base_dir, 'angle_test')
    distance_test_dir = os.path.join(base_dir, 'distance_test')
    plots_dir = os.path.join(base_dir, 'plots')

    os.makedirs(plots_dir, exist_ok=True)

    if os.path.exists(angle_test_dir):
        plot_angle_test_trajectories(angle_test_dir, os.path.join(plots_dir, 'angle_test_ee_trajectories.png'))
    else:
        print(f"Angle test directory not found: {angle_test_dir}")

    if os.path.exists(distance_test_dir):
        plot_distance_test_trajectories(distance_test_dir, os.path.join(plots_dir, 'distance_test_ee_trajectories.png'))
    else:
        print(f"Distance test directory not found: {distance_test_dir}")

    print(f"\nPlots saved to {plots_dir}")


if __name__ == '__main__':
    main()
