import os
from typing import Optional

from transformers import (
    PreTrainedModel,
    PreTrainedTokenizer
)

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
# from lm_eval.tasks import initialize_tasks
from lm_eval.utils import make_table

TASK_TO_NUM_FEWSHOT = {
    "arc_challenge": 25,
    "hellaswag": 10,
    "truthfulqa": 0,
    "mmlu": 5,
    "winogrande": 5,
    "gsm8k": 5
}

def evaluate_fewshot(
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        task: str,
        num_fewshot: int,
        eval_batch_size: Optional[int] = 4,
        log: Optional[bool] = True,
        output_path: Optional[str] = None,
):
    # initialize_tasks(verbosity="WARNING")
    lm = HFLM(
        pretrained=model,
        tokenizer=tokenizer,
        batch_size=eval_batch_size,
        device_map="auto"
    )
    results = evaluator.simple_evaluate(
        model=lm,
        tasks=task,
        num_fewshot=num_fewshot,
        batch_size=eval_batch_size,
        random_seed=0,
        numpy_random_seed=1234,
        torch_random_seed=1234,
    )

    if log:
        print(make_table(results))
        if "groups" in results:
            print(make_table(results, "groups"))
    
    if output_path:
        f = open(output_path, "a", encoding='utf-8')
        print(make_table(results), file=f)
        if "groups" in results:
            print(make_table(results, "groups"), file=f)
        f.close()
        
    return results

def evaluation(model, tokenizer, task=None, result_path=None, num_fewshot=0):
    if result_path is not None:
        result_dir = result_path.split("/")[:-1]
        result_dir = "/".join(result_dir)
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)

    if isinstance(task, str):
        evaluate_fewshot(
            model, tokenizer=tokenizer, task=task, num_fewshot=num_fewshot, output_path=result_path, log=True
        )
    else:
        tasks = ["winogrande", "arc_challenge", "arc_easy", "boolq", "hellaswag", "mmlu", "openbookqa", "rte"]
        eval_size = [64, 64, 64, 32, 64, 16, 32, 32]
        for i, t in enumerate(tasks):
            evaluate_fewshot(
                model, tokenizer=tokenizer, task=t, num_fewshot=num_fewshot, eval_batch_size=eval_size[i], output_path=result_path, log=True
            )