import os
import json
import cv2
import shortuuid
from tqdm import tqdm
from collections import defaultdict
import shutil
import math

def process_and_split_images(json_path, temp_path, root_path, image_dir, max_images):
    # Ensure temp and output directories exist
    os.makedirs(temp_path, exist_ok=True)
    train_images_dir = os.path.join(root_path, "train/images")
    val_images_dir = os.path.join(root_path, "val/images")
    os.makedirs(train_images_dir, exist_ok=True)
    os.makedirs(val_images_dir, exist_ok=True)

    # Pair list
    pair_list = [
        ('car', 'car'),
        ('pedestrian', 'pedestrian'),
        ('cyclist', 'cyclist'),
        ('van', 'van'),
        ('car', 'pedestrian'),
        ('car', 'cyclist'),
        ('car', 'van'),
        ('pedestrian', 'cyclist'),
        ('pedestrian', 'van'),
        ('cyclist', 'van')
    ]

    # Load JSON
    with open(json_path, 'r') as f:
        data = json.load(f)

    # Track valid cropped images and filtering stats
    valid_crops = defaultdict(list)
    total_images = 0
    filtered_bbox_edge = 0
    filtered_invalid_crop = 0

    # Process each image in the JSON
    for image_id, frame_data in tqdm(data.items(), desc="Processing images"):
        total_images += 1
        image_path = os.path.join(image_dir, f"{image_id}.png")
        image = cv2.imread(image_path)

        if image is None:
            print(f"Image {image_path} not found. Skipping...")
            continue

        img_height, img_width = image.shape[:2]

        for pair in frame_data.get("pairs", []):
            obj1 = next((obj for obj in frame_data["objects"] if obj["id"] == pair[0]), None)
            obj2 = next((obj for obj in frame_data["objects"] if obj["id"] == pair[1]), None)

            if obj1 is None or obj2 is None:
                continue

            # Get bounding boxes
            obj1_bbox = obj1["bbox_2d"]
            obj2_bbox = obj2["bbox_2d"]

            # Check if any bbox touches the edge of the image
            if (obj1_bbox[0] <= 0 or obj1_bbox[1] <= 0 or obj1_bbox[2] >= img_width or obj1_bbox[3] >= img_height or
                obj2_bbox[0] <= 0 or obj2_bbox[1] <= 0 or obj2_bbox[2] >= img_width or obj2_bbox[3] >= img_height):
                filtered_bbox_edge += 1
                continue

            # Determine leftmost and rightmost bounding boxes
            left_bbox = obj1_bbox if obj1_bbox[0] < obj2_bbox[0] else obj2_bbox
            right_bbox = obj2_bbox if obj1_bbox[0] < obj2_bbox[0] else obj1_bbox

            # Calculate the center between the leftmost and rightmost bounding boxes
            center_x = (int(left_bbox[0]) + int(right_bbox[2])) // 2

            # Calculate the crop width based on the height
            crop_width = max(img_height, min(img_height * 2, img_width))  # Ensure width is between height and height * 2

            # Determine the crop boundaries
            crop_x1 = max(0, center_x - crop_width // 2)  # Ensure it doesn't go out of bounds on the left
            crop_x2 = min(img_width, center_x + crop_width // 2)  # Ensure it doesn't go out of bounds on the right

            # Adjust the crop width if it exceeds the image boundaries
            actual_crop_width = crop_x2 - crop_x1
            if actual_crop_width < crop_width:
                if crop_x1 == 0:  # If the crop is at the left edge, extend to the right
                    crop_x2 = min(img_width, crop_x1 + crop_width)
                elif crop_x2 == img_width:  # If the crop is at the right edge, extend to the left
                    crop_x1 = max(0, crop_x2 - crop_width)

            # Crop the image
            crop = image[:, crop_x1:crop_x2]

            # Check if cropped width is valid
            if crop.shape[1] < img_height or crop.shape[1] > img_height * 2:
                filtered_invalid_crop += 1
                continue  # Skip invalid crops

            # Draw bounding boxes
            frame = crop.copy()
            cv2.rectangle(frame, (int(obj1_bbox[0]) - crop_x1, int(obj1_bbox[1])),
                          (int(obj1_bbox[2]) - crop_x1, int(obj1_bbox[3])), (0, 0, 255), 2)  # Red for obj1
            cv2.rectangle(frame, (int(obj2_bbox[0]) - crop_x1, int(obj2_bbox[1])),
                          (int(obj2_bbox[2]) - crop_x1, int(obj2_bbox[3])), (255, 0, 0), 2)  # Blue for obj2

            # Save cropped image
            uid = shortuuid.uuid()
            crop_path = os.path.join(temp_path, f"{uid}.jpg")
            cv2.imwrite(crop_path, frame)

            # Track valid crop
            pair_key = tuple(sorted([obj1["class"], obj2["class"]]))
            valid_crops[pair_key].append({
                "id": uid,
                "image": crop_path,
                "obj1": obj1,
                "obj2": obj2
            })

    # Limit total valid images to max_images
    total_valid_images = sum(len(crops) for crops in valid_crops.values())
    while total_valid_images > max_images:
        # Calculate the initial target per pair type
        target_per_type = max_images // len(valid_crops)
        current_total = 0

        # Step 1: Preserve all pair types up to the target
        for pair_type in list(valid_crops.keys()):
            if len(valid_crops[pair_type]) > target_per_type:
                valid_crops[pair_type] = valid_crops[pair_type][:target_per_type]
            current_total += len(valid_crops[pair_type])

        # Step 2: Calculate remaining slots
        remaining_slots = max_images - current_total
        if remaining_slots <= 0:
            break

        # Step 3: Redistribute remaining slots among pair types with excess images
        excess_types = [pair_type for pair_type in valid_crops if len(valid_crops[pair_type]) > target_per_type]
        if not excess_types:
            break

        additional_per_type = remaining_slots // len(excess_types)
        for pair_type in excess_types:
            current_count = len(valid_crops[pair_type])
            max_additional = min(additional_per_type, len(valid_crops[pair_type]) - target_per_type)
            valid_crops[pair_type] = valid_crops[pair_type][:current_count + max_additional]

        # Recalculate total valid images
        total_valid_images = sum(len(crops) for crops in valid_crops.values())

    # Split into train and test
    train_metadata = []
    val_metadata = []
    val_questions = []
    val_answers = []

    pair_type_summary = defaultdict(lambda: {"train": 0, "test": 0})

    for pair_type, crops in valid_crops.items():
        # Shuffle and split
        train_split = int(0.8 * len(crops))
        train_crops = crops[:train_split]
        val_crops = crops[train_split:]

        # Update pair type summary
        pair_type_summary[pair_type]["train"] += len(train_crops)
        pair_type_summary[pair_type]["test"] += len(val_crops)

        # Process train crops
        for crop in train_crops:
            new_path = os.path.join(train_images_dir, f"{crop['id']}.jpg")
            shutil.move(crop["image"], new_path)

            question = f"<image> Estimate the real-world distances between the objects in this image. Which object is closer to the camera, the {crop['obj1']['class']} (highlighted by a red box) or the {crop['obj2']['class']} (highlighted by a blue box) to the camera? choose one option from below: 1. red , 2. blue"
            answer = "1. red" if crop["obj1"]["closest_depth"] < crop["obj2"]["closest_depth"] else "2. blue"

            train_metadata.append({
                "id": crop["id"],
                "image": f"images/{crop['id']}.jpg",
                "conversations": [
                    {"from": "human", "value": question},
                    {"from": "gpt", "value": answer}
                ]
            })

        # Process val crops
        for crop in val_crops:
            new_path = os.path.join(val_images_dir, f"{crop['id']}.jpg")
            shutil.move(crop["image"], new_path)

            question = f"<image> Estimate the real-world distances between the objects in this image. Which object is closer to the camera, the {crop['obj1']['class']} (highlighted by a red box) or the {crop['obj2']['class']} (highlighted by a blue box) to the camera? choose one option from below: 1. red , 2. blue"
            answer = "1. red" if crop["obj1"]["closest_depth"] < crop["obj2"]["closest_depth"] else "2. blue"

            val_questions.append({
                "question_id": crop['id'],
                "image": f"images/{crop['id']}.jpg",
                "category": "default",
                "text": question
            })
            val_answers.append({
                "question_id": crop['id'],
                "prompt": question,
                "text": answer,
                "answer_id": None,
                "model_id": None,
                "metadata": {}
            })

    # Save metadata
    with open(os.path.join(root_path, "train/train.json"), "w") as f:
        json.dump(train_metadata, f, indent=2)

    with open(os.path.join(root_path, "val/val.json"), "w") as f:
        json.dump(val_questions, f, indent=2)

    with open(os.path.join(root_path, "val/val_ans.json"), "w") as f:
        json.dump(val_answers, f, indent=2)

    # Delete the temporary directory
    shutil.rmtree(temp_path)

    # Print statistics
    print(f"{'✅ Total images processed:':<50} {total_images}")
    print(f"{'❌ Images filtered out due to bbox touching edge:':<50} {filtered_bbox_edge}")
    print(f"{'❌ Images filtered out due to invalid crop dimensions:':<50} {filtered_invalid_crop}")
    print(f"{'✅ Total valid images after limiting:':<50} {total_valid_images}")
    print(f"{'✅ Train images:':<50} {len(train_metadata)}")
    print(f"{'✅ Validation images:':<50} {len(val_questions)}")
    print(f"{'✅ Temporary directory deleted:':<50} {temp_path}")
    print(f"{'✅ Processing complete. Train and validation datasets are ready.':<50}")

    # Print pair type summary
    print("\nPair Type Summary:")
    for pair_type, counts in pair_type_summary.items():
        print(f"{pair_type}: Train = {counts['train']}, Test = {counts['test']}")

# Example usage
json_path = "/fs/ess/PAS2099/sooyoung/perception_system_v2_local/kitti_analysis/FINAL/filtered_label_with_pairs.json"
temp_path = "/fs/scratch/PAS2099/vfm/exterior_depth/kitti/relative_depth_reduced/temp"
root_path = "/fs/scratch/PAS2099/vfm/exterior_depth/kitti/relative_depth_reduced"
image_dir = "/fs/scratch/PAS2099/dataset/kitti_obj3d/training/image_2"
max_images = 10000
process_and_split_images(json_path, temp_path, root_path, image_dir, max_images)