# autointerp_hf/eval_output.py
from __future__ import annotations

import time
import statistics
from dataclasses import dataclass, asdict, is_dataclass
from typing import Any, Dict, List, Union


@dataclass
class AutoInterpResultSingleLatent:
    """
    Canonical per-latent result container.

    Fields:
        latent: int
            The SAE feature index (a "neuron"/"latent").
        explanation: str
            Natural language description generated by the judge LLM.
        predictions: List[int]
            1-based indices of examples the judge predicted would activate.
        correct_seqs: List[int]
            1-based indices of examples that actually activated in the scoring set.
        score: float
            Classification accuracy (or similar) of the judge on this latent.
        logs: str
            Pretty-printed debugging info: prompts, tables, etc., for this latent.
    """

    latent: int
    explanation: str
    predictions: List[int]
    correct_seqs: List[int]
    score: float
    logs: str

    def to_dict(self) -> Dict[str, Any]:
        """
        Convert to a plain dict that can be JSON-serialized.
        """
        return {
            "latent": self.latent,
            "explanation": self.explanation,
            "predictions": list(self.predictions),
            "correct_seqs": list(self.correct_seqs),
            "score": float(self.score) if self.score is not None else None,
            "logs": self.logs,
        }


@dataclass
class AutoInterpEvalOutput:
    """
    Top-level evaluation result. We generally won't return this dataclass
    directly to run_eval.py (because run_eval.py expects a plain dict for
    json.dump), but this documents the shape we want to emit.
    """

    eval_config: Dict[str, Any]
    eval_id: str
    datetime_epoch_millis: int
    metrics: Dict[str, Any]
    per_latent_results: Dict[str, Dict[str, Any]]
    sae_metadata: Dict[str, Any]
    model_name_or_path: str
    hook_module_path: str

    def to_dict(self) -> Dict[str, Any]:
        """
        Convert to JSON-serializable dict. Uses asdict() but keeps field names stable.
        """
        return {
            "eval_config": self.eval_config,
            "eval_id": self.eval_id,
            "datetime_epoch_millis": self.datetime_epoch_millis,
            "metrics": self.metrics,
            "per_latent_results": self.per_latent_results,
            "sae_metadata": self.sae_metadata,
            "model_name_or_path": self.model_name_or_path,
            "hook_module_path": self.hook_module_path,
        }


def _result_obj_to_plain_dict(res: Any) -> Dict[str, Any]:
    """
    Normalize a single latent result (which may be either:
      - a dict (new code path), or
      - a dataclass / object with attributes (older code path),
    into a plain Python dict with the canonical keys:
      latent, explanation, predictions, correct_seqs, score, logs
    """

    # Case 1: already a dict (the current pipeline behavior).
    if isinstance(res, dict):
        # We trust keys like "latent", "score", etc. to already exist.
        return {
            "latent": res.get("latent"),
            "explanation": res.get("explanation", ""),
            "predictions": list(res.get("predictions", [])),
            "correct_seqs": list(res.get("correct_seqs", [])),
            "score": float(res.get("score")) if res.get("score") is not None else None,
            "logs": res.get("logs", ""),
        }

    # Case 2: it might be a dataclass like AutoInterpResultSingleLatent.
    if is_dataclass(res):
        asd = asdict(res)
        return {
            "latent": asd.get("latent"),
            "explanation": asd.get("explanation", ""),
            "predictions": list(asd.get("predictions", [])),
            "correct_seqs": list(asd.get("correct_seqs", [])),
            "score": float(asd.get("score")) if asd.get("score") is not None else None,
            "logs": asd.get("logs", ""),
        }

    # Case 3: plain object with attributes (.latent, .score, .to_dict(), etc.)
    if hasattr(res, "to_dict") and callable(res.to_dict):
        d = res.to_dict()
        return {
            "latent": d.get("latent"),
            "explanation": d.get("explanation", ""),
            "predictions": list(d.get("predictions", [])),
            "correct_seqs": list(d.get("correct_seqs", [])),
            "score": float(d.get("score")) if d.get("score") is not None else None,
            "logs": d.get("logs", ""),
        }

    # Last fallback: try attribute access
    return {
        "latent": getattr(res, "latent", None),
        "explanation": getattr(res, "explanation", ""),
        "predictions": list(getattr(res, "predictions", [])),
        "correct_seqs": list(getattr(res, "correct_seqs", [])),
        "score": float(getattr(res, "score"))
        if getattr(res, "score", None) is not None
        else None,
        "logs": getattr(res, "logs", ""),
    }


def build_eval_output(
    eval_config: Any = None,
    results_dict: Dict[int, Any] = None,
    eval_id: str = "",
    eval_start_time_millis: int | None = None,
    sae_metadata: Dict[str, Any] = None,
    model_name_or_path: str = "",
    hook_module_path: str = "",
    cfg: Any = None,
) -> Dict[str, Any]:
    """
    Build the final evaluation output dictionary.

    NOTE:
    - We accept BOTH `eval_config` and `cfg` for backwards compatibility.
      Older callsites may do: build_eval_output(eval_config=cfg, ...)
      Newer callsites may do: build_eval_output(cfg=cfg, ...)
      We unify them internally.

    Args:
        eval_config / cfg:
            AutoInterpEvalConfig (or similar). We serialize it into a plain
            dict under top-level key "eval_config" -> {"cfg": ...}.
        results_dict:
            Mapping latent_id -> per-latent result (dict or dataclass-like).
        eval_id:
            String identifier for this eval pass.
        eval_start_time_millis:
            ms since epoch at run start (or None -> we'll generate now).
        sae_metadata:
            Info about which SAE / which hook layer, etc.
        model_name_or_path:
            HF model ID or path.
        hook_module_path:
            The module path we hooked inside the model.
    """

    # -----------------------
    # 0. Normalize config obj
    # -----------------------
    # Prefer explicit eval_config if given. Otherwise fall back to cfg.
    config_obj = eval_config if eval_config is not None else cfg

    def _cfg_to_plain_dict(obj: Any) -> Dict[str, Any]:
        """
        Convert an arbitrary config-like object (which might be a pydantic
        model, a dataclass, or a simple namespace) into something JSON-safe.
        We always wrap it as {"cfg": ...} to keep the shape stable.
        """
        if obj is None:
            return {"cfg": None}

        # pydantic v2 style
        if hasattr(obj, "model_dump") and callable(obj.model_dump):
            return {"cfg": obj.model_dump()}

        # pydantic v1 style
        if hasattr(obj, "dict") and callable(obj.dict):
            return {"cfg": obj.dict()}

        # plain python object with __dict__
        if hasattr(obj, "__dict__"):
            return {"cfg": dict(obj.__dict__)}

        # last fallback: just string it
        return {"cfg": str(obj)}

    eval_cfg_plain = _cfg_to_plain_dict(config_obj)

    # Basic sanity for optional inputs
    if results_dict is None:
        results_dict = {}
    if sae_metadata is None:
        sae_metadata = {}

    # -----------------------
    # 1. Normalize per-latent
    # -----------------------
    per_latent_results: Dict[str, Dict[str, Any]] = []
    per_latent_results = {}
    scores: List[float] = []

    for latent_id, res in results_dict.items():
        # Normalize each latent result -> plain dict with keys we expect.
        res_dict = _result_obj_to_plain_dict(res)

        # Track scores for summary statistics.
        score_val = res_dict.get("score", None)
        if score_val is not None:
            try:
                scores.append(float(score_val))
            except Exception:
                pass

        # JSON wants string keys
        per_latent_results[str(latent_id)] = res_dict

    # -----------------------
    # 2. Aggregate metrics
    # -----------------------
    if len(scores) == 0:
        mean_score = 0.0
        std_dev = 0.0
    else:
        mean_score = float(sum(scores) / len(scores))
        if len(scores) > 1:
            import statistics

            std_dev = float(statistics.pstdev(scores))  # population stdev
        else:
            std_dev = 0.0

    metrics = {
        "autointerp_score": mean_score,
        "autointerp_std_dev": std_dev,
    }

    # -----------------------
    # 3. Timestamp
    # -----------------------
    if eval_start_time_millis is None:
        import time

        eval_start_time_millis = int(time.time() * 1000)

    # -----------------------
    # 4. Final JSON-safe dict
    # -----------------------
    final_output = {
        "eval_config": eval_cfg_plain,
        "eval_id": eval_id,
        "datetime_epoch_millis": eval_start_time_millis,
        "metrics": metrics,
        "per_latent_results": per_latent_results,
        "sae_metadata": sae_metadata,
        "model_name_or_path": model_name_or_path,
        "hook_module_path": hook_module_path,
    }

    return final_output
