import os
import json
import random
import shortuuid
import cv2
from tqdm import tqdm

def create_dataset(json_file_path, image_root, output_root):
    """
    Reorganizes the dataset into training and validation sets, draws red bounding boxes,
    and generates the required JSON files.

    Args:
        json_file_path (str): Path to the JSON file containing the dataset.
        image_root (str): Root directory containing the images.
        output_root (str): Root directory to save the reorganized dataset.

    Returns:
        None
    """
    # Emotion mapping for rafdb to match expw emotion names
    rafdb_emotion_map = {
        "anger": "angry",
        "disgust": "disgust",
        "fear": "fear",
        "happiness": "happy",
        "neutral": "neutral",
        "sadness": "sad",
        "surprise": "surprise"
    }

    # Define the fixed order of emotions as per the question's options
    emotion_order = ["happy", "sad", "angry", "fear", "surprise", "neutral", "disgust"]

    # Ensure output directories exist
    train_images_dir = os.path.join(output_root, "train", "images")
    val_images_dir = os.path.join(output_root, "val", "images")
    os.makedirs(train_images_dir, exist_ok=True)
    os.makedirs(val_images_dir, exist_ok=True)

    # Read the JSON file
    with open(json_file_path, 'r') as file:
        data = {}
        for line in file:
            batch = json.loads(line.strip())
            for emotion, entries in batch.items():
                if emotion not in data:
                    data[emotion] = []
                data[emotion].extend(entries)

    # Initialize JSON data structures
    train_json = []
    val_json = []
    val_ans_json = []

    # Initialize counters for statistics
    train_stats = {emotion: 0 for emotion in data.keys()}
    val_stats = {emotion: 0 for emotion in data.keys()}

    # Process each emotion
    for emotion, entries in data.items():
        # Shuffle the entries
        random.shuffle(entries)

        # Split into 70% train and 30% val
        split_index = int(len(entries) * 0.7)
        train_entries = entries[:split_index]
        val_entries = entries[split_index:]

        # Update statistics
        train_stats[emotion] += len(train_entries)
        val_stats[emotion] += len(val_entries)

        # Process training entries
        for entry in tqdm(train_entries, desc=f"Processing training data for {emotion}"):
            # Determine the correct image path based on the source
            if entry[-1] == "expw":
                image_path = os.path.join(image_root, "Expw_original", "output", entry[0])
            elif entry[-1] == "rafdb":
                rafdb_emotion = [key for key, value in rafdb_emotion_map.items() if value == emotion][0]
                image_path = os.path.join(image_root, "RAF", "data", entry[0])
            else:
                print(f"Unknown source for entry: {entry}, skipping.")
                continue

            if not os.path.exists(image_path):
                print(f"Image not found: {image_path}, skipping.")
                continue

            # Load the image
            image = cv2.imread(image_path)
            if image is None:
                print(f"Failed to load image: {image_path}, skipping.")
                continue

            # Draw a red bounding box
            top, left, right, bottom = entry[1:5]
            cv2.rectangle(image, (int(left), int(top)), (int(right), int(bottom)), (0, 0, 255), 2)

            # Generate a unique ID and save the image
            unique_id = shortuuid.uuid()
            output_image_path = os.path.join(train_images_dir, f"{unique_id}.jpg")
            cv2.imwrite(output_image_path, image)

            # Add to training JSON
            train_json.append({
                "id": unique_id,
                "image": f"images/{unique_id}.jpg",
                "conversations": [
                    {
                        "from": "human",
                        "value": "<image> Which of the following best describes the person's emotion in the red box? 1. happy , 2. sad , 3. angry , 4. fear , 5. surprise , 6. neutral , 7. disgust"
                    },
                    {
                        "from": "gpt",
                        "value": f"{emotion_order.index(emotion) + 1}. {emotion}"  # Use emotion_order to get the correct number
                    }
                ]
            })

        # Process validation entries
        for entry in tqdm(val_entries, desc=f"Processing validation data for {emotion}"):
            # Determine the correct image path based on the source
            if entry[-1] == "expw":
                image_path = os.path.join(image_root, "Expw_original", "output", entry[0])
            elif entry[-1] == "rafdb":
                rafdb_emotion = [key for key, value in rafdb_emotion_map.items() if value == emotion][0]
                image_path = os.path.join(image_root, "RAF", "data", entry[0])
            else:
                print(f"Unknown source for entry: {entry}, skipping.")
                continue

            if not os.path.exists(image_path):
                print(f"Image not found: {image_path}, skipping.")
                continue

            # Load the image
            image = cv2.imread(image_path)
            if image is None:
                print(f"Failed to load image: {image_path}, skipping.")
                continue

            # Draw a red bounding box
            top, left, right, bottom = entry[1:5]
            cv2.rectangle(image, (int(left), int(top)), (int(right), int(bottom)), (0, 0, 255), 2)

            # Generate a unique ID and save the image
            unique_id = shortuuid.uuid()
            output_image_path = os.path.join(val_images_dir, f"{unique_id}.jpg")
            cv2.imwrite(output_image_path, image)

            # Add to validation JSONs
            val_json.append({
                "question_id": unique_id,
                "image": f"images/{unique_id}.jpg",
                "category": "default",
                "text": "<image> Which of the following best describes the person's emotion in the red box? 1. happy , 2. sad , 3. angry , 4. fear , 5. surprise , 6. neutral , 7. disgust"
            })
            val_ans_json.append({
                "question_id": unique_id,
                "prompt": "<image> Which of the following best describes the person's emotion in the red box? 1. happy , 2. sad , 3. angry , 4. fear , 5. surprise , 6. neutral , 7. disgust",
                "text": f"{emotion_order.index(emotion) + 1}. {emotion}",  # Use emotion_order to get the correct number
                "answer_id": None,
                "model_id": None,
                "metadata": {}
            })

    # Shuffle train JSON
    random.shuffle(train_json)

    # Shuffle val JSON and val_ans JSON together
    combined_val = list(zip(val_json, val_ans_json))
    random.shuffle(combined_val)
    val_json, val_ans_json = zip(*combined_val)
    val_json = list(val_json)
    val_ans_json = list(val_ans_json)

    # Save JSON files
    with open(os.path.join(output_root, "train", "train.json"), 'w') as f:
        json.dump(train_json, f, indent=4)

    with open(os.path.join(output_root, "val", "val.json"), 'w') as f:
        json.dump(val_json, f, indent=4)

    with open(os.path.join(output_root, "val", "val_ans.json"), 'w') as f:
        json.dump(val_ans_json, f, indent=4)

    # Print dataset statistics
    print("\nDataset Statistics:")
    print(f"{'Emotion':<15}{'Train Count':<15}{'Val Count':<15}")
    print("-" * 45)
    total_train = 0
    total_val = 0
    for emotion in train_stats.keys():
        train_count = train_stats[emotion]
        val_count = val_stats[emotion]
        total_train += train_count
        total_val += val_count
        print(f"{emotion:<15}{train_count:<15}{val_count:<15}")
    print("-" * 45)
    print(f"{'Total':<15}{total_train:<15}{total_val:<15}")

    print("\nDataset creation complete.")

if __name__ == "__main__":
    # Path to the JSON file
    json_file_path = "/fs/ess/PAS2099/sooyoung/vfm_dataset/emotion/data/processed_data.json"
    # Root directory containing the images
    image_root = "/fs/ess/PAS2099/sooyoung/vfm_dataset/emotion/data"
    # Root directory to save the reorganized dataset
    output_root = "/fs/scratch/PAS2099/vfm/emotion"

    try:
        create_dataset(json_file_path, image_root, output_root)
    except Exception as e:
        print(f"Error: {e}")