import dataclasses
from typing import List

import pytest
from pytest_mock import MockerFixture

from fishfarm.models.base import GenerationRequest, GenerationResult, Model
from fishfarm.tasks.csbench.data import CSBenchSample, load_dataset
from fishfarm.tasks.csbench.task import CSBenchTask


@pytest.fixture
def csbench_samples() -> List[CSBenchSample]:
    example_sample = CSBenchSample(
        index=0,
        question="",
        answer="A",
        domain="Computer Network",
        sub_domain="Application Layer",
        format="Multiple-choice",
        tag="Knowledge",
        choice_desc=dict(A="", B="", C="", D=""),
    )
    samples = [
        dataclasses.replace(example_sample, index=0, format="Multiple-choice", answer="A"),
        dataclasses.replace(example_sample, index=1, format="Multiple-choice", answer="B"),
        dataclasses.replace(
            example_sample, index=2, format="Assertion", answer=False, choice_desc=None
        ),
        dataclasses.replace(
            example_sample, index=3, format="Assertion", answer="True", choice_desc=None
        ),
    ]
    return samples


@pytest.fixture
def llm_results() -> List[GenerationResult]:
    dummy_req = GenerationRequest(messages=[])
    outputs: List[GenerationResult] = [
        GenerationResult(generation="The answer is A.", request=dummy_req),  # LLM answer is A
        GenerationResult(
            generation="The answer is not B, but A.", request=dummy_req
        ),  # LLM answer is currently parsed as B
        GenerationResult(generation="false", request=dummy_req),  # LLM answer is False
        GenerationResult(
            generation="That is not true, the answer is False.", request=dummy_req
        ),  # LLM answer is currently parsed as True
    ]

    return outputs


def test_fmt_accuracies(
    mocker: MockerFixture,
    llm_results: List[GenerationResult],
    csbench_samples: List[CSBenchSample],
) -> None:
    task = CSBenchTask(csbench_samples)

    model = mocker.create_autospec(Model)
    model.generate.return_value = llm_results

    agg_mets = task.evaluate(model).aggregate_metrics

    model.generate.assert_called_once_with([sample.to_request() for sample in csbench_samples])

    for fmt in ("multiple-choice", "assertion"):
        assert agg_mets[f"acc_format_{fmt}"] == pytest.approx(1)


def test_load_dataset() -> None:
    datasets = load_dataset()
    assert len(datasets) == 2183

    dataset_sizes = {
        "Assertion": 442,
        "Multiple-choice": 1336,
        "Fill-in-the-blank": 235,
        "Open-ended": 170,
    }
    for fmt, dataset_size in dataset_sizes.items():
        assert len([dataset for dataset in datasets if dataset.format == fmt]) == dataset_size
