import argparse
import mteb
from mteb.benchmarks import Benchmark
from mteb.overview import get_tasks
import logging

# Set logging level for MTEB
logging.getLogger("mteb").setLevel(logging.INFO)

# === Custom Benchmark ===
imclass = Benchmark(
    name="ImClass",
    tasks=get_tasks(
        tasks=[
            # MSCOCO
            "MSCOCOT2IRetrieval",  
            "MSCOCOI2TRetrieval",
            # Flickr
            "Flickr30kT2IRetrieval",
            "Flickr30kI2TRetrieval",
            # # Stanford Cars
            # "StanfordCarsZeroShot",
            # # Oxford Pets
            # "OxfordPetsZeroShot",
            # # Food101
            # "Food101ZeroShot",
            # # EuroSAT
            # "EuroSATZeroShot",
            # # Resisc45
            # "RESISC45ZeroShot",
            # # FER2013
            # "FER2013ZeroShot",
            # # SUN397
            # "SUN397ZeroShot",
            # # Caltech101
            # "Caltech101ZeroShot",
            # # Country211
            # "Country211ZeroShot",
            # # DTD
            # "DTDZeroShot",
            # # Imagenet1k
            # "Imagenet1kZeroShot",
        ],
    ),
    description="""Subset of classification tasks.""",
)

def parse_args():
    parser = argparse.ArgumentParser(description="Run MTEB benchmarks on a given model")
    parser.add_argument(
        "--model-name", "-m",
        type=str,
        required=True,
        help="Name of the model to load via mteb.get_model"
    )
    parser.add_argument(
        "--batch-size", "-b",
        type=int,
        default=8,
        help="Batch size for encoding tasks"
    )
    parser.add_argument(
        "--output-folder", "-o",
        type=str,
        default="results",
        help="Folder to save the results"
    )
    return parser.parse_args()

def main():
    args = parse_args()

    model = mteb.get_model(
        model_name=args.model_name,
    )
    evaluator = mteb.MTEB(tasks=imclass)
    results_test = evaluator.run(
        model,
        output_folder=args.output_folder,
        verbosity=2,
        encode_kwargs={"batch_size": args.batch_size}
    )

if __name__ == "__main__":
    main()
