import os
import json
import cv2
import numpy as np
from tqdm import tqdm

def create_videos_from_json(json_file_path, image_root, output_dir, fps=1):
    """
    Reads the JSON file, selects 5 images from 'expw' and 5 from 'rafdb' for each emotion,
    draws bounding boxes (red for expw, blue for rafdb), and creates a video for each emotion.

    Args:
        json_file_path (str): Path to the JSON file.
        image_root (str): Root directory containing the images.
        output_dir (str): Directory to save the output videos.
        fps (int): Frames per second for the videos.

    Returns:
        None
    """
    # Emotion mapping for rafdb to match expw emotion names
    rafdb_emotion_map = {
        "anger": "angry",
        "disgust": "disgust",
        "fear": "fear",
        "happiness": "happy",
        "neutral": "neutral",
        "sadness": "sad",
        "surprise": "surprise"
    }

    # Read the JSON file
    with open(json_file_path, 'r') as file:
        data = {}
        for line in file:
            batch = json.loads(line.strip())
            for emotion, entries in batch.items():
                if emotion not in data:
                    data[emotion] = []
                data[emotion].extend(entries)

    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Process each emotion
    for emotion, entries in data.items():
        # Separate entries by source
        expw_entries = [entry for entry in entries if entry[-1] == "expw"]
        rafdb_entries = [entry for entry in entries if entry[-1] == "rafdb"]

        # Select 5 images from each source
        selected_entries = expw_entries[:10] + rafdb_entries[:10]

        if len(selected_entries) < 20:
            print(f"Skipping {emotion} as it has less than 10 images.")
            continue

        # Load images and determine the largest size
        images = []
        max_width, max_height = 0, 0
        for entry in selected_entries:
            # Determine the correct image path based on the source
            if entry[-1] == "expw":
                image_path = os.path.join(image_root, "Expw_original", "output", entry[0])
            elif entry[-1] == "rafdb":
                rafdb_emotion = [key for key, value in rafdb_emotion_map.items() if value == emotion][0]
                image_path = os.path.join(image_root, "RAF", "data", entry[0])
            else:
                print(f"Unknown source for entry: {entry}, skipping.")
                continue

            if not os.path.exists(image_path):
                print(f"Image not found: {image_path}, skipping.")
                continue

            # Load the image
            image = cv2.imread(image_path)
            if image is None:
                print(f"Failed to load image: {image_path}, skipping.")
                continue

            # Draw bounding box on the image
            top, left, right, bottom = entry[1:5]
            if entry[-1] == "expw":
                # Draw a red bounding box for expw
                cv2.rectangle(image, (int(left), int(top)), (int(right), int(bottom)), (0, 0, 255), 2)
            elif entry[-1] == "rafdb":
                # Draw a blue bounding box for rafdb
                cv2.rectangle(image, (int(left), int(top)), (int(right), int(bottom)), (255, 0, 0), 2)

            # Update the max dimensions
            max_height = max(max_height, image.shape[0])
            max_width = max(max_width, image.shape[1])

            images.append(image)

        # Resize all images to the largest size with black padding
        resized_images = []
        for img in images:
            height, width = img.shape[:2]
            top_pad = (max_height - height) // 2
            bottom_pad = max_height - height - top_pad
            left_pad = (max_width - width) // 2
            right_pad = max_width - width - left_pad

            padded_image = cv2.copyMakeBorder(
                img, top_pad, bottom_pad, left_pad, right_pad, cv2.BORDER_CONSTANT, value=[0, 0, 0]
            )
            resized_images.append(padded_image)

        # Create a video for the emotion
        video_path = os.path.join(output_dir, f"{emotion}.mp4")
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        video_writer = cv2.VideoWriter(video_path, fourcc, fps, (max_width, max_height))

        for img in resized_images:
            video_writer.write(img)

        video_writer.release()
        print(f"Video saved for emotion: {emotion}")

if __name__ == "__main__":
    # Path to the JSON file
    json_file_path = "/fs/ess/PAS2099/sooyoung/vfm_dataset/emotion/data/processed_data.json"
    # Root directory containing the images
    image_root = "/fs/ess/PAS2099/sooyoung/vfm_dataset/emotion/data"
    # Directory to save the output videos
    output_dir = "/fs/ess/PAS2099/sooyoung/vfm_dataset/emotion/videos"

    try:
        create_videos_from_json(json_file_path, image_root, output_dir, fps=1)
    except Exception as e:
        print(f"Error: {e}")