# [single_gpu, 44.207s] PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=16 --time=30000 python src/eval/_eval_rank_slow.py
# [multi_gpu,  12.805s] PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:8 --cpus-per-task=16 --time=30000 accelerate launch --config_file scripts/accelerate_configs/deepspeed_zero2.yaml --num_processes=8 src/eval/_eval_rank_slow.py
import re
import copy
import math

import torch
import transformers
import accelerate

from src import utils


def rank_matrix_rows_desc(tensor):
    # Negate the tensor to sort in descending order
    neg_tensor = -tensor

    # Sort lexicographically by applying argsort from the last column to the first
    indices = torch.arange(tensor.size(0), device=tensor.device)
    for col in reversed(range(tensor.size(1))):
        values = neg_tensor[:, col]
        sorted_idx = values[indices].argsort(stable=True)
        indices = indices[sorted_idx]

    return indices


def eval_rank(model, processor, data_config, eval_template = "all"):
    
    assert processor.tokenizer.padding_side == "right"

    templates = {
        f"{key}-{i}": tmpl 
        for key, tmpl_list in data_config["templates"].items() 
        for i, tmpl in enumerate(tmpl_list)
    }
    if eval_template in ("train", "eval"):
        templates = {
            k: v for k, v in templates.items() if k.startswith(f"{eval_template}-")
        }
    elif isinstance(eval_template, list):
        templates = {
            k: v for k, v in templates.items() if k in eval_template
        }

    results = {}

    for template_key, template in templates.items():

        results[template_key] = {
            "template": template,
            "rank": None,
            "meta": [],
        }

        option_key = re.search(r'\{(.*?)\}', template[1]).group(1)
        options_list = [mapping[option_key] for mapping in data_config["mapping"]]
        # WARNING: deduplication
        options_list = list(set(options_list))
        formatted_options = [template[1].format(**{option_key: option}) for option in options_list]
        
        # Ensure each process gets at least one sample
        num_processes = accelerate.PartialState().num_processes
        process_index = accelerate.PartialState().process_index

        # augment data_config["mapping"] (which is a python list) by rounding so that the length is divisible by num_processes
        if len(data_config["mapping"]) % num_processes != 0:
            # Repeat from start to pad until divisible
            num_to_add = num_processes - (len(data_config["mapping"]) % num_processes)
            data_config["mapping"].extend(data_config["mapping"][:num_to_add])

        assert len(data_config["mapping"]) % num_processes == 0, "Data is not evenly divisible by number of processes."

        num_data_per_process = len(data_config["mapping"]) // num_processes
        start_idx = process_index * num_data_per_process
        end_idx = (process_index + 1) * num_data_per_process
        mappings_this_device = data_config["mapping"][start_idx:end_idx]

        for mapping in mappings_this_device:

            prompt = template[0].format_map(utils.SafeDict(mapping))
            prompt_options = [prompt + formatted_option for formatted_option in formatted_options]

            if "{image_prefix}" in prompt_options[0]:

                if isinstance(processor, transformers.MllamaProcessor):
                    image_prefix = "<|image|><|begin_of_text|>"
                    image = mapping["image"]
                elif isinstance(processor, transformers.Gemma3Processor):
                    image_prefix = "<start_of_image> "
                    image = mapping["image"]
                elif isinstance(processor, (transformers.LlavaProcessor, transformers.LlavaNextProcessor)):
                    image_prefix = "USER: <image>\n ASSISTANT:"
                    image = mapping["image"]
                if isinstance(processor, (transformers.Qwen2VLProcessor, transformers.Qwen2_5_VLProcessor)):
                    import qwen_vl_utils
                    image_prefix = "<|vision_start|><|image_pad|><|vision_end|>"
                    image = qwen_vl_utils.fetch_image({"image": image})

                # instantiate "{image_prefix}"
                prompt = prompt.format(image_prefix=image_prefix)
                prompt_options = [prompt_option.format(
                    image_prefix=image_prefix) for prompt_option in prompt_options]
                images = [[image]] * len(prompt_options)

                prompt_options_batch = processor(
                    text=copy.deepcopy(prompt_options), 
                    images=images, 
                    return_tensors="pt", 
                    padding=True
                ).to(model.device)

                prompt_batch = processor(
                    text=[prompt], 
                    images=[[image]], 
                    return_tensors="pt", 
                    padding=True
                ).to(model.device)
                prompt_len = prompt_batch["input_ids"].size(1)

            else:

                prompt_options_batch = processor(
                    text=copy.deepcopy(prompt_options), 
                    return_tensors="pt", 
                    padding=True
                ).to(model.device)

                prompt_batch = processor(
                    text=[prompt], 
                    return_tensors="pt", 
                    padding=True
                ).to(model.device)
                prompt_len = prompt_batch["input_ids"].size(1)
            
            # if necessary, batchify here
            with torch.no_grad():
                outputs = model.forward(**prompt_options_batch)

            logits = outputs["logits"][:, prompt_len-1:-1]
            labels = prompt_options_batch["input_ids"][:, prompt_len:].clone()

            label_pad_mask = labels == processor.tokenizer.pad_token_id
            labels[labels == processor.tokenizer.pad_token_id] = 0

            per_token_logps = torch.gather(
                logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
            per_token_logps.masked_fill_(label_pad_mask, float("inf"))

            ranked_args = rank_matrix_rows_desc(per_token_logps)
            ranked_options = [options_list[arg] for arg in ranked_args]
            rank = ranked_args.tolist().index(options_list.index(mapping[option_key]))
            result = {
                "name": mapping["name"], 
                "gt": mapping[option_key], 
                "rank": rank, 
                "ranked_options": ranked_options
            }
            results[template_key]["meta"].append(result)

    gathered_results = accelerate.utils.gather_object([results])
    merged_results = {}

    for template_key in gathered_results[0].keys():
        merged_meta = []
        template = gathered_results[0][template_key]["template"]
        for partial_result in gathered_results:
            merged_meta.extend(partial_result[template_key]["meta"])
        
        # Calculate mean rank
        mean_rank = sum([entry["rank"] for entry in merged_meta]) / len(merged_meta)

        merged_results[template_key] = {
            "template": template,
            "rank": mean_rank,
            "meta": merged_meta
        }

    return merged_results
    

if __name__ == "__main__":
    from dataclasses import dataclass

    import tyro
    import pprint

    from src.utils import GpuTimer

    @dataclass
    class ScriptArguments:
        model_name_or_path: str = "/mnt/lustrenew/mllm_safety-shared/models/huggingface/meta-llama/Llama-3.2-11B-Vision"
        data_config_path: str = "data/animals/config_image.yaml"
        eval_template: str = "all"
        per_device_eval_batch_size: int = 100
        save_path: str = None

    script_args = tyro.cli(ScriptArguments)

    ################
    # Model, Processor
    ################
    print(f"evaluating {script_args.model_name_or_path}...")
    model = transformers.AutoModelForVision2Seq.from_pretrained(
        script_args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        # attn_implementation="flash_attention_2",
        # device_map="auto",
        device_map={"": accelerate.PartialState().local_process_index},
    )
    processor_kwargs={}
    if isinstance(model, transformers.Qwen2VLForConditionalGeneration):
        processor_kwargs["min_pixels"] = 32*28*28
        processor_kwargs["max_pixels"] = 128*28*28
    processor = transformers.AutoProcessor.from_pretrained(
        script_args.model_name_or_path,
        padding_side="right",
        **processor_kwargs,
    )
    data_config = utils.parse_data_config(script_args.data_config_path)
    train_config, eval_config = utils.parse_train_and_eval_config(data_config)

    with GpuTimer():
        results = {}
        results.update(**eval_rank(model, processor, train_config, script_args.eval_template))
        results.update(**eval_rank(model, processor, eval_config,  script_args.eval_template))
    pprint.pprint(results, depth=2, width=500)
