import json
from pathlib import Path
from typing import Any

from datasets import Dataset


def load_umwp_dataset(path: Path, answerable_only: bool = False) -> Dataset:
    """Loads the UMWP (Unanswerable Math Word Problem) dataset.

    The UMWP dataset contains 5200 math word problems (2600 answerable, 2600 unanswerable)
    designed to evaluate LLM hallucination in Question Answering tasks.

    Categories:
    - 0: Answerable questions with complete information
    - 1-5: Unanswerable questions with different types of missing/ambiguous information

    Reference: https://github.com/Yuki-Asuuna/UMWP
    Paper: https://arxiv.org/abs/2403.03558

    Args:
        path: Path to the StandardDataset.jsonl file
        answerable_only: If True, only return answerable questions.
                        If False (default), return all questions.

    Returns:
        Dataset with columns:
        - id: Sequential index for this dataset
        - question_id: Original ID from UMWP dataset
        - question: The math word problem text
        - answer: The answer (or "unanswerable" for unanswerable questions)
        - answerable: Boolean indicating if the question is answerable
        - category: Category (0=answerable, 1-5=unanswerable types)
        - source: Original dataset source (e.g., SVAMP, MAWPS)
    """
    records = _load_jsonl(path)

    if answerable_only:
        records = [r for r in records if r["answerable"] is True]

    for i, record in enumerate(records):
        record["question_id"] = record.pop("id")
        record["id"] = i
        record["answer"] = _extract_answer(record["answer"])

    dataset = Dataset.from_list(records)

    return dataset


def _load_jsonl(path: Path) -> list[dict[str, Any]]:
    records = []
    with path.open() as f:
        for line in f:
            records.append(json.loads(line))
    return records


def _extract_answer(answer: list[float] | float | None) -> str:
    if answer is None:
        return "unanswerable"
    elif isinstance(answer, list):
        return str(answer[0]) if answer else "unanswerable"
    return str(answer)
