
import json
import random
import os
import sys
import cv2
import numpy as np
import ast
from typing import List, Dict, Any, Optional
from PIL import Image, ImageDraw
import torch
from types import SimpleNamespace

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)


try:
    from transformers import AutoProcessor
    from vllm import LLM, SamplingParams
    import uuid
    from tqdm import tqdm
    VLLM_AVAILABLE = True
except ImportError:
    print("Warning: Could not import model libraries, instruction modification may not work")
    AutoProcessor = None
    LLM = None
    SamplingParams = None
    uuid = None
    tqdm = None
    VLLM_AVAILABLE = False


def calculate_dynamic_mask_size(
    width: int,
    height: int,
    base_width: int = 720,
    base_mask_size: int = 115,
    min_mask_size: int = 50,
    max_mask_size: int = 300
) -> int:
    scale_factor = width / base_width
    dynamic_size = int(base_mask_size * scale_factor)
    
    dynamic_size = max(min_mask_size, min(dynamic_size, max_mask_size))
    
    return dynamic_size


class MaskProcessor:
    
    def __init__(self, mask_object_ratio: Optional[float] = None, use_dynamic_size: bool = True):
        self.mask_object_ratio = mask_object_ratio
        self.use_dynamic_size = use_dynamic_size
    
    def get_mask_size(self, image_width: int, image_height: int) -> float:
        if self.use_dynamic_size and self.mask_object_ratio is None:
            return calculate_dynamic_mask_size(image_width, image_height)
        elif self.mask_object_ratio is not None:
            return self.mask_object_ratio
        else:
            return calculate_dynamic_mask_size(image_width, image_height)
    
    def remove_containing_bboxes(self, bbox_list, gt, image_size):
        click_x, click_y = gt[0] / 1000 * image_size[0], gt[1] / 1000 * image_size[1]
        out_bbox_list = []
        in_bbox_list = []
        if len(bbox_list) > 0:
            for bbox in bbox_list:
                x, y, w, h = bbox
                if not (x <= click_x <= x+w and y <= click_y <= y+h):
                    out_bbox_list.append(bbox)
                else:
                    in_bbox_list.append(bbox)
        return out_bbox_list, in_bbox_list, (click_x, click_y)
    
    def get_bbox_data(self, obs, image_width=None, image_height=None):
        dataset_name = obs.get('dataset_name', '')
        
        if dataset_name == 'AndroidControl':
            accessibility_trees_file_path = obs.get('accessibility_trees')
            bbox_data = []
            if accessibility_trees_file_path and os.path.exists(accessibility_trees_file_path):
                with open(accessibility_trees_file_path, "r", encoding="utf-8") as file:
                    for line in file:
                        try:
                            obj = json.loads(line)
                            bbox = obj.get("bbox_pixels", None)
                            class_name = obj.get("class_name", None)
                            if bbox and (class_name == 'android.widget.ImageButton' or 
                                       class_name == 'android.widget.TextView' or 
                                       class_name == 'android.widget.ImageView') and obj.get("is_clickable"):
                                x_min, y_min, x_max, y_max = bbox["x_min"], bbox["y_min"], bbox["x_max"], bbox["y_max"]
                                if image_width and image_height:
                                    if (0 <= x_min < x_max <= image_width and
                                        0 <= y_min < y_max <= image_height):
                                        bbox_data.append([x_min, y_min, x_max-x_min, y_max-y_min])
                                else:
                                    bbox_data.append([x_min, y_min, x_max-x_min, y_max-y_min])
                        except Exception:
                            continue
        elif dataset_name == 'AITZ':
            accessibility_trees_file_path = obs.get('accessibility_trees')
            bbox_data = []
            if accessibility_trees_file_path and os.path.exists(accessibility_trees_file_path):
                with open(accessibility_trees_file_path, "r", encoding="utf-8") as file:
                    accessibility_trees_file_data = json.load(file)
                image_path = obs.get('images', [None])[0]
                if image_path:
                    for idx, acc_data in enumerate(accessibility_trees_file_data):
                        if acc_data.get('image_path') in image_path:
                            bbox = ast.literal_eval(accessibility_trees_file_data[idx]['ui_positions'])
                            bbox_data = [[y, x, h, w] for (x, y, w, h) in bbox]
                            break
        else:
            bbox_data = obs.get('bbox')
            if bbox_data and len(bbox_data) == 4:
                if image_width and image_height:
                    x_min = bbox_data[0]/1000*image_width
                    y_min = bbox_data[1]/1000*image_height
                    x_max = bbox_data[2]/1000*image_width
                    y_max = bbox_data[3]/1000*image_height
                    bbox_data = [[x_min, y_min, x_max-x_min, y_max-y_min]]
                else:
                    bbox_data = [[bbox_data[0], bbox_data[1], bbox_data[2]-bbox_data[0], bbox_data[3]-bbox_data[1]]]
            else:
                bbox_data = []
        
        return bbox_data
    
    def get_gt_coordinates(self, obs, image_width=None, image_height=None):
        gt_action = obs.get('gt_action', {})
        if not gt_action:
            return None
        
        action = gt_action.get('action', '')
        
        if action in ['click', 'long_press']:
            coord = gt_action.get('coordinate')
            if coord and len(coord) >= 2:
                x, y = coord[0], coord[1]
                if image_width and image_height:
                    norm_x = x / image_width * 1000
                    norm_y = y / image_height * 1000
                    return [norm_x, norm_y]
                else:
                    return [x, y]
        
        return None
    
    def apply_mask(self, image_path: str, obs: Dict[str, Any]) -> Image.Image:
        image_input = Image.open(image_path).convert('RGB')
        draw = ImageDraw.Draw(image_input)
        image_width, image_height = image_input.size[0], image_input.size[1]
        
        bbox_data = self.get_bbox_data(obs, image_width=image_width, image_height=image_height)
        
        gt = self.get_gt_coordinates(obs, image_width=image_width, image_height=image_height)
        if gt is None:
            print(f"Warning: Cannot get gt coordinates for {image_path}, skipping mask")
            return image_input
        
        _, bbox_list, point = self.remove_containing_bboxes(
            bbox_list=bbox_data, 
            gt=gt, 
            image_size=[image_width, image_height]
        )
        
        if len(bbox_list) > 0:
            for bbox in bbox_list:
                x, y, w, h = bbox
                draw.rectangle([x, y, x+w, y+h], fill="black")
        
        r = self.get_mask_size(image_width, image_height)
        x1 = max(0, int(point[0] - r))
        y1 = max(0, int(point[1] - r))
        x2 = min(image_width, int(point[0] + r))
        y2 = min(image_height, int(point[1] + r))
        if x2 > x1 and y2 > y1:
            draw.rectangle([x1, y1, x2, y2], fill="black")
        
        strip_y1 = max(0, int(point[1] - r))
        strip_y2 = min(image_height, int(point[1] + r))
        if strip_y2 > strip_y1:
            draw.rectangle([0, strip_y1, image_width, strip_y2], fill="black")
        
        return image_input
    
    def apply_inpaint(self, image_path: str, obs: Dict[str, Any]) -> Image.Image:
        image_input = Image.open(image_path).convert('RGB')
        image_width, image_height = image_input.size[0], image_input.size[1]
        
        image_cv = np.array(image_input)
        image_cv = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
        
        bbox_data = self.get_bbox_data(obs, image_width=image_width, image_height=image_height)
        
        gt = self.get_gt_coordinates(obs, image_width=image_width, image_height=image_height)
        if gt is None:
            print(f"Warning: Cannot get gt coordinates for {image_path}, skipping inpaint")
            return image_input
        
        _, bbox_list, point = self.remove_containing_bboxes(
            bbox_list=bbox_data, 
            gt=gt, 
            image_size=[image_width, image_height]
        )
        
        if len(bbox_list) > 0:
            for bbox in bbox_list:
                mask = np.zeros(image_cv.shape[:2], dtype=np.uint8)
                if len(bbox) == 4:
                    x, y, w, h = bbox
                    mask[int(y):int(y+h), int(x):int(x+w)] = 255
                else:
                    continue
                image_cv = cv2.inpaint(image_cv, mask, 3, cv2.INPAINT_TELEA)
        
        r = self.get_mask_size(image_width, image_height)
        mask = np.zeros(image_cv.shape[:2], dtype=np.uint8)
        x, y = point
        
        x_min = max(0, int(x - r))
        y_min = max(0, int(y - r))
        x_max = min(image_width, int(x + r))
        y_max = min(image_height, int(y + r))
        if x_max > x_min and y_max > y_min:
            mask[y_min:y_max, x_min:x_max] = 255
            image_cv = cv2.inpaint(image_cv, mask, 3, cv2.INPAINT_TELEA)
        
        image_cv = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)
        image_input = Image.fromarray(image_cv)
        
        return image_input


def build_instruction_modification_prompt(original_instruction: str) -> str:
    prompt = f"""You are a language expert tasked with modifying user instructions to create a completely different, unrelated task.

Original instruction: {original_instruction}

Your task:
1. Analyze the original instruction to understand what task it describes
2. Generate a NEW instruction that describes a COMPLETELY DIFFERENT and UNRELATED task
3. The new instruction should be about a different app, different action, or different goal
4. The new instruction should be realistic and natural, but have nothing to do with the original task
5. Keep the instruction format similar (e.g., if original is about mobile app, new one should also be about mobile app, but different app/action)

Requirements:
- The new instruction must describe a task that is COMPLETELY UNRELATED to the original task
- Do not paraphrase or modify the original - create a completely new task instead
- The instruction should be natural, realistic, and grammatically correct
- Only output the new instruction, without any explanation or additional text
- Make sure the new task is clearly different from the original (different app, different action, different goal)

New unrelated instruction:"""
    
    return prompt


class InstructionModifier:
    
    def __init__(self, model_path: Optional[str] = None, 
                 batch_size: int = 512,
                 device_ids: str = "[0]",
                 tensor_parallel_size: Optional[int] = None,
                 seed: int = 42):
        self.model_path = model_path
        self.batch_size = batch_size
        self.seed = seed
        
        if not VLLM_AVAILABLE or not model_path:
            self.llm = None
            self.processor = None
            self.sampling_params = None
            print("Warning: vLLM not available or model_path not provided, instruction modification will be skipped")
            return
        
        try:
            device_ids_list = eval(device_ids)
        except:
            device_ids_list = [0]
        
        if tensor_parallel_size is None:
            tensor_parallel_size = len(device_ids_list)
        
        if device_ids_list:
            os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids_list))
        
        print(f"[InstructionModifier] Model path: {self.model_path}")
        print(f"[InstructionModifier] Devices: {device_ids_list}, tensor_parallel_size: {tensor_parallel_size}")
        print(f"[InstructionModifier] Batch size: {self.batch_size}")
        
        self._init_vllm_engine(tensor_parallel_size)
        
        print(f"[InstructionModifier] Loading processor...")
        self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
        print(f"[InstructionModifier] Processor loaded")
    
    def _init_vllm_engine(self, tensor_parallel_size: int):
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
        
        print(f"[InstructionModifier] Initializing vLLM engine...")
        
        engine_args = {
            "model": self.model_path,
            "tensor_parallel_size": tensor_parallel_size,
            "seed": self.seed,
            "max_model_len": 16384,
        }
        
        self.llm = LLM(**engine_args)
        self.sampling_params = SamplingParams(
            temperature=0.7,
            top_p=0.9,
            max_tokens=512
        )
        
        print(f"[InstructionModifier] vLLM engine initialized")
    
    def modify_instructions_batch(self, data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        if not self.llm:
            print("Warning: vLLM not initialized, returning original data")
            return data_list
        
        vllm_inputs = []
        sample_metadata = []
        
        print(f"[InstructionModifier] Building vLLM inputs for {len(data_list)} samples...")
        
        for sample in tqdm(data_list, desc="Building inputs"):
            original_instruction = sample.get('instruction', '')
            if not original_instruction:
                vllm_inputs.append({"prompt": ""})
                sample_metadata.append({"sample": sample, "error": "Missing instruction"})
                continue
            
            prompt_text = build_instruction_modification_prompt(original_instruction)
            
            messages = [
                {
                    "role": "user",
                    "content": [{"type": "text", "text": prompt_text}]
                }
            ]
            
            try:
                prompt = self.processor.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=True
                )
            except Exception as e:
                prompt = prompt_text
            
            vllm_input = {"prompt": prompt}
            vllm_inputs.append(vllm_input)
            sample_metadata.append({"sample": sample})
        
        print(f"[InstructionModifier] Starting vLLM batch inference for {len(vllm_inputs)} samples...")
        all_outputs = []
        
        try:
            total_batches = (len(vllm_inputs) + self.batch_size - 1) // self.batch_size
            for batch_idx in tqdm(range(0, len(vllm_inputs), self.batch_size), desc="vLLM inference batches"):
                batch_end = min(batch_idx + self.batch_size, len(vllm_inputs))
                batch_inputs = vllm_inputs[batch_idx:batch_end]
                
                batch_num = batch_idx // self.batch_size + 1
                print(f"[InstructionModifier] Processing batch {batch_num}/{total_batches}")
                
                try:
                    batch_outputs = self.llm.generate(batch_inputs, self.sampling_params)
                    all_outputs.extend(batch_outputs)
                except Exception as e:
                    print(f"[InstructionModifier] Batch {batch_num} failed: {e}")
                    class EmptyOutputItem:
                        def __init__(self):
                            self.text = ""
                    class EmptyRequestOutput:
                        def __init__(self):
                            self.outputs = [EmptyOutputItem()]
                    
                    for _ in range(len(batch_inputs)):
                        all_outputs.append(EmptyRequestOutput())
            
            print(f"[InstructionModifier] vLLM inference completed")
        except Exception as e:
            print(f"[InstructionModifier] vLLM inference failed: {e}")
            return data_list
        
        processed_data = []
        for output, metadata in zip(all_outputs, sample_metadata):
            sample = metadata["sample"]
            new_item = sample.copy()
            
            if "error" in metadata:
                new_item['trap_mode'] = 'instruction_modify'
                processed_data.append(new_item)
                continue
            
            try:
                output_text = output.outputs[0].text.strip()
                
                modified_instruction = output_text.split('\n')[0].strip()
                if not modified_instruction:
                    modified_instruction = sample.get('instruction', '')
                
                original_instruction = sample.get('instruction', '')
                new_item['instruction'] = modified_instruction
                
                if 'messages' in new_item and len(new_item['messages']) > 0:
                    for msg in new_item['messages']:
                        if msg.get('role') == 'user' and 'content' in msg:
                            content = msg['content']
                            if isinstance(content, str) and original_instruction in content:
                                msg['content'] = content.replace(original_instruction, modified_instruction)
                
                new_item['trap_mode'] = 'instruction_modify'
                
                processed_data.append(new_item)
            except Exception as e:
                print(f"[InstructionModifier] Failed to parse output: {e}")
                new_item['trap_mode'] = 'instruction_modify'
                processed_data.append(new_item)
        
        return processed_data


def get_relative_path_from_datasets(image_path: str) -> str:
    prefixes = [
        "/DATASET_PREFIX_1",
        "/DATASET_PREFIX_2",
    ]
    
    for prefix in prefixes:
        if image_path.startswith(prefix):
            return image_path[len(prefix):]
    
    return image_path.lstrip("/")


def batch_process_mask(data_list: List[Dict[str, Any]], 
                       mask_object_ratio: Optional[float] = None,
                       use_dynamic_size: bool = True,
                       output_image_dir: Optional[str] = None) -> List[Dict[str, Any]]:
    if output_image_dir is None:
        output_image_dir = "/OUTPUT_IMAGE_DIR"
    
    processor = MaskProcessor(mask_object_ratio=mask_object_ratio)
    processed_data = []
    
    for item in data_list:
        try:
            new_item = item.copy()
            
            image_path = item.get('images', [None])[0]
            if not image_path or not os.path.exists(image_path):
                print(f"Warning: Image not found: {image_path}, skipping")
                continue
            
            masked_image = processor.apply_mask(image_path, item)
            
            relative_path = get_relative_path_from_datasets(image_path)
            relative_dir = os.path.dirname(relative_path)
            base_name = os.path.basename(relative_path)
            name, ext = os.path.splitext(base_name)
            output_dir = os.path.join(output_image_dir, relative_dir)
            os.makedirs(output_dir, exist_ok=True)
            output_path = os.path.join(output_dir, f"{name}_mask{ext}")
            masked_image.save(output_path)
            
            new_item['old_images'] = item.get('images', [])
            new_item['images'] = [output_path]
            
            new_item['trap_mode'] = 'mask'
            
            processed_data.append(new_item)
        except Exception as e:
            print(f"Error processing item: {e}")
            continue
    
    return processed_data


def batch_process_inpaint(data_list: List[Dict[str, Any]], 
                         mask_object_ratio: Optional[float] = None,
                         use_dynamic_size: bool = True,
                         output_image_dir: Optional[str] = None) -> List[Dict[str, Any]]:
    if output_image_dir is None:
        output_image_dir = "/OUTPUT_IMAGE_DIR"
    
    processor = MaskProcessor(mask_object_ratio=mask_object_ratio)
    processed_data = []
    
    for item in data_list:
        try:
            new_item = item.copy()
            
            image_path = item.get('images', [None])[0]
            if not image_path or not os.path.exists(image_path):
                print(f"Warning: Image not found: {image_path}, skipping")
                continue
            
            inpainted_image = processor.apply_inpaint(image_path, item)
            
            relative_path = get_relative_path_from_datasets(image_path)
            relative_dir = os.path.dirname(relative_path)
            base_name = os.path.basename(relative_path)
            name, ext = os.path.splitext(base_name)
            output_dir = os.path.join(output_image_dir, relative_dir)
            os.makedirs(output_dir, exist_ok=True)
            output_path = os.path.join(output_dir, f"{name}_inpaint{ext}")
            inpainted_image.save(output_path)
            
            new_item['old_images'] = item.get('images', [])
            new_item['images'] = [output_path]
            
            new_item['trap_mode'] = 'inpaint'
            
            processed_data.append(new_item)
        except Exception as e:
            print(f"Error processing item: {e}")
            continue
    
    return processed_data


def batch_process_instruction_modify(data_list: List[Dict[str, Any]],
                                     model_path: Optional[str] = None,
                                     batch_size: int = 512,
                                     device_ids: str = "[0,1,2,3,4,5,6,7]",
                                     tensor_parallel_size: Optional[int] = None,
                                     seed: int = 42) -> List[Dict[str, Any]]:
    if not model_path:
        print("Warning: model_path not provided, returning original data")
        return data_list
    
    modifier = InstructionModifier(
        model_path=model_path,
        batch_size=batch_size,
        device_ids=device_ids,
        tensor_parallel_size=tensor_parallel_size,
        seed=seed
    )
    
    processed_data = modifier.modify_instructions_batch(data_list)
    
    return processed_data


def create_trap_dataset(
    original_data_list: List[Dict[str, Any]],
    trap_config: Optional[Dict[str, Any]] = None,
    output_image_dir: Optional[str] = None,
    random_seed: int = 42
) -> List[Dict[str, Any]]:
    if trap_config is None:
        total_count = len(original_data_list)
        trap_config = {
                "mask": {
                    "ratio": 0.33,
                    "mask_object_ratio": None,
                    "use_dynamic_size": True
                },
                "inpaint": {
                    "ratio": 0.33,
                    "mask_object_ratio": None,
                    "use_dynamic_size": True
                },
                "instruction_modify": {
                    "ratio": 0.33,
                    "model_path": "/MODEL_PATH",
                    "batch_size": 512,
                    "device_ids": "[0,1,2,3,4,5,6,7]",
                    "tensor_parallel_size": 8,
                    "seed": 42
                }
        }
    
    random.seed(random_seed)
    
    total_count = len(original_data_list)
    mask_count = int(trap_config.get("mask", {}).get("count") or 
                    total_count * trap_config.get("mask", {}).get("ratio", 0.33))
    inpaint_count = int(trap_config.get("inpaint", {}).get("count") or 
                       total_count * trap_config.get("inpaint", {}).get("ratio", 0.33))
    instruction_count = int(trap_config.get("instruction_modify", {}).get("count") or 
                           total_count * trap_config.get("instruction_modify", {}).get("ratio", 0.33))
    
    total_trap_count = mask_count + inpaint_count + instruction_count
    if total_trap_count > total_count:
        scale = total_count / total_trap_count
        mask_count = int(mask_count * scale)
        inpaint_count = int(inpaint_count * scale)
        instruction_count = total_count - mask_count - inpaint_count
    
    print(f"Trap data allocation:")
    print(f"  Mask: {mask_count} samples")
    print(f"  Inpaint: {inpaint_count} samples")
    print(f"  Instruction Modify: {instruction_count} samples")
    print(f"  Total: {mask_count + inpaint_count + instruction_count} samples")
    
    shuffled_data = original_data_list.copy()
    random.shuffle(shuffled_data)
    
    mask_data = shuffled_data[:mask_count]
    inpaint_data = shuffled_data[mask_count:mask_count + inpaint_count]
    instruction_data = shuffled_data[mask_count + inpaint_count:mask_count + inpaint_count + instruction_count]
    
    trap_data = []
    
    if mask_count > 0:
        print(f"\nProcessing mask data...")
        mask_config = trap_config.get("mask", {})
        processed_mask = batch_process_mask(
            mask_data,
            mask_object_ratio=mask_config.get("mask_object_ratio"),
            use_dynamic_size=mask_config.get("use_dynamic_size", True),
            output_image_dir=output_image_dir
        )
        trap_data.extend(processed_mask)
    
    if inpaint_count > 0:
        print(f"\nProcessing inpaint data...")
        inpaint_config = trap_config.get("inpaint", {})
        processed_inpaint = batch_process_inpaint(
            inpaint_data,
            mask_object_ratio=inpaint_config.get("mask_object_ratio"),
            use_dynamic_size=inpaint_config.get("use_dynamic_size", True),
            output_image_dir=output_image_dir
        )
        trap_data.extend(processed_inpaint)
    
    if instruction_count > 0:
        print(f"\nProcessing instruction modification data...")
        instruction_config = trap_config.get("instruction_modify", {})
        processed_instruction = batch_process_instruction_modify(
            instruction_data,
            model_path=instruction_config.get("model_path", "/MODEL_PATH"),
            batch_size=instruction_config.get("batch_size", 512),
            device_ids=instruction_config.get("device_ids", "[0,1,2,3,4,5,6,7]"),
            tensor_parallel_size=instruction_config.get("tensor_parallel_size", 8),
            seed=instruction_config.get("seed", 42)
        )
        trap_data.extend(processed_instruction)
    
    random.shuffle(trap_data)
    
    return trap_data


def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="Create trap dataset")
    parser.add_argument("--input", type=str, default='/INPUT_FILE',
                       help="Input JSON file path")
    parser.add_argument("--output", type=str, default='/OUTPUT_FILE',
                       help="Output JSON file path")
    parser.add_argument("--output_image_dir", type=str, default='/OUTPUT_IMAGE_DIR',
                       help="Output image directory")
    parser.add_argument("--mask_ratio", type=float, default=0.35,
                       help="Mask ratio")
    parser.add_argument("--inpaint_ratio", type=float, default=0.35,
                       help="Inpaint ratio")
    parser.add_argument("--instruction_ratio", type=float, default=0.3,
                       help="Instruction modification ratio")
    parser.add_argument("--mask_object_ratio", type=float, default=None,
                       help="Mask object ratio")
    parser.add_argument("--model_path", type=str, default="/MODEL_PATH",
                       help="Model path for instruction modification")
    parser.add_argument("--batch_size", type=int, default=512,
                       help="Batch size")
    parser.add_argument("--device_ids", type=str, default="[0,1,2,3,4,5,6,7]",
                       help="CUDA device IDs")
    parser.add_argument("--tensor_parallel_size", type=int, default=8,
                       help="Tensor parallel size")
    parser.add_argument("--random_seed", type=int, default=42,
                       help="Random seed")
    
    args = parser.parse_args()
    
    print(f"Reading data: {args.input}")
    with open(args.input, 'r', encoding='utf-8') as f:
        original_data = json.load(f)
    
    print(f"Total {len(original_data)} original samples")
    
    trap_config = {
        "mask": {
            "ratio": args.mask_ratio,
            "mask_object_ratio": args.mask_object_ratio,
            "use_dynamic_size": True
        },
        "inpaint": {
            "ratio": args.inpaint_ratio,
            "mask_object_ratio": args.mask_object_ratio,
            "use_dynamic_size": True
        },
        "instruction_modify": {
            "ratio": args.instruction_ratio,
            "model_path": args.model_path,
            "batch_size": args.batch_size,
            "device_ids": args.device_ids,
            "tensor_parallel_size": args.tensor_parallel_size,
            "seed": args.random_seed
        }
    }
    
    trap_data = create_trap_dataset(
        original_data,
        trap_config=trap_config,
        output_image_dir=args.output_image_dir,
        random_seed=args.random_seed
    )
    
    print(f"\nSaving trap data to: {args.output}")
    with open(args.output, 'w', encoding='utf-8') as f:
        json.dump(trap_data, f, ensure_ascii=False, indent=2)
    
    print(f"Completed! Generated {len(trap_data)} trap samples")
    
    trap_types = {}
    for item in trap_data:
        trap_mode = item.get("trap_mode", "unknown")
        trap_types[trap_mode] = trap_types.get(trap_mode, 0) + 1
    
    print("\nTrap data statistics:")
    for trap_mode, count in trap_types.items():
        print(f"  {trap_mode}: {count} samples")


if __name__ == "__main__":
    main()

