from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, Iterable, List, Sequence


class ZCFeatureStore:
    """Lightweight accessor for zero-cost NASBench201 features."""

    def __init__(self, data_path: str | Path | None = None) -> None:
        base_dir = Path(__file__).resolve().parent
        self.data_path = Path(data_path) if data_path is not None else base_dir / "data" / "nb201_zc_scores.json"
        if not self.data_path.exists():
            raise FileNotFoundError(f"Zero-cost feature file not found: {self.data_path}")

        with self.data_path.open("r") as handle:
            payload = json.load(handle)

        self.zc_names: List[str] = payload["zc_names"]
        self._datasets: Dict[str, Dict[str, Dict[str, Dict[str, float]]]] = payload["datasets"]

    def list_datasets(self) -> List[str]:
        return sorted(self._datasets.keys())

    def has_dataset(self, dataset: str) -> bool:
        return dataset in self._datasets

    def list_architectures(self, dataset: str) -> List[str]:
        data = self._get_dataset(dataset)
        return list(data.keys())

    def get_scores(self, dataset: str, arch: str) -> Dict[str, float]:
        record = self._get_arch_record(dataset, arch)
        return record["zc_scores"].copy()

    def get_feature_vector(self, dataset: str, arch: str, order: Sequence[str] | None = None) -> List[float]:
        order = list(order) if order is not None else self.zc_names
        scores = self.get_scores(dataset, arch)
        try:
            return [scores[name] for name in order]
        except KeyError as exc:
            missing = exc.args[0]
            raise KeyError(f"Feature '{missing}' not available for architecture {arch}") from exc

    def get_val_accuracy(self, dataset: str, arch: str) -> float:
        record = self._get_arch_record(dataset, arch)
        return record.get("val_accuracy")

    def export_dataset(self, dataset: str, order: Sequence[str] | None = None) -> Dict[str, List[float]]:
        order = list(order) if order is not None else self.zc_names
        data = self._get_dataset(dataset)
        features = []
        accuracies = []
        arch_ids = []
        for arch, record in data.items():
            scores = record["zc_scores"]
            features.append([scores[name] for name in order])
            accuracies.append(record.get("val_accuracy"))
            arch_ids.append(arch)
        return {"architectures": arch_ids, "features": features, "accuracies": accuracies, "feature_names": order}

    def iter_dataset(self, dataset: str) -> Iterable[tuple[str, Dict[str, float], float]]:
        data = self._get_dataset(dataset)
        for arch, record in data.items():
            yield arch, record["zc_scores"], record.get("val_accuracy")

    def _get_dataset(self, dataset: str):
        if dataset not in self._datasets:
            available = ", ".join(self.list_datasets())
            raise KeyError(f"Dataset '{dataset}' not found. Available datasets: {available}")
        return self._datasets[dataset]

    def _get_arch_record(self, dataset: str, arch: str):
        data = self._get_dataset(dataset)
        if arch not in data:
            raise KeyError(f"Architecture '{arch}' not found in dataset '{dataset}'")
        return data[arch]
