#!/usr/bin/env python
import argparse
import logging
import torch
import mteb
from mteb.model_meta import ModelMeta
from mteb.models.colvbert_models import ColVBertWrapper


def parse_args() -> argparse.Namespace:
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Run MTEB evaluation on a ColQwen 2.5 model."
    )
    parser.add_argument(
        "model_name",
        help="HF repo ID or local path of the model to evaluate.",
    )
    parser.add_argument(
        "--benchmarks",
        nargs="+",
        default=["ViDoRe(v1)", "ViDoRe(v2)"],
        metavar="BENCHMARK",
        help="One or more benchmark names (space-separated).",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=32,
        help="Batch size passed to encode_kwargs.",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()

    # --- Logging ---
    logging.getLogger("mteb").setLevel(logging.INFO)

    # --- Model metadata ---
    # name = args.model_name.split("/")[-2:] if len(args.model_name.split("/")) > 2 else args.model_name
    custom_model_meta = ModelMeta(
        loader=ColVBertWrapper,
        name=args.model_name,
        modalities=["image", "text"],
        framework=["ColPali"],
        similarity_fn_name="max_sim",
        use_instructions=True,
        # None metadata
        revision=None,
        release_date=None,
        languages=None,
        n_parameters=None,
        memory_usage_mb=None,
        max_tokens=None,
        embed_dim=128,
        license="apache-2.0",
        open_weights=True,
        public_training_code=None,
        public_training_data=None,
        training_datasets=None,
    )

    # --- Load model ---
    device = "cuda" if torch.cuda.is_available() else "cpu"
    custom_model = custom_model_meta.load_model(
        model_name=args.model_name,
        device=device,
        attn_implementation="flash_attention_2",
    )
    # custom_model.processor.image_processor.size["longest_edge"] = 1024

    # --- Load tasks ---
    tasks = mteb.get_benchmarks(names=args.benchmarks)
    evaluator = mteb.MTEB(tasks=tasks)

    # --- Run evaluation ---
    evaluator.run(
        model=custom_model,
        verbosity=2,
        encode_kwargs={"batch_size": args.batch_size},
    )


if __name__ == "__main__":
    main()
