import os
import json
import copy
import argparse
import random
import numpy as np
import torch


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--local_rank",
        type=int,
        default=0,
        help="Local rank for distributed training.",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        required=True,
        help="Path to the MMRB2 edit.json file.",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        required=True,
        help="Path to save the output results (JSON format).",
    )
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="The model to use.",
    )
    parser.add_argument(
        "--lora_path",
        type=str,
        default=None,
        help="Path to the LoRA weights.",
    )
    parser.add_argument(
        "--tensor_parallel_size",
        type=int,
        default=1,
        help="The tensor parallel size.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility.",
    )
    parser.add_argument(
        "--max_num_seqs",
        type=int,
        default=32,
        help="The maximum number of sequences.",
    )
    parser.add_argument(
        "--limit_mm_per_prompt_image",
        type=int,
        default=4,  # Changed from 2 to 4 to support multi-image fusion (3 inputs + 1 output)
        help="The limit of memory per prompt for image.",
    )
    parser.add_argument(
        "--max_model_len",
        type=int,
        default=4096,
        help="The maximum model length.",
    )
    parser.add_argument(
        "--max_num_batched_tokens",
        type=int,
        default=4096,
        help="The maximum number of batched tokens.",
    )
    parser.add_argument(
        "--gpu_memory_utilization",
        type=float,
        default=0.85,
        help="The GPU memory utilization.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.7,
        help="The temperature for sampling.",
    )
    parser.add_argument(
        "--top_p",
        type=float,
        default=0.9,
        help="The top_p for sampling.",
    )
    parser.add_argument(
        "--top_k",
        type=int,
        default=20,
        help="The top_k for sampling.",
    )
    parser.add_argument(
        "--max_tokens",
        type=int,
        default=512,
        help="The maximum number of tokens to generate.",
    )
    parser.add_argument(
        "--enable_prefix_caching",
        type=bool,
        default=True,
        help="Whether to enable prefix caching.",
    )
    parser.add_argument(
        "--enforce_eager",
        action="store_true",
        help="Whether to enforce eager execution.",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="The batch size for inference.",
    )
    parser.add_argument(
        "--distributed_executor_backend",
        type=str,
        default=None,
        help="Distributed executor backend (e.g., 'mp').",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        help="Data type (e.g., 'bfloat16').",
    )
    parser.add_argument(
        "--min_pixels",
        type=int,
        default=56 * 56,
        help="Minimum number of pixels for image processing.",
    )
    parser.add_argument(
        "--max_pixels",
        type=int,
        default=12845056,
        help="Maximum number of pixels for image processing.",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=4,
        help="Number of workers for data loading.",
    )
    parser.add_argument(
        "--score_range",
        type=int,
        default=25,
        help="Score range for evaluation.",
    )
    parser.add_argument(
        "--with_region",
        action="store_true",
        help="Whether to use region-based prompts.",
    )
    parser.add_argument(
        "--interleaved",
        action="store_true",
        help="Use interleaved reasoning format with <|bbox_id|> and <|global|> tokens.",
    )
    parser.add_argument(
        "--score_aggregation",
        type=str,
        default="min",
        choices=["min", "mean", "weighted_power"],
        help="Score aggregation method for SC dimensions: 'min' (take minimum), 'mean' (take average), or 'weighted_power' (weighted power formula).",
    )
    parser.add_argument(
        "--weighted_power_params",
        type=float,
        nargs=5,
        default=None,
        help="Five parameters for weighted_power aggregation: w1 w2 w3 w4 a. Formula: ((w1*s1+w2*s2)**a) * ((w3*s3+w4*s4)**(1-a))",
    )
    parser.add_argument(
        "--add_timestamp",
        action="store_true",
        help="Add timestamp to output filename to avoid overwriting existing results.",
    )
    parser.add_argument(
        "--single_image_only",
        action="store_true",
        default=False,  # Changed: Now evaluate all data by default
        help="Filter to single-image editing tasks only (exclude multi-image fusion tasks). Default: False (evaluate all)",
    )
    args = parser.parse_args()
    return args


def parse_llm_args(args):
    llm_kwargs = {
        "model": args.model,
        "max_num_seqs": args.max_num_seqs,
        "limit_mm_per_prompt": {"image": args.limit_mm_per_prompt_image},
        "tensor_parallel_size": args.tensor_parallel_size,
        "max_model_len": args.max_model_len,
        "gpu_memory_utilization": args.gpu_memory_utilization,
        "enable_prefix_caching": args.enable_prefix_caching,
        "enforce_eager": args.enforce_eager
    }
    
    if args.dtype is not None:
        llm_kwargs["dtype"] = args.dtype
    
    if args.distributed_executor_backend is not None:
        llm_kwargs["distributed_executor_backend"] = args.distributed_executor_backend
    
    return llm_kwargs


def read_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data


def back_envs():
    torchrun_vars = [
        'RANK', 'LOCAL_RANK', 'WORLD_SIZE', 'LOCAL_WORLD_SIZE', 'GROUP_RANK', 'ROLE_RANK', 'ROLE_NAME', 'GROUP_WORLD_SIZE', 'ROLE_WORLD_SIZE',
        'MASTER_ADDR', 'MASTER_PORT', 'TORCHELASTIC_RESTART_COUNT', 'TORCHELASTIC_MAX_RESTARTS', 'TORCHELASTIC_RUN_ID', 'TORCHELASTIC_USE_AGENT_STORE', 'TORCH_NCCL_ASYNC_ERROR_HANDLING'
    ]
    torchrun_vars_bak = {}
    for var in torchrun_vars:
        if var in os.environ:
            torchrun_vars_bak[var] = os.environ[var]
            del os.environ[var]
    return torchrun_vars_bak


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def save_data_to_cache(data_dict, cache_path, lock):
    with lock:
        data_copy = copy.deepcopy(data_dict)
    
    try:
        with open(cache_path, 'w', encoding='utf-8') as f:
            f.write(json.dumps(data_copy, ensure_ascii=False, indent=4))
        print(f"Data saved to {cache_path}")
    except Exception as e:
        print(f"Error saving data to {cache_path}: {e}")

