import os

import pytest

import lm_eval.api as api
import lm_eval.evaluator as evaluator
from lm_eval import tasks


@pytest.mark.parametrize(
    "limit,model,model_args",
    [
        (
            10,
            "hf",
            "pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
        ),
    ],
)
def test_include_correctness(limit: int, model: str, model_args: str):
    task_name = ["arc_easy"]

    task_manager = tasks.TaskManager()
    task_dict = tasks.get_task_dict(task_name, task_manager)

    e1 = evaluator.simple_evaluate(
        model=model,
        tasks=task_name,
        limit=limit,
        model_args=model_args,
    )
    assert e1 is not None

    # run with evaluate() and "arc_easy" test config (included from ./testconfigs path)
    lm = api.registry.get_model(model).create_from_arg_string(
        model_args,
        {
            "batch_size": None,
            "max_batch_size": None,
            "device": None,
        },
    )

    task_name = ["arc_easy"]

    task_manager = tasks.TaskManager(
        include_path=os.path.dirname(os.path.abspath(__file__)) + "/testconfigs",
        include_defaults=False,
    )
    task_dict = tasks.get_task_dict(task_name, task_manager)

    e2 = evaluator.evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
    )

    assert e2 is not None
    # check that caching is working

    def r(x):
        return x["results"]["arc_easy"]

    assert all(
        x == y
        for x, y in zip([y for _, y in r(e1).items()], [y for _, y in r(e2).items()])
    )


# test that setting include_defaults = False works as expected and that include_path works
def test_no_include_defaults():
    task_name = ["arc_easy"]

    task_manager = tasks.TaskManager(
        include_path=os.path.dirname(os.path.abspath(__file__)) + "/testconfigs",
        include_defaults=False,
    )
    # should succeed, because we've included an 'arc_easy' task from this dir
    task_dict = tasks.get_task_dict(task_name, task_manager)

    # should fail, since ./testconfigs has no arc_challenge task
    task_name = ["arc_challenge"]
    with pytest.raises(KeyError):
        task_dict = tasks.get_task_dict(task_name, task_manager)  # noqa: F841


# test that include_path containing a task shadowing another task's name fails
# def test_shadowed_name_fails():

#     task_name = ["arc_easy"]

#     task_manager = tasks.TaskManager(include_path=os.path.dirname(os.path.abspath(__file__)) + "/testconfigs")
#     task_dict = tasks.get_task_dict(task_name, task_manager)
