# PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=16 --time=30000 python src/tools/patches_sfilter.py
# PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:8 --cpus-per-task=16 --time=30000 python src/tools/patches_sfilter.py --model_name_or_path "/mnt/lustrenew/mllm_safety-shared/models/huggingface/meta-llama/Llama-3.2-90B-Vision"
# --- Imports ---
import os
import shutil
import pprint
import json
import pathlib
from dataclasses import dataclass

import torch
import transformers
import accelerate
import tyro

from src import utils
from src.eval import eval_rank
from src.eval.eval_rank_internvl import eval_rank as eval_rank_internvl
from src.utils import GpuTimer

# --- Argument Parsing ---
@dataclass
class ScriptArguments:
    model_name_or_path: str = "/mnt/lustrenew/mllm_safety-shared/models/huggingface/meta-llama/Llama-3.2-11B-Vision"
    data_config_path: str = "data/animals/test.yaml"
    data_overwrite_args: str = "data.eval[0].images_dirs[0]=tmp/data/animals/files/8x8"
    output_dir: str = "tmp/data/animals/files/others/sfilter/8x8/Llama-3.2-11B-Vision/unrecognizable"
    per_device_eval_batch_size: int = 8
    # ✅ Accept comma-separated thresholds (e.g., "1,3,5")
    rank_thresholds: str = "1,3,5"

script_args = tyro.cli(ScriptArguments)
assert accelerate.PartialState().num_processes == 1

if (pathlib.Path(script_args.output_dir) / "stats.json").exists(): exit()

# --- Model and Processor Loading ---
print(f"Evaluating model: {script_args.model_name_or_path}")
if "InternVL" in script_args.model_name_or_path:
    model = transformers.AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        # device_map="auto",
        device_map={"": accelerate.PartialState().local_process_index},
        trust_remote_code=True,
    )
else:
    model_config = transformers.PretrainedConfig.from_pretrained(script_args.model_name_or_path)
    model = getattr(transformers, model_config.architectures[0]).from_pretrained(
        script_args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )

# Additional processor settings for specific models
processor_kwargs = {"padding_side": "right", "trust_remote_code": True}
if isinstance(model, (transformers.Qwen2VLForConditionalGeneration, transformers.Qwen2_5_VLForConditionalGeneration)):
    processor_kwargs.update({
        "min_pixels": 32 * 28 * 28,
        "max_pixels": 128 * 28 * 28,
    })
if isinstance(model, transformers.LlavaForConditionalGeneration):
    processor_kwargs["add_prefix_space"] = True
if "InternVL" in script_args.model_name_or_path:
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        script_args.model_name_or_path,
        **processor_kwargs,
    )
else:
    processor = transformers.AutoProcessor.from_pretrained(
        script_args.model_name_or_path,
        **processor_kwargs,
    )

# --- Data Configuration ---
data_config = utils.parse_data_config(script_args.data_config_path, script_args.data_overwrite_args)
_, eval_configs = utils.parse_train_and_eval_config(data_config)

# --- Evaluation ---
with GpuTimer():
    if "InternVL" in script_args.model_name_or_path:
        results = eval_rank_internvl(
            model=model,
            tokenizer=tokenizer,
            data_config=eval_configs[0],
            per_device_eval_batch_size=script_args.per_device_eval_batch_size,
            tqdm_disable=False,
        )
    else:
        results = eval_rank(
            model=model,
            processor=processor,
            data_config=eval_configs[0],
            per_device_eval_batch_size=script_args.per_device_eval_batch_size,
            tqdm_disable=False,
        )

pprint.pprint(results, depth=2, width=500)

# Save Unrecognizable Images by Threshold
thresholds = [int(t.strip()) for t in script_args.rank_thresholds.split(",") if t.strip()]
if accelerate.PartialState().is_main_process:
    # thresholds = [int(t.strip()) for t in script_args.rank_thresholds.split(",") if t.strip()]
    for threshold in thresholds:
        save_dir = f"{script_args.output_dir}/threshold{threshold}"
        os.makedirs(save_dir, exist_ok=True)

        count = 0
        for item in results["0"]["meta"]:
            if item["rank"] >= threshold:
                src_path = item["path"]
                filename = os.path.basename(src_path)
                dest_path = os.path.join(save_dir, filename)
                shutil.copy(src_path, dest_path)
                count += 1

        print(f"Saved {count} images with rank >= {threshold} to {save_dir}")

    with open(pathlib.Path(script_args.output_dir) / "stats.json", "w", encoding="utf-8") as f:
        json.dump({}, f, ensure_ascii=False, indent=4)
