import os

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch

from tasks import *

# Set global font to serif
sns.set_context(context="paper", font_scale=0.68)
sns.set_style("white", {"font.family": "serif"})

# Create figure with 4 subplots side by side
fig, axes = plt.subplots(1, 4, figsize=(32, 9))
plt.subplots_adjust(wspace=0.1)  # Increased gap between subplots

# Add main title
fig.suptitle(
    "Trajectory Distribution Across Iterations", fontsize=52, y=0.9, fontfamily="serif"
)

methods = [
    "gemini",
    "gemini_few_shot_oracle",
    "gemini_iteration_2_oracle",
    "gemini_iteration_3_oracle",
]
task = "fridge"

# Create output directory if it doesn't exist
os.makedirs(f"vis/{task}", exist_ok=True)

env = gym.make(
    "EnvFridge-v0",
    render_mode="rgb_array",
    human_render_camera_configs=dict(shader_pack="rt", width=800, height=800),
)

camera = env.unwrapped.scene.human_render_cameras["render_camera"]
extrinsics = camera.camera.get_extrinsic_matrix().cpu()[0].float()
intrinsics = camera.camera.get_intrinsic_matrix().cpu()[0].float()

for idx, method in enumerate(methods):
    trajectories = torch.load(
        f"data/{method}/{task}/train1/trajectories.pt", weights_only=False
    )
    keypoints = torch.load(
        f"data/{method}/{task}/train1/keypoints.pt", weights_only=False
    )

    ax = axes[idx]
    img = keypoints["handle"].imgs[0]
    # Convert PIL image to numpy array and increase brightness
    img = np.array(img)
    img = np.clip(img * 1.5, 0, 255).astype(np.uint8)
    ax.imshow(img)

    # Set the plot limits to image dimensions
    ax.set_xlim(150, 650)  # Zoomed in from original (0, 800)
    ax.set_ylim(
        600, 100
    )  # Zoomed in from original (800, 0), still reversed for image coordinates

    # Add iteration number as title below the plot
    ax.set_title(f"Iteration {idx}", pad=25, fontsize=40, fontfamily="serif", y=-0.15)

    for i in range(30):
        for name in ["handle", "hand"]:
            points = trajectories[name][i, :, :3].cpu().float()

            # Convert points to homogeneous coordinates
            points_homog = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)

            # Apply extrinsic matrix to get camera coordinates
            points_cam = torch.matmul(points_homog, extrinsics.T)

            # Project to 2D
            points_2d_homog = torch.matmul(points_cam, intrinsics.T)

            # Convert from homogeneous to pixel coordinates
            points_2d = points_2d_homog[..., :2] / points_2d_homog[..., 2:3]

            # Plot trajectory as a line without points
            color = "#FF9E4A"  # Warm orange
            ax.plot(
                points_2d[:, 0],
                points_2d[:, 1],
                "-",
                color=color,
                linewidth=4,
                alpha=0.7,
            )

    ax.axis("off")

plt.savefig(f"trajectories_comparison.pdf", bbox_inches="tight", pad_inches=0)
plt.close()
