"""
Plot EE trajectories and z embeddings side by side with matching colors.

Creates two plots:
1. Angle test: 3D EE trajectories (left) + z embeddings (right)
2. Distance test: 3D EE trajectories (left) + z embeddings (right)
"""
import os
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


def get_12_colors():
    """Get 12 distinct colors for trajectories."""
    cmap = plt.cm.tab20
    colors = [cmap(i / 12) for i in range(12)]
    return colors


def load_ee_trajectory(ep_path):
    """Load EE trajectory from saved file or extract from states."""
    ee_traj_path = os.path.join(ep_path, 'ee_trajectory.npy')
    if os.path.exists(ee_traj_path):
        return np.load(ee_traj_path)
    else:
        # Fallback to extracting from states
        states = np.load(os.path.join(ep_path, 'states.npy'))
        # gripper_pose starts at index 15
        return states[:, 15:18]


def plot_angle_test(test_dir, z_path, output_path):
    """
    Plot angle test: 3D EE trajectories (left) + z embeddings (right).
    """
    episode_dirs = sorted([d for d in os.listdir(test_dir) if d.startswith('episode')],
                          key=lambda x: int(x.replace('episode', '')))

    # Load z embeddings
    z = np.load(z_path)
    angles_path = os.path.join(os.path.dirname(z_path), 'angles.npy')
    angles = np.load(angles_path) if os.path.exists(angles_path) else np.arange(0, 360, 30)

    # Create figure
    fig = plt.figure(figsize=(16, 7))

    # Left: 3D EE trajectories
    ax3d = fig.add_subplot(1, 2, 1, projection='3d')

    # Right: z embeddings
    ax_z = fig.add_subplot(1, 2, 2)

    # Use 12 distinct colors
    colors = get_12_colors()

    for i, ep_dir in enumerate(episode_dirs):
        ep_path = os.path.join(test_dir, ep_dir)
        ee_pos = load_ee_trajectory(ep_path)
        metadata = np.load(os.path.join(ep_path, 'metadata.npy'), allow_pickle=True).item()
        angle = metadata.get('angle', i * 30)
        success = metadata.get('success', True)

        color = colors[i % len(colors)]
        # Plot 3D EE trajectory
        ax3d.plot(ee_pos[:, 0], ee_pos[:, 1], ee_pos[:, 2],
                 color=color, linewidth=2, alpha=0.8)
        ax3d.scatter(ee_pos[0, 0], ee_pos[0, 1], ee_pos[0, 2],
                    color=color, s=60, marker='o', edgecolors='black', linewidths=0.5)
        ax3d.scatter(ee_pos[-1, 0], ee_pos[-1, 1], ee_pos[-1, 2],
                    color=color, s=60, marker='s', edgecolors='black', linewidths=0.5)

        # Plot z embedding
        if i < len(z):
            ax_z.scatter(z[i, 0], z[i, 1], c=[color], s=150, alpha=0.8,
                        edgecolors='black', linewidths=1)

    # Format 3D plot
    ax3d.set_xlabel('X (m)', fontsize=11)
    ax3d.set_ylabel('Y (m)', fontsize=11)
    ax3d.set_zlabel('Z (m)', fontsize=11)
    ax3d.set_title('EE Trajectories (3D)', fontsize=13)

    # Format z plot
    ax_z.set_xlabel('z[0]', fontsize=12)
    ax_z.set_ylabel('z[1]', fontsize=12)
    ax_z.set_title('Latent Space (z)', fontsize=13)
    ax_z.grid(True, alpha=0.3)
    ax_z.set_aspect('equal', adjustable='datalim')

    plt.suptitle('Angle Generalization Test\n(distance=1.0, position=0.5, 12 angles)', fontsize=14)
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved angle test plot to {output_path}")


def plot_distance_test(test_dir, z_path, output_path):
    """
    Plot distance test: 3D EE trajectories (left) + z embeddings (right).
    """
    episode_dirs = sorted([d for d in os.listdir(test_dir) if d.startswith('episode')],
                          key=lambda x: int(x.replace('episode', '')))

    # Load z embeddings
    z = np.load(z_path)

    # Create figure
    fig = plt.figure(figsize=(16, 7))

    # Left: 3D EE trajectories
    ax3d = fig.add_subplot(1, 2, 1, projection='3d')

    # Right: z embeddings
    ax_z = fig.add_subplot(1, 2, 2)

    # Use 12 distinct colors
    colors = get_12_colors()

    for i, ep_dir in enumerate(episode_dirs):
        ep_path = os.path.join(test_dir, ep_dir)
        ee_pos = load_ee_trajectory(ep_path)
        metadata = np.load(os.path.join(ep_path, 'metadata.npy'), allow_pickle=True).item()

        angle = metadata.get('angle', 0)
        dist = metadata.get('distance', 0.5)
        success = metadata.get('success', True)

        color = colors[i % len(colors)]

        # Use different markers for angle 0 vs 180
        marker_start = 'o' if angle == 0 else '^'
        marker_end = 's' if angle == 0 else 'v'
        linestyle = '-' if angle == 0 else '--'

        # Plot 3D EE trajectory
        ax3d.plot(ee_pos[:, 0], ee_pos[:, 1], ee_pos[:, 2],
                 color=color, linewidth=2, alpha=0.8, linestyle=linestyle)
        ax3d.scatter(ee_pos[0, 0], ee_pos[0, 1], ee_pos[0, 2],
                    color=color, s=60, marker=marker_start, edgecolors='black', linewidths=0.5)
        ax3d.scatter(ee_pos[-1, 0], ee_pos[-1, 1], ee_pos[-1, 2],
                    color=color, s=60, marker=marker_end, edgecolors='black', linewidths=0.5)

        # Plot z embedding
        if i < len(z):
            marker_z = 'o' if angle == 0 else 's'
            ax_z.scatter(z[i, 0], z[i, 1], c=[color], s=150, alpha=0.8,
                        edgecolors='black', linewidths=1, marker=marker_z)

    # Format 3D plot
    ax3d.set_xlabel('X (m)', fontsize=11)
    ax3d.set_ylabel('Y (m)', fontsize=11)
    ax3d.set_zlabel('Z (m)', fontsize=11)
    ax3d.set_title('EE Trajectories (3D)\n(solid=angle 0°, dashed=angle 180°)', fontsize=12)

    # Format z plot
    ax_z.set_xlabel('z[0]', fontsize=12)
    ax_z.set_ylabel('z[1]', fontsize=12)
    ax_z.set_title('Latent Space (z)\n(circle=angle 0°, square=angle 180°)', fontsize=12)
    ax_z.grid(True, alpha=0.3)
    ax_z.set_aspect('equal', adjustable='datalim')

    plt.suptitle('Distance Generalization Test\n(angle=0°,180°, 6 distances each, position=0.5)', fontsize=14)
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved distance test plot to {output_path}")


def main():
    base_dir = '/scratch4/workspace/placeholder-hdp1/dppo/data/close_drawer/variation2/encoder/z2/generalize_test'

    angle_test_dir = os.path.join(base_dir, 'angle_test')
    distance_test_dir = os.path.join(base_dir, 'distance_test')
    plots_dir = os.path.join(base_dir, 'plots')

    os.makedirs(plots_dir, exist_ok=True)

    # Paths to z embeddings
    z_angle_path = os.path.join(base_dir, 'z_angle_test.npy')
    z_distance_path = os.path.join(base_dir, 'z_distance_test.npy')

    # Plot angle test
    if os.path.exists(angle_test_dir) and os.path.exists(z_angle_path):
        plot_angle_test(angle_test_dir, z_angle_path,
                       os.path.join(plots_dir, 'angle_test_ee_and_z.png'))
    else:
        print(f"Skipping angle test: dir={os.path.exists(angle_test_dir)}, z={os.path.exists(z_angle_path)}")

    # Plot distance test
    if os.path.exists(distance_test_dir) and os.path.exists(z_distance_path):
        plot_distance_test(distance_test_dir, z_distance_path,
                          os.path.join(plots_dir, 'distance_test_ee_and_z.png'))
    else:
        print(f"Skipping distance test: dir={os.path.exists(distance_test_dir)}, z={os.path.exists(z_distance_path)}")

    print(f"\nPlots saved to {plots_dir}")


if __name__ == '__main__':
    main()
