import os
import json
import cv2
from tqdm import tqdm  # Import tqdm for progress tracking

def create_video_from_json(json_path, images_dir, output_video_path, fps=1):
    # Load the JSON file
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    # Define colors for bounding boxes
    colors = {
        "car": (0, 0, 255),  # Red
        "pedestrian": (255, 0, 0),  # Blue
        "person_sitting": (255, 0, 0),  # Blue
        "cyclist": (0, 255, 0),  # Green
        "van": (0, 165, 255)  # Orange
    }
    
    # Get image dimensions from the first image
    sample_image_path = os.path.join(images_dir, sorted(os.listdir(images_dir))[0])
    sample_image = cv2.imread(sample_image_path)
    img_height, img_width, _ = sample_image.shape
    
    # Define the height of the black boxes
    black_box_height = 50
    
    # Calculate the new frame height (original image + top and bottom black boxes)
    frame_height = img_height + 2 * black_box_height
    frame_width = img_width
    
    # Initialize the video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
    
    frame_number = 0
    
    # Process each frame in the JSON
    for image_id, frame_data in tqdm(data.items(), desc="Processing frames"):
        image_name = f"{image_id}.png"  # Assuming image files are named as <image_id>.png
        pairs = frame_data.get("pairs", [])
        objects = {obj["id"]: obj for obj in frame_data.get("objects", [])}  # Map objects by their IDs
        
        # Load the corresponding image
        image_path = os.path.join(images_dir, image_name)
        image = cv2.imread(image_path)
        
        if image is None:
            print(f"Image {image_name} not found. Skipping...")
            continue
        
        for pair in tqdm(pairs, desc=f"Processing pairs in {image_name}", leave=False):
            
            # Clone the original image to create a new frame
            frame = image.copy()
            
            # Get object details for the pair using their IDs
            obj1 = objects.get(pair[0])
            obj2 = objects.get(pair[1])
            
            if obj1 is None or obj2 is None:
                print(f"Object IDs {pair[0]} or {pair[1]} not found in objects. Skipping pair...")
                continue
            
            # if obj1["closest_depth"] > 50 and obj2["closest_depth"] > 50:
            #     continue
            
            # Draw bbox for obj1
            obj1_bbox = obj1["bbox_2d"]  # [x1, y1, x2, y2]
            obj1_name = obj1["class"]
            obj1_closest_depth = obj1["closest_depth"]
            obj1_color = colors.get(obj1_name, (255, 255, 255))  # Default to white if class not found
            cv2.rectangle(frame, (int(obj1_bbox[0]), int(obj1_bbox[1])), (int(obj1_bbox[2]), int(obj1_bbox[3])), obj1_color, 2)
            
            # Draw bbox for obj2
            obj2_bbox = obj2["bbox_2d"]  # [x1, y1, x2, y2]
            obj2_name = obj2["class"]
            obj2_closest_depth = obj2["closest_depth"]
            obj2_color = colors.get(obj2_name, (255, 255, 255))  # Default to white if class not found
            cv2.rectangle(frame, (int(obj2_bbox[0]), int(obj2_bbox[1])), (int(obj2_bbox[2]), int(obj2_bbox[3])), obj2_color, 2)
            
            # Calculate depth difference
            depth_diff = abs(obj1_closest_depth - obj2_closest_depth)
            
            # Add black boxes on top and bottom
            frame_with_boxes = cv2.copyMakeBorder(frame, black_box_height, black_box_height, 0, 0, cv2.BORDER_CONSTANT, value=(0, 0, 0))
            
            # Add text to the top black box
            top_text = f"Frame {frame_number}\nDepth Diff: {depth_diff:.2f} | {obj1_name} - {obj2_name}"
            y0, dy = 20, 20
            for i, line in enumerate(top_text.split('\n')):
                y = y0 + i * dy
                cv2.putText(frame_with_boxes, line, (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA)
            
            # Add text to the bottom black box
            bottom_text = f"{obj1_name}'s Closest Depth: {obj1_closest_depth:.2f} | {obj2_name}'s Closest Depth: {obj2_closest_depth:.2f}"
            cv2.putText(frame_with_boxes, bottom_text, (10, img_height + black_box_height + 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA)
            
            # Write the frame to the video
            video_writer.write(frame_with_boxes)
            
            # Increment frame number
            frame_number += 1
    
    # Release the video writer
    video_writer.release()
    print(f"Video saved to {output_video_path}")

# Example usage
json_path = "/fs/ess/PAS2099/sooyoung/perception_system_v2_local/kitti_analysis/FINAL/filtered_label_with_pairs.json"
images_dir = "/fs/scratch/PAS2099/dataset/kitti_obj3d/training/image_2"
output_video_path = "output_video.mp4"
create_video_from_json(json_path, images_dir, output_video_path)