from __future__ import annotations

import logging
import random
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple

from pel_nas.core.config import DATA_CONFIG, EVAL_CONFIG, PREDICTOR_CONFIG, SEARCH_CONFIG
from pel_nas.data.arch_converter import arch_str_to_tuple
from pel_nas.data.metrics_repository import ArchitectureRecord, MetricsRepository
from pel_nas.data.nas_bench_api import SimplifiedNASBenchAPI
from pel_nas.search.architecture_analyzer import ArchitectureAnalyzer

try:  # Optional dependency; gracefully degraded when unavailable.
    from zc_predictor import ZCPredictor  # type: ignore
except ImportError:  # pragma: no cover - runtime guard
    ZCPredictor = None  # type: ignore


logger = logging.getLogger(__name__)


@dataclass
class ArchitectureInfo:
    """Container encapsulating evaluation results for a single architecture."""

    arch_index: int
    arch_str: str
    accuracy: float
    latency: float
    conv_category: str
    conv_3x3_count: int
    conv_1x1_count: int
    dataset: str
    hardware_device: str
    source: str = "metrics"
    is_valid: bool = True
    measured_accuracy: Optional[float] = None
    predicted_accuracy: Optional[float] = None
    latencies: Dict[str, Optional[float]] = field(default_factory=dict)
    arch_tuple: Optional[Tuple[int, ...]] = None
    params: Optional[float] = None
    flops: Optional[float] = None
    energy: Optional[float] = None
    metadata: Dict[str, float] = field(default_factory=dict)

    def to_dict(self) -> Dict:
        return {
            "arch_index": self.arch_index,
            "arch_str": self.arch_str,
            "accuracy": self.accuracy,
            "latency": self.latency,
            "conv_category": self.conv_category,
            "conv_3x3_count": self.conv_3x3_count,
            "conv_1x1_count": self.conv_1x1_count,
            "dataset": self.dataset,
            "hardware_device": self.hardware_device,
            "source": self.source,
            "is_valid": self.is_valid,
            "measured_accuracy": self.measured_accuracy,
            "predicted_accuracy": self.predicted_accuracy,
            "latencies": self.latencies,
            "arch_tuple": list(self.arch_tuple) if self.arch_tuple is not None else None,
            "params": self.params,
            "flops": self.flops,
            "energy": self.energy,
            "metadata": self.metadata,
        }


class ArchitectureEvaluator:
    """Evaluate NAS architectures using the consolidated CSV and optional predictor."""

    def __init__(
        self,
        dataset: str = "cifar10",
        hardware_device: str = "edgegpu",
        use_predictor: bool = True,
    ) -> None:
        self.dataset = dataset
        self.hardware_device = hardware_device
        self.request_predictor = use_predictor and bool(PREDICTOR_CONFIG.get("enabled", True))

        csv_path = DATA_CONFIG.get("metrics_csv")
        if not csv_path:
            raise RuntimeError("DATA_CONFIG['metrics_csv'] is not configured")

        self.metrics_repo = MetricsRepository.shared(csv_path)
        self.api = SimplifiedNASBenchAPI(metrics_csv=csv_path)
        self.analyzer = ArchitectureAnalyzer()

        if hardware_device not in self.metrics_repo.devices:
            logger.warning(
                "Device %s not present in metrics CSV. Available devices: %s",
                hardware_device,
                ", ".join(self.metrics_repo.devices),
            )

        self.predictor = self._initialize_predictor() if self.request_predictor else None
        self.use_predictor = self.predictor is not None and self.request_predictor

        self.conv_categories = SEARCH_CONFIG.get("conv_categories", {})
        self.conv_category_archs: Dict[str, List[Dict]] = {}
        self._categorize_architectures()

        logger.info(
            "Evaluator ready for dataset=%s, device=%s, predictor=%s",
            self.dataset,
            self.hardware_device,
            "enabled" if self.use_predictor else "disabled",
        )

    # ------------------------------------------------------------------
    # Initialization helpers
    # ------------------------------------------------------------------
    def _initialize_predictor(self) -> Optional[ZCPredictor]:  # type: ignore[override]
        if ZCPredictor is None:
            logger.warning("Zero-cost predictor package not available; disabling predictor mode.")
            return None

        feature_path = DATA_CONFIG.get("zero_cost_features")
        if not feature_path:
            logger.warning("Zero-cost feature path not configured; disabling predictor mode.")
            return None

        model_path = self._resolve_model_path()
        if model_path is None:
            logger.warning(
                "No predictor model path configured for %s; disabling predictor mode.",
                self.dataset,
            )
            return None

        feature_order = PREDICTOR_CONFIG.get("feature_order")

        try:
            predictor = ZCPredictor(
                data_path=feature_path,
                model_path=model_path,
                feature_order=feature_order,
            )
            logger.info("Loaded zero-cost predictor from %s", model_path)
            return predictor
        except Exception as exc:  # pragma: no cover - defensive logging
            logger.error("Failed to initialize predictor: %s", exc)
            return None

    def _resolve_model_path(self) -> Optional[Path]:
        configured = PREDICTOR_CONFIG.get("model_paths", {}).get(self.dataset)
        candidates: List[Path] = []

        if configured:
            candidates.append(Path(configured))

        default_dir = Path("zc_predictor") / "models"
        if default_dir.exists():
            for ext in (".json", ".bin", ".xgb", ".model"):
                candidates.append(default_dir / f"{self.dataset}{ext}")

        for candidate in candidates:
            if candidate.exists():
                return candidate
        return None

    def _categorize_architectures(self) -> None:
        self.conv_category_archs = {category: [] for category in self.conv_categories.keys()}
        self.conv_category_archs.setdefault("unknown", [])

        for record in self.metrics_repo.iter_dataset(self.dataset):
            analysis = self.analyzer.analyze_architecture(record.arch_str)
            category = analysis.get("category") or "unknown"
            entry = {"record": record, "analysis": analysis}
            self.conv_category_archs.setdefault(category, []).append(entry)

        for category, entries in self.conv_category_archs.items():
            logger.debug("Category %s populated with %d candidate architectures", category, len(entries))

    # ------------------------------------------------------------------
    # Public evaluation helpers
    # ------------------------------------------------------------------
    def select_initial_architectures(self) -> List[ArchitectureInfo]:
        selected: List[ArchitectureInfo] = []

        for category, config in self.conv_categories.items():
            entries = list(self.conv_category_archs.get(category, []))
            target_count = config.get("target_count", 0)

            if not entries:
                logger.warning("Category %s has no available architectures for initialization", category)
                continue

            if len(entries) > target_count:
                chosen = random.sample(entries, target_count)
            else:
                chosen = entries

            logger.info(
                "Selected %d/%d initial architectures for category %s",
                len(chosen),
                target_count,
                category,
            )

            for entry in chosen:
                info = self._build_architecture_info(entry["record"], entry["analysis"])
                if info:
                    selected.append(info)

        return selected

    def evaluate_batch_by_category(
        self, category_architectures: Dict[str, Iterable[str]]
    ) -> Dict[str, List[ArchitectureInfo]]:
        results: Dict[str, List[ArchitectureInfo]] = {}
        total_valid = 0
        total_evaluated = 0

        for category, arch_strings in category_architectures.items():
            evaluated: List[ArchitectureInfo] = []
            for arch_str in arch_strings:
                info = self.evaluate_architecture_string(arch_str)
                if info:
                    evaluated.append(info)
                    total_evaluated += 1
                    if info.is_valid:
                        total_valid += 1
            results[category] = evaluated
            logger.info(
                "Category %s → %d/%d valid architectures",
                category,
                sum(1 for item in evaluated if item.is_valid),
                len(evaluated),
            )

        logger.info("Batch evaluation summary: %d/%d valid architectures", total_valid, total_evaluated)
        return results

    def evaluate_batch(self, arch_strings: Iterable[str]) -> List[ArchitectureInfo]:
        evaluated: List[ArchitectureInfo] = []
        for arch_str in arch_strings:
            info = self.evaluate_architecture_string(arch_str)
            if info:
                evaluated.append(info)
        logger.info(
            "Batch evaluation processed %d architectures (%d valid)",
            len(evaluated),
            sum(1 for arch in evaluated if arch.is_valid),
        )
        return evaluated

    def evaluate_architecture_string(self, arch_str: str) -> Optional[ArchitectureInfo]:
        analysis = self.analyzer.analyze_architecture(arch_str)
        record = self.metrics_repo.get_record(self.dataset, arch_str)

        if record:
            return self._build_architecture_info(record, analysis)

        logger.warning("Architecture not found in metrics CSV: %s", arch_str)
        return self._build_missing_architecture_info(arch_str, analysis)

    def get_category_stats(self) -> Dict[str, Dict]:
        stats: Dict[str, Dict] = {}
        for category, config in self.conv_categories.items():
            stats[category] = {
                "available_count": len(self.conv_category_archs.get(category, [])),
                "target_count": config.get("target_count", 0),
                "description": config.get("description", ""),
            }
        stats["unknown"] = {"available_count": len(self.conv_category_archs.get("unknown", []))}
        return stats

    # ------------------------------------------------------------------
    # Internal utilities
    # ------------------------------------------------------------------
    def _build_architecture_info(
        self,
        record: ArchitectureRecord,
        analysis: Dict,
    ) -> ArchitectureInfo:
        latency = self._sanitize_latency(record.latency_for(self.hardware_device))
        measured_accuracy = record.test_accuracy
        predicted_accuracy = self._predict_accuracy(record.arch_tuple)

        if self.use_predictor and predicted_accuracy is not None:
            accuracy = predicted_accuracy
            source = "predictor"
        else:
            accuracy = measured_accuracy if measured_accuracy is not None else EVAL_CONFIG.get("fallback_accuracy", 0.0)
            source = "metrics"

        is_valid = self._is_valid_metrics(accuracy, latency)

        return ArchitectureInfo(
            arch_index=record.index,
            arch_str=record.arch_str,
            accuracy=accuracy,
            latency=latency if latency is not None else 0.0,
            conv_category=analysis.get("category") or "unknown",
            conv_3x3_count=analysis.get("conv_3x3_count", 0),
            conv_1x1_count=analysis.get("conv_1x1_count", 0),
            dataset=self.dataset,
            hardware_device=self.hardware_device,
            source=source,
            is_valid=is_valid,
            measured_accuracy=measured_accuracy,
            predicted_accuracy=predicted_accuracy,
            latencies=record.latencies,
            arch_tuple=record.arch_tuple,
        )

    def _build_missing_architecture_info(self, arch_str: str, analysis: Dict) -> ArchitectureInfo:
        arch_tuple = arch_str_to_tuple(arch_str)
        predicted_accuracy = self._predict_accuracy(arch_tuple) if arch_tuple is not None else None
        accuracy = (
            predicted_accuracy
            if (self.use_predictor and predicted_accuracy is not None)
            else EVAL_CONFIG.get("fallback_accuracy", 0.0)
        )

        return ArchitectureInfo(
            arch_index=-1,
            arch_str=arch_str,
            accuracy=accuracy,
            latency=0.0,
            conv_category=analysis.get("category") or "unknown",
            conv_3x3_count=analysis.get("conv_3x3_count", 0),
            conv_1x1_count=analysis.get("conv_1x1_count", 0),
            dataset=self.dataset,
            hardware_device=self.hardware_device,
            source="predictor" if predicted_accuracy is not None else "missing",
            is_valid=False,
            measured_accuracy=None,
            predicted_accuracy=predicted_accuracy,
            latencies={},
            arch_tuple=tuple(arch_tuple) if arch_tuple is not None else None,
        )

    def _predict_accuracy(self, arch_tuple: Optional[Tuple[int, ...]]) -> Optional[float]:
        if not self.predictor or arch_tuple is None:
            return None
        try:
            return float(self.predictor.predict_architecture(self.dataset, str(tuple(arch_tuple))))
        except KeyError:
            return None
        except Exception as exc:  # pragma: no cover - defensive logging
            logger.debug("Predictor failed for %s: %s", arch_tuple, exc)
            return None

    @staticmethod
    def _sanitize_latency(latency: Optional[float]) -> Optional[float]:
        if latency is None:
            return None
        if latency <= 0:
            return None
        return float(latency)

    @staticmethod
    def _is_valid_metrics(accuracy: Optional[float], latency: Optional[float]) -> bool:
        if accuracy is None or accuracy <= 0 or accuracy > 100:
            return False
        if latency is None or latency <= 0:
            return False
        return True


# ----------------------------------------------------------------------
# Pareto utilities reused by visualisation / analysis modules
# ----------------------------------------------------------------------

def is_valid_architecture_data(arch: ArchitectureInfo, hardware_device: str | None = None) -> bool:
    if not arch.is_valid:
        return False
    if arch.accuracy is None or arch.accuracy <= 0 or arch.accuracy > 100:
        return False
    if arch.latency is None or arch.latency <= 0:
        return False
    return True


def filter_valid_architectures(
    architectures: Iterable[ArchitectureInfo], hardware_device: str | None = None
) -> List[ArchitectureInfo]:
    architectures_list = list(architectures)
    valid_archs = [arch for arch in architectures_list if is_valid_architecture_data(arch, hardware_device)]
    removed = len(architectures_list) - len(valid_archs)
    if removed > 0:
        logger.info("Filtered %d invalid architectures for device %s", removed, hardware_device or "unknown")
    return valid_archs


def calculate_robust_pareto_front(
    architectures: Iterable[ArchitectureInfo], hardware_device: str | None = None
) -> List[ArchitectureInfo]:
    valid_archs = filter_valid_architectures(list(architectures), hardware_device)
    if not valid_archs:
        logger.warning("No valid architectures available for Pareto calculation")
        return []

    pareto_front: List[ArchitectureInfo] = []
    for candidate in valid_archs:
        dominated = False
        for other in valid_archs:
            if candidate is other:
                continue
            accuracy_better = other.accuracy >= candidate.accuracy
            latency_better = other.latency <= candidate.latency
            strictly_better = other.accuracy > candidate.accuracy or other.latency < candidate.latency
            if accuracy_better and latency_better and strictly_better:
                dominated = True
                break
        if not dominated:
            pareto_front.append(candidate)

    logger.info("Robust Pareto front size: %d", len(pareto_front))
    return pareto_front


def get_robust_axis_limits(
    architectures: Iterable[ArchitectureInfo],
    hardware_device: str | None = None,
    margin_factor: float = 0.05,
) -> Tuple[Tuple[float, float], Tuple[float, float]]:
    valid_archs = filter_valid_architectures(list(architectures), hardware_device)
    if not valid_archs:
        return (0.0, 100.0), (0.0, 1.0)

    accuracies = [arch.accuracy for arch in valid_archs]
    latencies = [arch.latency for arch in valid_archs]

    acc_min, acc_max = min(accuracies), max(accuracies)
    lat_min, lat_max = min(latencies), max(latencies)

    acc_margin = max((acc_max - acc_min) * margin_factor, 1e-3)
    lat_margin = max((lat_max - lat_min) * margin_factor, 1e-6)

    acc_range = (max(0.0, acc_min - acc_margin), min(100.0, acc_max + acc_margin))
    lat_range = (max(0.0, lat_min - lat_margin), lat_max + lat_margin)
    return acc_range, lat_range


__all__ = [
    "ArchitectureEvaluator",
    "ArchitectureInfo",
    "calculate_robust_pareto_front",
    "filter_valid_architectures",
    "get_robust_axis_limits",
]
