import os
import sys
# Assuming /mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl
# is the base directory containing the 'src' folder
# Adjust sys.path based on your actual project structure if needed.
# Make sure the path points to the directory *containing* 'src'.
sys.path.append("/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl")

import re
import itertools
import json
from tqdm import tqdm
from PIL import Image
import numpy as np
import concurrent.futures # Import for multi-threading

# Assuming these imports from src work based on sys.path
try:
    from src.utils.scene import DiffusionScene
    from src.utils.prompt import gen_prompt, edit_prompt, identity_prompt
    from src.utils.vlm import vlm_request, extract_and_parse_json
except ImportError as e:
    print(f"Error importing from src: {e}")
    print("Please ensure '/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl' is correct and contains the 'src' directory.")
    sys.exit(1)


# --- Keep your existing helper functions and definitions ---
def check_overlap(mask_1, mask_2):
    return (mask_1.bool() & mask_2.bool()).any()

def check_out_of_bounds(mask):
    xx, yy = np.where(mask.cpu())
    # Assuming mask is 512x512
    return xx.min() == 0 or xx.max() == 511 or yy.min() == 0 or yy.max() == 511

RELATIONS = {
    "2d": [
        "on the side of",
        "next to",
        "near",
        "on the left of",
        "on the right of",
        "on the bottom of",
        "on the top of"
    ],
    "3d": [
        'in front of',
        'at the back left of',
        'at the front left of',
        'behind of',
        'at the back right of',
        'at the front right of',
        'hidden by'
    ]
}

SCENES = [
    ("on the desert", ["animals", "outdoor", "person"]),
    ("in the room", ["indoor", "person"]),
    ("on the street", ["outdoor", "person"]),
    ("in the jungle", ["animals", "person"]),
    ("on the road", ["animals", "outdoor", "person"]),
    ("in the studio", ["indoor", "person"]),
    ("on the beach", ["animals", "person"]),
    ("on a snowy landscape", ["outdoor", "person"]),
    ("in the apartment", ["indoor", "person"]),
    ("in the library", ["indoor", "person"]),
]

SCENES_PROMPT = [
    "on the desert",
    "in the room",
    "on the street",
    "in the jungle",
    "on the road",
    "in the studio",
    "on the beach",
    "on a snowy landscape",
    "in the apartment",
    "in the library",
]

OBJECTS_CATEGORIES = {
    "animals": ['dog', 'mouse', 'sheep', 'cat', 'cow', 'chicken', 'turtle', 'giraffe', 'pig', 'butterfly', 'horse', 'bird', 'rabbit', 'frog', 'fish'],
    "indoor": ['bed', 'desk', 'key', 'chair', 'vase', 'candle', 'cup', 'phone', 'computer', 'bowl', 'sofa', 'balloon', 'plate', 'refrigerator', 'wallet', 'bag', 'painting', 'suitcase', 'table', 'couch', 'clock', 'book', 'lamp', 'television'],
    "outdoor": ["car", "motorcycle", "backpack", "bench", 'train', 'airplane', 'bicycle'],
    "person": ['woman', 'man', 'boy', 'girl'],
}

def generate_scene_combinations():
    """
    Generates all possible combinations of (scene, object1, 3d_relation, object2)
    based on the defined SCENES, OBJECTS_CATEGORIES, and RELATIONS.
    """
    all_combinations = []
    relations_3d = RELATIONS.get("3d", [])

    for scene_name, allowed_categories in SCENES:
        allowed_objects_for_scene = set() # Use a set to avoid duplicates
        for category in allowed_categories:
            if category in OBJECTS_CATEGORIES:
                # Use update to add all items from the list
                allowed_objects_for_scene.update(OBJECTS_CATEGORIES[category])
            else:
                print(f"  Warning: Category '{category}' not found in OBJECTS_CATEGORIES.")

        allowed_objects_list = list(allowed_objects_for_scene)

        # We use permutations because "A in front of B" is different from "B in front of A".
        object_pairs = itertools.permutations(allowed_objects_list, 2)
        # object_pairs = itertools.combination(allowed_objects_list, 2)

        for obj1, obj2 in object_pairs:
            for relation in relations_3d:
                combination = (obj1, relation, obj2)
                # We only need the (obj1, relation, obj2) tuple, scene context
                # isn't used until find_applicable_scenes (which isn't called in the main loop)
                # If you plan to use scene context later per combination, you might store it here.
                # For now, the original code only used obj1, relation, obj2 in the main loop.
                if combination not in all_combinations:
                    all_combinations.append(combination)

    return all_combinations

# Pre-calculate object to category mapping and all known objects
OBJECT_TO_CATEGORY = {}
for category, objects in OBJECTS_CATEGORIES.items():
    for obj in objects:
        OBJECT_TO_CATEGORY[obj] = category

all_known_objects = set(OBJECT_TO_CATEGORY.keys())

def find_applicable_scenes(prompt):
    """
    Determines which scene prompts are applicable to a given prompt based on the objects mentioned.
    (This function is not used in the original main loop, but kept for completeness)
    """
    prompt_lower = prompt.lower()
    found_objects = set()

    for obj in all_known_objects:
        pattern = r'\b' + re.escape(obj) + r'\b'
        if re.search(pattern, prompt_lower):
            found_objects.add(obj)

    if not found_objects:
        print(f"Warning: No known objects found in prompt: '{prompt}'")
        return []

    prompt_categories = set()
    for obj in found_objects:
        category = OBJECT_TO_CATEGORY.get(obj)
        if category:
            prompt_categories.add(category)

    if not prompt_categories:
        print(f"Warning: Could not determine categories for found objects: {found_objects}")
        return []

    applicable_scenes = []
    for scene_prompt, allowed_categories in SCENES:
        if prompt_categories.issubset(set(allowed_categories)):
            applicable_scenes.append(scene_prompt)

    return applicable_scenes

# --- Worker function for each combination ---
def process_combination(combination, data_path, json_path):
    """
    Processes a single object/relation combination: performs VLM request,
    saves JSON, builds and renders scene, saves images.
    Intended to be run by a thread.
    """
    obj1, relation, obj2 = combination
    caption = f"a {obj1} {relation} a {obj2}"
    entities = [obj1, obj2]

    # Check if processing is already complete (JSON file exists)
    json_output_path = f'{json_path}/{caption}.json'
    if os.path.exists(json_output_path):
        # print(f"Skipping '{caption}': JSON file already exists.")
        return f"Skipped: {caption}" # Indicate that it was skipped

    try:
        # Ensure output directory for images exists
        image_output_dir = f"{data_path}/{caption}"
        os.makedirs(image_output_dir, exist_ok=True)

        # Perform VLM Request
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": f"{gen_prompt.replace('<caption>', caption).replace('<entities>', json.dumps(entities))}"},
                ]
            }
        ]
        content = vlm_request(messages) # This is likely I/O bound
        answer = content.split('</think>')[-1]
        ans_json = extract_and_parse_json(answer)

        # Save JSON data
        data = {
            'caption': caption,
            'entities': entities,
            'ans_json': ans_json,
            'content': content,
        }
        with open(json_output_path, 'w') as f: # File I/O
            json.dump(data, f, indent=4)

        # Define scene parameters and build the scene
        # Note: Ensure DiffusionScene and its methods are thread-safe if they
        # rely on shared external resources (like a GPU context).
        # Standard CPU operations within render should be fine per thread.
        scene_size = ans_json['scene_parameters']['scene_size'] / 2
        cam_pitch_angle = 90 - ans_json['scene_parameters']['camera_pitch_angle']
        floor_offset = - scene_size / 5
        floor_scale_x = 1
        floor_scale_y = 1

        scene = DiffusionScene(scene_size=scene_size)
        # rotation_axis(x,z,y), translation(x, z, y)
        scene.move_camera(rotation_angle=cam_pitch_angle, rotation_axis=[1,0,0], translation=[0,0,0])
        scene.build_floor(scale_x=floor_scale_x, scale_y=floor_scale_y, floor_offset=floor_offset)

        for i, entity in enumerate(ans_json['entity_layout']):
            scene.add_box(id=f"box_{i}", size=entity['size'], origin=entity['position'], prompt=entity['entity_name'])
            # scene.box(f"box_{i}").rotate_left(entity['orient']) # Keep if needed

        # Render and save images
        depth_all = scene.render(single=True, floor=True, depth_max=4*scene_size) # May be compute or I/O
        for j, depth in enumerate(depth_all):
            Image.fromarray(depth).save(f'{image_output_dir}/render_depth_{j}.png') # File I/O

        depth_all = scene.render_bas() # May be compute or I/O
        for j, depth in enumerate(depth_all):
            Image.fromarray(depth).save(f'{image_output_dir}/bas_depth_{j}.png') # File I/O

        return f"Processed: {caption}" # Indicate success

    except Exception as e:
        # Log any errors for this specific combination
        print(f"\nError processing '{caption}': {e}", file=sys.stderr)
        return f"Error processing {caption}: {e}" # Return error info


# --- Main execution block using ThreadPoolExecutor ---
if __name__ == "__main__":
    data_path = 'data/render'
    json_path = 'data/json'
    os.makedirs(f"{data_path}", exist_ok=True)
    os.makedirs(f"{json_path}", exist_ok=True)

    all_combinations = generate_scene_combinations()
    new_all_combinations = []
    for combination in all_combinations:
        obj1, relation, obj2 = combination
        caption = f"a {obj1} {relation} a {obj2}"

        # Check if processing is already complete (JSON file exists)
        json_output_path = f'{json_path}/{caption}.json'
        if not os.path.exists(json_output_path):
            new_all_combinations.append(combination)
    all_combinations = sorted(new_all_combinations)
    print(f"Generated {len(all_combinations)} total combinations.")

    # Configure the number of worker threads
    # A reasonable number for I/O-bound tasks is often > number of CPU cores.
    # You might need to adjust this based on your VLM API rate limits,
    # file system speed, and system resources.
    MAX_WORKERS = 4

    print(f"Using ThreadPoolExecutor with {MAX_WORKERS} workers.")

    results = [] # Optional: to store results/status from each task

    # Use ThreadPoolExecutor to process combinations in parallel
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        # Submit tasks to the executor
        future_to_combination = {
            executor.submit(process_combination, combination, data_path, json_path): combination
            for combination in all_combinations
        }

        # Use tqdm to show progress as tasks complete
        for future in tqdm(concurrent.futures.as_completed(future_to_combination), total=len(all_combinations), desc="Processing combinations"):
            combination = future_to_combination[future]
            try:
                result = future.result() # Get the result of the completed task
                results.append(result)
                # You can optionally print the result here if needed
                # print(result)
            except Exception as exc:
                # Exception is already printed in the worker function,
                # but we can log it here too if needed.
                print(f'\nCombination {combination} generated an exception: {exc}', file=sys.stderr)
                results.append(f"Failed: {combination} - {exc}")

    print("\nAll combinations processing attempted.")