"""Pareto metrics utilities aligned with the CSV-based data pipeline."""

from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np

from pel_nas.core.config import DATA_CONFIG
from pel_nas.data.metrics_repository import ArchitectureRecord, MetricsRepository
from pel_nas.search.architecture_analyzer import ArchitectureAnalyzer
from pel_nas.search.evaluator import (
    ArchitectureInfo,
    calculate_robust_pareto_front,
    filter_valid_architectures,
)

logger = logging.getLogger(__name__)


# ----------------------------------------------------------------------
# Lightweight Pareto helpers (adapted from compute_pareto_metrics.py)
# ----------------------------------------------------------------------

def pareto_front(points: np.ndarray) -> np.ndarray:
    """Return indices of the non-dominated points for minimisation objectives."""
    if points.size == 0:
        return np.array([], dtype=int)

    order = np.lexsort((points[:, 1], points[:, 0]))  # latency asc, then -acc asc
    best_second = np.inf
    dominated = np.zeros(points.shape[0], dtype=bool)

    for idx in order:
        value = points[idx, 1]
        if value < best_second - 1e-12:
            best_second = value
        else:
            dominated[idx] = True
    return np.where(~dominated)[0]


def normalize(points: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    mins = points.min(axis=0)
    maxs = points.max(axis=0)
    denom = np.where(maxs > mins, maxs - mins, 1.0)
    return (points - mins) / denom, mins, maxs


def hypervolume_2d(front: np.ndarray, reference: Tuple[float, float] = (1.0, 1.0)) -> float:
    if front.size == 0:
        return 0.0
    pts = front[np.argsort(front[:, 0])]
    refx, refy = reference
    hv = 0.0
    prev_y = refy
    for x, y in pts:
        if y < prev_y:
            hv += max(refx - x, 0.0) * max(prev_y - y, 0.0)
            prev_y = y
    return float(max(hv, 0.0))


def igd(opt_front: np.ndarray, found_front: np.ndarray) -> float:
    if opt_front.size == 0:
        return 0.0
    if found_front.size == 0:
        return float(np.linalg.norm(opt_front - 1.0, axis=1).mean())
    dists = [np.min(np.linalg.norm(found_front - p, axis=1)) for p in opt_front]
    return float(np.mean(dists))


# ----------------------------------------------------------------------
# Data-access helpers
# ----------------------------------------------------------------------

@dataclass
class DatasetPareto:
    dataset_architectures: List[ArchitectureInfo]
    dataset_front: List[ArchitectureInfo]

    @property
    def dataset_sample(self) -> List[ArchitectureInfo]:
        return self.dataset_architectures

    @property
    def dataset_pareto(self) -> List[ArchitectureInfo]:
        return self.dataset_front


class ParetoMetrics:
    """Utility that exposes dataset fronts and evaluation metrics."""

    def __init__(self, metrics_csv: Optional[str] = None) -> None:
        csv_path = metrics_csv or DATA_CONFIG.get("metrics_csv")
        if not csv_path:
            raise RuntimeError("DATA_CONFIG['metrics_csv'] must be configured for ParetoMetrics")

        self.metrics_repo = MetricsRepository.shared(csv_path)
        self.analyzer = ArchitectureAnalyzer()

    # ------------------------------------------------------------------
    # Dataset helpers
    # ------------------------------------------------------------------
    def collect_dataset_architectures(self, dataset: str, device: str) -> List[ArchitectureInfo]:
        records = self.metrics_repo.list_records(dataset)
        infos: List[ArchitectureInfo] = []

        for record in records:
            latency = record.latency_for(device)
            accuracy = record.test_accuracy
            if accuracy is None or accuracy <= 0:
                continue
            if latency is None or latency <= 0:
                continue

            analysis = self.analyzer.analyze_architecture(record.arch_str)
            info = ArchitectureInfo(
                arch_index=record.index,
                arch_str=record.arch_str,
                accuracy=float(accuracy),
                latency=float(latency),
                conv_category=analysis.get("category") or "unknown",
                conv_3x3_count=analysis.get("conv_3x3_count", 0),
                conv_1x1_count=analysis.get("conv_1x1_count", 0),
                dataset=dataset,
                hardware_device=device,
                source="metrics",
                is_valid=True,
                measured_accuracy=float(accuracy),
                predicted_accuracy=None,
                latencies=record.latencies,
                arch_tuple=record.arch_tuple,
            )
            infos.append(info)

        return infos

    def compute_dataset_front(self, dataset: str, device: str) -> DatasetPareto:
        all_infos = self.collect_dataset_architectures(dataset, device)
        front = calculate_robust_pareto_front(all_infos, hardware_device=device)
        return DatasetPareto(all_infos, front)

    # ------------------------------------------------------------------
    # Metric computation
    # ------------------------------------------------------------------
    def calculate_hypervolume_ratio(
        self,
        dataset_front: Sequence[ArchitectureInfo],
        generated_front: Sequence[ArchitectureInfo],
    ) -> Dict[str, float]:
        dataset_points = self._to_objective_points(dataset_front)
        gen_points = self._to_objective_points(generated_front)

        if dataset_points.size == 0 or gen_points.size == 0:
            return {"hv_ratio": 0.0, "dataset_hv": 0.0, "generated_hv": 0.0}

        dataset_norm, mins, maxs = normalize(dataset_points)
        # Normalise generated using same bounds
        denom = np.where(maxs > mins, maxs - mins, 1.0)
        gen_norm = (gen_points - mins) / denom

        dataset_hv = hypervolume_2d(dataset_norm)
        generated_hv = hypervolume_2d(gen_norm)
        hv_ratio = generated_hv / dataset_hv if dataset_hv > 1e-12 else 0.0

        return {
            "hv_ratio": hv_ratio,
            "dataset_hv": dataset_hv,
            "generated_hv": generated_hv,
        }

    def calculate_igd(
        self,
        dataset_front: Sequence[ArchitectureInfo],
        generated_front: Sequence[ArchitectureInfo],
    ) -> float:
        dataset_points = self._to_objective_points(dataset_front)
        gen_points = self._to_objective_points(generated_front)
        if dataset_points.size == 0 or gen_points.size == 0:
            return float("inf")

        dataset_norm, mins, maxs = normalize(dataset_points)
        denom = np.where(maxs > mins, maxs - mins, 1.0)
        gen_norm = (gen_points - mins) / denom
        return igd(dataset_norm, gen_norm)

    def calculate_all_metrics(
        self,
        dataset_front: Sequence[ArchitectureInfo],
        generated_front: Sequence[ArchitectureInfo],
    ) -> Dict[str, float]:
        hv_data = self.calculate_hypervolume_ratio(dataset_front, generated_front)
        igd_value = self.calculate_igd(dataset_front, generated_front)
        return {
            **hv_data,
            "igd": igd_value,
            "dataset_pareto_size": len(dataset_front),
            "generated_pareto_size": len(generated_front),
        }

    @staticmethod
    def _to_objective_points(architectures: Sequence[ArchitectureInfo]) -> np.ndarray:
        if not architectures:
            return np.empty((0, 2))
        points = np.array([[arch.latency, -arch.accuracy] for arch in architectures], dtype=float)
        return points


def evaluate_pareto_performance(
    dataset_front: Sequence[ArchitectureInfo],
    generated_front: Sequence[ArchitectureInfo],
) -> Dict[str, float]:
    metrics = ParetoMetrics()
    return metrics.calculate_all_metrics(dataset_front, generated_front)


__all__ = [
    "DatasetPareto",
    "ParetoMetrics",
    "evaluate_pareto_performance",
    "pareto_front",
]
