"""Lightweight compatibility API backed by the consolidated metrics CSV."""

from __future__ import annotations

import logging
import random
from dataclasses import asdict
from pathlib import Path
from typing import Dict, Iterable, List, Optional

from pel_nas.core.config import DATA_CONFIG
from pel_nas.data.metrics_repository import ArchitectureRecord, MetricsRepository

logger = logging.getLogger(__name__)


class SimplifiedNASBenchAPI:
    """Compatibility shim offering a subset of the original NAS-Bench API."""

    def __init__(self, metrics_csv: Optional[str] = None, *_args, **_kwargs) -> None:
        csv_path = metrics_csv or DATA_CONFIG.get("metrics_csv")
        if not csv_path:
            raise RuntimeError("No metrics CSV configured for SimplifiedNASBenchAPI")

        self.metrics_repo = MetricsRepository.shared(csv_path)
        self._index_lookup: Dict[str, Dict[int, ArchitectureRecord]] = {
            dataset: {record.index: record for record in self.metrics_repo.list_records(dataset)}
            for dataset in self.metrics_repo.datasets
        }

    # ------------------------------------------------------------------
    # Basic accessors
    # ------------------------------------------------------------------
    def __len__(self) -> int:
        return sum(len(records) for records in self._index_lookup.values())

    @property
    def datasets(self) -> List[str]:
        return self.metrics_repo.datasets

    @property
    def devices(self) -> List[str]:
        return self.metrics_repo.devices

    # ------------------------------------------------------------------
    # Query helpers
    # ------------------------------------------------------------------
    def get_architecture_info(
        self,
        arch_index: int,
        dataset: str = "cifar10",
        hardware_device: str = "edgegpu",
    ) -> Optional[Dict]:
        record = self._index_lookup.get(dataset, {}).get(arch_index)
        if not record:
            logger.debug("Index %s not found in dataset %s", arch_index, dataset)
            return None

        latency = record.latency_for(hardware_device)
        return {
            "arch_index": record.index,
            "arch_str": record.arch_str,
            "flops": None,
            "params": None,
            "accuracy": record.test_accuracy,
            "latency": latency if latency is not None else 0.0,
            "energy": None,
            "dataset": dataset,
            "hardware_device": hardware_device,
            "latencies": record.latencies,
        }

    def query_by_arch_string(
        self,
        arch_str: str,
        dataset: str = "cifar10",
        hardware_device: str = "edgegpu",
    ) -> Optional[Dict]:
        record = self.metrics_repo.get_record(dataset, arch_str)
        if not record:
            logger.debug("Architecture string not found in dataset %s", dataset)
            return None
        return self.get_architecture_info(record.index, dataset, hardware_device)

    def sample_random_architectures(
        self,
        n_samples: int,
        dataset: str = "cifar10",
        hardware_device: str = "edgegpu",
    ) -> List[Dict]:
        records = self.metrics_repo.list_records(dataset)
        if not records:
            return []
        chosen = random.sample(records, min(n_samples, len(records)))
        return [self.get_architecture_info(record.index, dataset, hardware_device) for record in chosen]

    def get_compute_costs(self, arch_index: int, dataset: str = "cifar10") -> Optional[Dict]:
        # Compute cost information (params/FLOPs) is not available in the CSV export.
        logger.debug("Compute costs unavailable for CSV-backed repository")
        return None


# Backwards compatibility alias
NASBenchAPI = SimplifiedNASBenchAPI


def create_nas_api(nas_bench_201_path: str | None = None, hw_nas_bench_path: str | None = None) -> SimplifiedNASBenchAPI:
    """Factory retained for source compatibility. Paths are ignored."""
    return SimplifiedNASBenchAPI(metrics_csv=DATA_CONFIG.get("metrics_csv"))

