from __future__ import annotations

from dataclasses import dataclass
from typing import Any, List, Optional

from utils.config import load_yaml_config


@dataclass
class DatasetSpec:
    """Configuration for a single training dataset."""

    name: str
    hf_path: str
    task: str
    subset: Optional[str] = None
    split: str = "train"
    source: str = "hf_hub"
    max_samples: Optional[int] = None
    weight: float = 1.0

    @staticmethod
    def from_dict(raw: dict[str, Any]) -> "DatasetSpec":
        required = {"name", "hf_path", "task"}
        missing = required.difference(raw.keys())
        if missing:
            raise ValueError(f"Dataset specification missing required keys: {missing}")

        return DatasetSpec(
            name=raw["name"],
            hf_path=raw["hf_path"],
            task=raw["task"],
            subset=raw.get("subset"),
            split=raw.get("split", "train"),
            source=raw.get("source", "hf_hub"),
            max_samples=raw.get("max_samples"),
            weight=float(raw.get("weight", 1.0)),
        )


def load_dataset_specs(config: dict[str, Any]) -> List[DatasetSpec]:
    datasets_cfg = config.get("datasets")
    if datasets_cfg is None:
        raise ValueError("data.datasets must be provided in the training configuration.")

    specs: List[DatasetSpec] = []
    for item in datasets_cfg:
        specs.append(DatasetSpec.from_dict(item))
    return specs


def load_training_config(path: str) -> dict[str, Any]:
    return load_yaml_config(path)

