# NOTE: although this is faster, it only works when the first token of each option is distinct
# [single_gpu, 7.588s] PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=16 --time=30000 python src/eval/eval_prob.py
# [multi_gpu,  4.000s] PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:8 --cpus-per-task=16 --time=30000 accelerate launch --config_file scripts/accelerate_configs/deepspeed_zero2.yaml --num_processes=8 src/eval/eval_prob.py
import re
import copy

import tqdm
import torch
import transformers
import accelerate

from src import utils


def eval_prob(
    model, 
    processor, 
    data_config, 
    per_device_eval_batch_size = 100,
    tqdm_disable = True,
) -> dict:
    
    assert processor.tokenizer.padding_side == "right"

    results = {}

    for template_key, template in enumerate(data_config["templates"]):
        template_key = str(template_key)

        results[template_key] = {
            "template": template,
            "prob": None,
            "meta": [],
        }

        option_key = re.search(r'\{(.*?)\}', template[1]).group(1)
        options_list = [mapping[option_key] for mapping in data_config["mapping"]]
        # WARNING: deduplication
        options_list = list(set(options_list))
        formatted_options = [
            template[1].format(**{option_key: option})
            for option in options_list
        ]
        first_token_of_options = [
            processor.tokenizer.encode(formatted_option, add_special_tokens=False)[0]
            for formatted_option in formatted_options
        ]

        if len(set(first_token_of_options)) != len(first_token_of_options):
            # Identify conflicting options
            token_to_options = {}
            for option, token in zip(options_list, first_token_of_options):
                if token not in token_to_options:
                    token_to_options[token] = []
                token_to_options[token].append(option)

            # Collect conflicts (options sharing the same first token)
            conflicts = {
                token: options
                for token, options in token_to_options.items()
                if len(options) > 1
            }

            # Construct a detailed error message
            conflict_details = "\n".join(
                f"  - Token {token} is shared by options: {', '.join(options)}"
                for token, options in conflicts.items()
            )
            error_message = (
                "First tokens of the options are not unique. Fast ranking requires uniqueness.\n"
                "Conflicts found:\n" + conflict_details
            )

            raise ValueError(error_message)

        # Ensure each process gets at least one sample
        num_processes = accelerate.PartialState().num_processes
        process_index = accelerate.PartialState().process_index

        # Ensure the length of data_config["mapping"] is divisible by num_processes
        if len(data_config["mapping"]) < num_processes:
            # Repeat the mapping elements enough times to reach or exceed num_processes
            repeat_factor = (num_processes // len(data_config["mapping"])) + 1
            data_config["mapping"].extend(data_config["mapping"] * repeat_factor)

        # Now pad the list to make it divisible by num_processes
        if len(data_config["mapping"]) % num_processes != 0:
            num_to_add = num_processes - (len(data_config["mapping"]) % num_processes)
            data_config["mapping"].extend(data_config["mapping"][:num_to_add])

        assert len(data_config["mapping"]) % num_processes == 0, "Data is not evenly divisible by number of processes."

        num_data_per_process = len(data_config["mapping"]) // num_processes
        start_idx = process_index * num_data_per_process
        end_idx = (process_index + 1) * num_data_per_process
        mappings_this_device = data_config["mapping"][start_idx:end_idx]

        for i in tqdm.tqdm(
            range(0, len(mappings_this_device), per_device_eval_batch_size),
            disable=tqdm_disable or not accelerate.PartialState().is_main_process
        ):
            batch_mappings = mappings_this_device[i:i + per_device_eval_batch_size]
            prompts = [
                template[0].format_map(utils.SafeDict(mapping))
                for mapping in batch_mappings
            ]

            if "{image_prefix}" in prompts[0]:

                if isinstance(processor, transformers.MllamaProcessor):
                    image_prefix = "<|image|><|begin_of_text|>"
                    images = [[mapping["image"]] for mapping in batch_mappings]
                elif isinstance(processor, transformers.Gemma3Processor):
                    image_prefix = "<start_of_image> "
                    images = [[mapping["image"]] for mapping in batch_mappings]
                elif isinstance(processor, (transformers.LlavaProcessor, transformers.LlavaNextProcessor)):
                    image_prefix = "USER: <image>\n ASSISTANT:"
                    images = [[mapping["image"]] for mapping in batch_mappings]
                elif isinstance(processor, (transformers.Qwen2VLProcessor, transformers.Qwen2_5_VLProcessor)):
                    import qwen_vl_utils
                    image_prefix = "<|vision_start|><|image_pad|><|vision_end|>"
                    images = [
                        [qwen_vl_utils.fetch_image({"image": mapping["image"]})]
                        for mapping in batch_mappings
                    ]
                else:
                    raise NotImplementedError

                prompts = [prompt.format(image_prefix=image_prefix) for prompt in prompts]

                prompt_batch = processor(
                    text=copy.deepcopy(prompts),
                    images=images,
                    return_tensors="pt",
                    padding=True
                ).to(model.device)

            else:

                prompt_batch = processor(
                    text=copy.deepcopy(prompts),
                    return_tensors="pt",
                    padding=True
                ).to(model.device)

            with torch.no_grad():
                outputs = model(**prompt_batch)

            lens = (prompt_batch["input_ids"] != processor.tokenizer.pad_token_id).sum(-1)
            pooled_logits = outputs["logits"][
                torch.arange(lens.size(0), device=model.device), lens - 1]
            pooled_logits = pooled_logits[:, first_token_of_options]
            pooled_probs = pooled_logits.softmax(-1).cpu().tolist()

            for mapping, pooled_prob in zip(batch_mappings, pooled_probs):
                result = {
                    "name": mapping["name"],
                    "path": mapping.get("path", mapping["name"]),
                    "gt": mapping[option_key],
                    "prob": pooled_prob[options_list.index(mapping[option_key])],
                    "pooled_probs": dict(zip(options_list, pooled_prob)),
                }
                results[template_key]["meta"].append(result)

    # After all templates processed, gather and merge results from all devices
    gathered_results = accelerate.utils.gather_object([results])
    merged_results = {}

    for template_key in gathered_results[0].keys():
        merged_meta = []
        template = gathered_results[0][template_key]["template"]
        # ✅ Deduplicate based on `path`, keep the first occurrence
        seen_names = set()
        for partial_result in gathered_results:
            for entry in partial_result[template_key]["meta"]:
                if entry["path"] not in seen_names:
                    merged_meta.append(entry)
                    seen_names.add(entry["path"])

        mean_prob = sum([entry["prob"] for entry in merged_meta]) / len(merged_meta)
        merged_results[template_key] = {
            "template": template,
            "prob": mean_prob,
            "meta": merged_meta,
        }

    return merged_results
    

if __name__ == "__main__":
    from dataclasses import dataclass

    import tyro
    import pprint

    from src.utils import GpuTimer

    @dataclass
    class ScriptArguments:
        model_name_or_path: str = "/mnt/lustrenew/mllm_safety-shared/models/huggingface/meta-llama/Llama-3.2-11B-Vision"
        data_config_path: str = "data/animals/config_image.yaml"
        data_overwrite_args: str = "" # e.g. --data_overwrite_args "data.train[0].images_dirs[0]=/new/path/to/images,..."
        per_device_eval_batch_size: int = 8
        save_path: str = None

    script_args = tyro.cli(ScriptArguments)

    ################
    # Model, Processor
    ################
    print(f"evaluating {script_args.model_name_or_path}...")
    model = transformers.AutoModelForVision2Seq.from_pretrained(
        script_args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        # attn_implementation="flash_attention_2",
        # device_map="auto",
        device_map={"": accelerate.PartialState().local_process_index},
    )
    processor_kwargs={}
    if isinstance(model, (transformers.Qwen2VLForConditionalGeneration, transformers.Qwen2_5_VLForConditionalGeneration)):
        processor_kwargs["min_pixels"] = 32*28*28
        processor_kwargs["max_pixels"] = 128*28*28
    if isinstance(model, transformers.LlavaForConditionalGeneration):
        processor_kwargs["add_prefix_space"] = True
    processor = transformers.AutoProcessor.from_pretrained(
        script_args.model_name_or_path,
        padding_side="right",
        **processor_kwargs,
    )
    data_config = utils.parse_data_config(script_args.data_config_path, script_args.data_overwrite_args)
    train_configs, eval_configs = utils.parse_train_and_eval_config(data_config)

    with GpuTimer():
        results = {}

        # First handle train_configs
        for train_idx, train_config in enumerate(train_configs):
            if train_config is None:
                continue  # skip if train_config is None
            partial_results = eval_prob(
                model,
                processor,
                train_config,
                script_args.per_device_eval_batch_size,
                tqdm_disable=False
            )
            for template_key, template_result in partial_results.items():
                new_key = f"train-{train_idx}.{template_key}"
                results[new_key] = template_result

        # Then handle eval_configs
        for eval_idx, eval_config in enumerate(eval_configs):
            if eval_config is None:
                continue  # skip if eval_config is None
            partial_results = eval_prob(
                model,
                processor,
                eval_config,
                script_args.per_device_eval_batch_size,
                tqdm_disable=False
            )
            for template_key, template_result in partial_results.items():
                new_key = f"eval-{eval_idx}.{template_key}"
                results[new_key] = template_result

    pprint.pprint(results, depth=2, width=500)
