import gc

import pytest

from fishfarm.imports import try_import
from fishfarm.logging import get_logger
from fishfarm.tasks.csbench.data import load_dataset
from fishfarm.tasks.csbench.task import CSBenchTask


with try_import() as _imports:
    import torch
    import vllm

    from fishfarm.models.vllm_model import VLLMModel

logger = get_logger(__name__)

# https://csbench.github.io/#explorer
GROUNDTRUTH_ACCS = {
    "mistralai/Mistral-7B-Instruct-v0.2": {"multiple-choice": 0.5051, "assertion": 0.6077},
    "meta-llama/Meta-Llama-3-8B-Instruct": {"multiple-choice": 0.5573, "assertion": 0.6380},
    "google/gemma-2b-it": {"multiple-choice": 0.3874, "assertion": 0.5164},
}

"""
Absolute error tolerance. Fixed here to prevent regression bugs inside fishfarm.
The values below are extracted by running models on CS-Bench and caluclate the mismatch
with the groundtruth value above.

Setting error tolerance a posteriori is not ideal and fixing tolerance at e.g. 1% for
all the models would be better, but the slight difference like
* max_num_seqs in vllm model parameters
* presence/absence of whitespace in prompts
would lead to up to 1~2% difference in accuracies, so it is hard to exactly
reproduce values from the paper.
"""
ABS_TOLERANCE = {
    "mistralai/Mistral-7B-Instruct-v0.2": 5e-3,  # 0.5%
    "google/gemma-2b-it": 0.018,  # 1.8%
    "meta-llama/Meta-Llama-3-8B-Instruct": 0.028,  # 2.8%
}

MODELS = [
    "mistralai/Mistral-7B-Instruct-v0.2",
    "google/gemma-2b-it",
    "meta-llama/Meta-Llama-3-8B-Instruct",
]


@pytest.mark.skipif(not _imports.is_successful(), reason="Failed to import torch and/or vllm")
@pytest.mark.use_gpu
def test_fmt_accuracies() -> None:
    csbench_data = load_dataset()
    task = CSBenchTask(csbench_data)

    for model_name in MODELS:
        model = VLLMModel(
            vllm.LLM(
                model_name,
                max_model_len=4096,
                gpu_memory_utilization=0.9,
            ),
            # From Appendix D.3 of CS-Bench paper
            vllm.SamplingParams(
                temperature=0,
                top_p=1,
                max_tokens=2048,
            ),
            chat_template=None,
        )

        agg_mets = task.evaluate(model).aggregate_metrics
        logger.info(agg_mets)
        for fmt in ("multiple-choice", "assertion"):
            assert agg_mets[f"acc_format_{fmt}"] == pytest.approx(
                GROUNDTRUTH_ACCS[model_name][fmt], abs=ABS_TOLERANCE[model_name]
            )

        del model
        gc.collect()
        torch.cuda.empty_cache()
