import argparse
import json
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.patches as patches
import numpy as np
from matplotlib.animation import FFMpegWriter
from tqdm import tqdm
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from utils import load_instance_data

# Function to rotate a point (x, y) around the origin (0, 0) by an angle theta (in radians)
def rotate_point(x, y, theta):
    x_rot = x * np.cos(theta) - y * np.sin(theta)
    y_rot = x * np.sin(theta) + y * np.cos(theta)
    return x_rot, y_rot

def animate_instance_trajectory(historical_data, forecasted_data, predicted_data, output_file="instance_trajectory.mp4"):
    fig, ax = plt.subplots(figsize=(10, 8))

    # Get the unique instance IDs for both historical and forecasted data
    all_instance_ids = list(historical_data.keys()) + list(forecasted_data.keys())

    # Find the min and max positions to scale axes
    all_translations = []
    for instance_id in all_instance_ids:
        if instance_id in historical_data:
            translations = historical_data[instance_id]["translation"]
        if instance_id in forecasted_data:
            translations = forecasted_data[instance_id]["translation"]
        if instance_id in predicted_data:
            translations = predicted_data[instance_id]["translation"]
        all_translations.extend(translations)

    # Get min and max values for scaling the plot dynamically
    min_x = min([pos[0] for pos in all_translations])
    max_x = max([pos[0] for pos in all_translations])
    min_y = min([pos[1] for pos in all_translations])
    max_y = max([pos[1] for pos in all_translations])

    ax.set_xlim(min_x - 10, max_x + 10)  # Add some padding
    ax.set_ylim(min_y - 10, max_y + 10)  # Add some padding

    ax.set_xlabel('X Position')
    ax.set_ylabel('Y Position')

    # Function to update the plot for each frame
    def update_frame(frame):
        ax.clear()  # Clear the previous frame
        ax.set_xlim(min_x - 10, max_x + 10)  # Reset limits
        ax.set_ylim(min_y - 10, max_y + 10)  # Reset limits
        ax.set_xlabel('X Position')
        ax.set_ylabel('Y Position')
        ax.set_aspect('equal', 'box')

        # Process predicted, forecasted, and historical data
        for data_file, base_color in zip([predicted_data, forecasted_data, historical_data], ['orange', 'green', 'blue']):
            for instance_id in data_file:
                instance = data_file[instance_id]

                if frame not in instance["timestep"]:
                    continue  # Skip this instance if the frame exceeds its data
                else:
                    eff_time_index = instance["timestep"].index(frame) if isinstance(instance["timestep"], list) else np.where(instance["timestep"] == frame)[0][0]

                translation = instance["translation"][eff_time_index]
                rotation = instance["rotation"][eff_time_index]
                
                # Use the last historical size for this instance
                if instance_id in historical_data:
                    historical_size = historical_data[instance_id]["size"][-1]  # Use the last known size
                    width, length = historical_size[1], historical_size[2]
                    color = base_color
                else:
                    # Assume a default size if no historical data available
                    width, length = 4.0, 2.0
                    if data_file is predicted_data:
                        color = 'red'  # Hallucinated instances are red
                    else:
                        continue  # Non-scored instances (present in the g.t. forecast but not in the history) are not shown

                x, y = translation[0], translation[1]
                yaw = rotation[2]

                # Create the 4 corners of the bounding box before rotation
                corners = [
                    (-length / 2, width / 2),  # Bottom-left
                    (-length / 2, -width / 2),   # Bottom-right
                    (length / 2, -width / 2),    # Top-right
                    (length / 2, width / 2)    # Top-left
                ]

                # Rotate all corners and adjust to center
                rotated_corners = []
                for cx, cy in corners:
                    rx, ry = rotate_point(cx, cy, yaw)
                    rotated_corners.append((rx + x, ry + y))  # Translate back to the center

                # Create a polygon for the rotated bounding box
                rotated_bbox = patches.Polygon(rotated_corners, closed=True, edgecolor=color, facecolor='none', linewidth=2, alpha=0.5)
                ax.add_patch(rotated_bbox)

                # Optionally, add a scatter point at the instance's position for clarity
                ax.scatter(x, y, color=color, zorder=5, alpha=0.5)
                # ax.scatter(rotated_corners[0][0], rotated_corners[0][1], color='red', zorder=5)  # Bottom-left = red
                # ax.scatter(rotated_corners[1][0], rotated_corners[1][1], color='green', zorder=5)  # Bottom-right = green
                # ax.scatter(rotated_corners[2][0], rotated_corners[2][1], color='blue', zorder=5)  # Top-right = blue
                # ax.scatter(rotated_corners[3][0], rotated_corners[3][1], color='black', zorder=5)  # Top-left = black

        return []

    # Create the animation
    ani = animation.FuncAnimation(
        fig, update_frame, frames=tqdm(range(1+forecasted_data[next(iter(forecasted_data))]["timestep"][-1])),
        # fig, update_frame, frames=tqdm(range(10)),
        interval=500, blit=False
    )
    
    # Set up the writer for saving the video
    writer = FFMpegWriter(fps=2, metadata=dict(artist='Me'), bitrate=1800, extra_args=['-vcodec', 'libx264', '-pix_fmt', 'yuv420p'])
    
    # Save the animation
    ani.save(output_file, writer=writer, dpi=300)
    print(f"Video saved to {output_file}")

parser = argparse.ArgumentParser(description="Visualize vehicle forecasting")
parser.add_argument('-H', '--history_file', type=str, required=True, help="Path to the historical data file")
parser.add_argument('-F', '--forecast_file', type=str, required=True, help="Path to the ground-truth forecast data file")
parser.add_argument('-P', '--prediction_file', type=str, default=None, help="Path to the predicted forecast data file")
parser.add_argument('-O', '--output_dir', type=str, required=True, help="Path to the directory to save output video")
args = parser.parse_args()

os.makedirs(args.output_dir, exist_ok=True)
historical_data, _ = load_instance_data(args.history_file, numpy=True)
forecasted_data, _ = load_instance_data(args.forecast_file, numpy=True)
predicted_data, _ = load_instance_data(args.prediction_file, numpy=True, historical_data=historical_data) if args.prediction_file is not None else []
if args.prediction_file is not None:
    output_file = f"{args.output_dir}/{args.prediction_file.split('/')[-1].split('.')[0]}.mp4"
else:
    output_file = f"{args.output_dir}/{args.forecast_file.split('/')[-1].split('.')[0]}.mp4"

# Call the function to generate the animation and save as a video
animate_instance_trajectory(historical_data, forecasted_data, predicted_data, output_file=output_file)
