# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import torch.distributed as dist
from transformers import AutoTokenizer
from lm_eval import tasks, evaluator
from lm_eval.models.huggingface import HFLM
from utils.synchronize import synchronize
from config.config_template import ConfigTemplate
from evaluation.my_hf_wrapper import MyHFWrapper, MyHFConfig


def evaluation(config: ConfigTemplate, model):
    if dist.get_rank() == 0:
        print("\n\n\n\nEvaluation - Start\n\n\n\n")
    synchronize()

    # Step 1: Set the model to eval mode
    model.eval()

    # Step 2: Wrap the model
    print("Notice: Assuming the gpt2 tokenizer")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    wrapped_model = MyHFWrapper(
        model=model,
        batch_size_fwd=config.batch_size_fwd,
        num_class=config.num_class,
        max_allowed_num_token=config.context_window,
        hf_config=MyHFConfig(vocab_size=config.vocab_size)
    )
    wrapped_model = HFLM(
        pretrained=wrapped_model,
        tokenizer=tokenizer,
        backend="causal",
        device="cuda",
        batch_size=config.batch_size_fwd * dist.get_world_size(),
    )

    # Step 3: Run evaluation
    task_dict = tasks.get_task_dict([
        "hellaswag",
        "piqa",
        "arc_easy",
        "arc_challenge",
        "lambada_openai",
    ])
    results = evaluator.evaluate(wrapped_model, task_dict)["results"]

    synchronize()
    if dist.get_rank() == 0:
        print("\n\n\n\nEvaluation - End\n\n\n\n")
    synchronize()
    return results
