from __future__ import annotations

import csv
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple

from pel_nas.data.arch_converter import arch_str_to_tuple

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class ArchitectureRecord:
    """Lightweight container for architecture metrics sourced from CSV."""

    index: int
    arch_str: str
    arch_tuple: Tuple[int, ...]
    dataset: str
    test_accuracy: float
    latencies: Dict[str, Optional[float]]

    def latency_for(self, device: str) -> Optional[float]:
        return self.latencies.get(device)


class MetricsRepository:
    """In-memory repository that serves NAS metrics from a consolidated CSV file."""

    _instances: Dict[Path, "MetricsRepository"] = {}

    def __init__(self, csv_path: str | Path) -> None:
        self.csv_path = Path(csv_path)
        if not self.csv_path.exists():
            raise FileNotFoundError(f"Metrics CSV not found: {self.csv_path}")

        self._records_by_dataset: Dict[str, Dict[Tuple[int, ...], ArchitectureRecord]] = {}
        self._records_list: Dict[str, List[ArchitectureRecord]] = {}
        self._dataset_aliases: Dict[str, str] = {}
        self._devices: List[str] = []

        self._load_csv()

    @classmethod
    def shared(cls, csv_path: str | Path) -> "MetricsRepository":
        path = Path(csv_path).resolve()
        if path not in cls._instances:
            cls._instances[path] = cls(path)
        return cls._instances[path]

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------
    @property
    def datasets(self) -> List[str]:
        return sorted(self._records_by_dataset.keys())

    @property
    def devices(self) -> List[str]:
        return list(self._devices)

    def iter_dataset(self, dataset: str) -> Iterable[ArchitectureRecord]:
        ds = self._normalize_dataset(dataset)
        return list(self._records_list.get(ds, []))

    def get_record(self, dataset: str, arch_str: str) -> Optional[ArchitectureRecord]:
        arch_tuple = arch_str_to_tuple(arch_str)
        if arch_tuple is None:
            return None
        return self.get_record_by_tuple(dataset, arch_tuple)

    def get_record_by_tuple(self, dataset: str, arch_tuple: Tuple[int, ...]) -> Optional[ArchitectureRecord]:
        ds = self._normalize_dataset(dataset)
        dataset_records = self._records_by_dataset.get(ds)
        if not dataset_records:
            return None
        return dataset_records.get(tuple(arch_tuple))

    def list_records(self, dataset: str) -> List[ArchitectureRecord]:
        return list(self.iter_dataset(dataset))

    def _normalize_dataset(self, dataset: str) -> str:
        if not dataset:
            return dataset
        key = dataset.lower()
        return self._dataset_aliases.get(key, dataset)

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------
    def _load_csv(self) -> None:
        logger.info("Loading architecture metrics from %s", self.csv_path)

        with self.csv_path.open("r", newline="") as handle:
            reader = csv.DictReader(handle)
            if reader.fieldnames is None:
                raise ValueError(f"CSV file {self.csv_path} has no header row")

            latency_columns = [col for col in reader.fieldnames if col.endswith("_latency")]
            self._devices = [col[:-len("_latency")] for col in latency_columns]

            for row in reader:
                dataset = row.get("dataset")
                arch_str = row.get("arch_str")
                if not dataset or not arch_str:
                    continue

                arch_tuple = arch_str_to_tuple(arch_str)
                if arch_tuple is None:
                    logger.debug("Skipping unparsable architecture string: %s", arch_str)
                    continue

                latencies: Dict[str, Optional[float]] = {}
                for col in latency_columns:
                    device = col[:-len("_latency")]
                    lat_value = self._safe_float(row.get(col))
                    latencies[device] = lat_value

                record = ArchitectureRecord(
                    index=self._safe_int(row.get("index"), default=-1),
                    arch_str=arch_str,
                    arch_tuple=tuple(arch_tuple),
                    dataset=dataset,
                    test_accuracy=self._safe_float(row.get("test_accuracy")),
                    latencies=latencies,
                )

                ds_key = dataset
                self._records_by_dataset.setdefault(ds_key, {})[record.arch_tuple] = record
                self._records_list.setdefault(ds_key, []).append(record)
                self._dataset_aliases.setdefault(ds_key.lower(), ds_key)

        logger.info(
            "Loaded %d datasets with %d total architectures",
            len(self._records_by_dataset),
            sum(len(v) for v in self._records_by_dataset.values()),
        )

    @staticmethod
    def _safe_float(value, default: Optional[float] = None) -> Optional[float]:
        try:
            if value is None or value == "":
                return default
            result = float(value)
            if result != result:  # NaN
                return default
            return result
        except (TypeError, ValueError):
            return default

    @staticmethod
    def _safe_int(value, default: int = -1) -> int:
        try:
            return int(float(value))
        except (TypeError, ValueError):
            return default


__all__ = ["ArchitectureRecord", "MetricsRepository"]

