"""Aggregation of ZSCT accuracy across seeds and episodes (Table 4 format)."""
from __future__ import annotations

import json
import math
from typing import Sequence

from meta_rg.stats import token_stats


def aggregate(episode_results: Sequence[dict]) -> dict:
    """Mean ± std of ZSCT accuracy over a list of episode result dicts."""
    accs = [r["zsct_acc"] * 100.0 for r in episode_results]
    n = len(accs)
    if n == 0:
        return {"mean": 0.0, "std": 0.0, "n": 0}
    mean = sum(accs) / n
    variance = sum((x - mean) ** 2 for x in accs) / max(n - 1, 1)
    return {"mean": round(mean, 1), "std": round(math.sqrt(variance), 1), "n": n}


_NULL_TOK = {"mean": -1.0, "std": -1.0, "min": -1, "max": -1, "n": -1}


def aggregate_seeds(seed_results: Sequence[Sequence[dict]]) -> dict:
    """
    Mean ± std of ZSCT accuracy over all individual episodes across all seeds.
    seed_results[i] = list of episode dicts for seed i.

    std is computed over n_episodes (not over per-seed means).
    per_seed means are preserved in the "per_seed" field for reference.

    All metric keys are always present in the returned dict.
    Metrics unavailable for a given strategy carry the sentinel value -1 / -1.0.
    """
    per_seed_means = [
        sum(r["zsct_acc"] * 100.0 for r in eps) / max(len(eps), 1)
        for eps in seed_results
    ]
    if not per_seed_means:
        return {"mean": 0.0, "std": 0.0, "n": 0, "per_seed": []}

    all_accs = [r["zsct_acc"] * 100.0 for eps in seed_results for r in eps]
    n = len(all_accs)
    mean = sum(all_accs) / n
    variance = sum((x - mean) ** 2 for x in all_accs) / max(n - 1, 1)

    all_eps = [ep for seed in seed_results for ep in seed]

    # ── Token stats: skip sentinel (-1) episodes ──────────────────────────────
    ep_prompt_means = [
        ep["prompt_tokens"]["mean"]
        for ep in all_eps
        if ep.get("prompt_tokens", {}).get("mean", -1.0) >= 0
    ]
    ep_completion_means = [
        ep["completion_tokens"]["mean"]
        for ep in all_eps
        if ep.get("completion_tokens", {}).get("mean", -1.0) >= 0
    ]
    per_episode_tokens = [
        {"prompt_tokens": ep["prompt_tokens"], "completion_tokens": ep["completion_tokens"]}
        for ep in all_eps
        if ep.get("prompt_tokens", {}).get("mean", -1.0) >= 0
    ]
    tok = {
        "prompt_tokens":     token_stats(ep_prompt_means) if ep_prompt_means else _NULL_TOK,
        "completion_tokens": token_stats(ep_completion_means) if ep_completion_means else _NULL_TOK,
        "n_lm_calls": sum(
            ep["prompt_tokens"]["n"] for ep in all_eps
            if ep.get("prompt_tokens", {}).get("n", -1) >= 0
        ) if ep_prompt_means else -1,
        "per_episode_tokens": per_episode_tokens,
    }

    # ── Error stats: skip sentinel (-1) episodes ──────────────────────────────
    def _sum_valid(key: str) -> int:
        vals = [ep[key] for ep in all_eps if ep.get(key, -1) >= 0]
        return sum(vals) if vals else -1

    def _rate(num: int, denom: int) -> float:
        return round(num / max(denom, 1), 3) if num >= 0 else -1.0

    total_games = sum(ep.get("n_test", 0) + ep.get("n_train", 0) for ep in all_eps)
    n_trunc    = _sum_valid("n_truncated")
    n_adapter  = _sum_valid("n_adapter_errors")
    n_re       = _sum_valid("n_re_prompt_truncated")
    n_fmt      = _sum_valid("n_format_errors")
    n_reprompd = (max(n_trunc, 0) + max(n_adapter, 0)) if (n_trunc >= 0 or n_adapter >= 0) else -1

    trunc_stats = {
        "n_truncated":           n_trunc,
        "n_adapter_errors":      n_adapter,
        "n_re_prompt_truncated": n_re,
        "n_format_errors":       n_fmt,
        "truncation_rate":       _rate(n_trunc,   total_games),
        "adapter_error_rate":    _rate(n_adapter, total_games),
        "re_prompt_failure_rate":_rate(n_re, max(n_reprompd, 1) if n_reprompd >= 0 else 1),
        "format_error_rate":     _rate(n_fmt,     total_games),
    }

    n_episodes = len(all_eps)
    n_overflow = sum(1 for ep in all_eps if ep.get("context_overflow", False))
    overflow_stats = {
        "n_context_overflow":   n_overflow,
        "context_overflow_rate": round(n_overflow / max(n_episodes, 1), 3),
    }

    return {
        "mean": round(mean, 1),
        "std": round(math.sqrt(variance), 1),
        "n": n,  # n_episodes
        "per_seed": [round(v, 1) for v in per_seed_means],
        **tok,
        **trunc_stats,
        **overflow_stats,
    }


def format_table_row(model_name: str, results_by_condition: dict) -> str:
    """
    Format results as a Table 4 row.
    results_by_condition: {(O, S): {"mean": ..., "std": ...}, ...}
    """
    cells = []
    for (o, s) in [(1, 1), (1, 2), (4, 1), (4, 2)]:
        r = results_by_condition.get((o, s), {"mean": 0.0, "std": 0.0})
        cells.append(f"{r['mean']:.1f}±{r['std']:.1f}")
    return f"{model_name:<20} | " + " | ".join(cells)


def print_table(results: dict) -> None:
    """
    results: {model_name: {(O, S): {"mean": ..., "std": ...}}}
    """
    header = f"{'Model':<20} | O=1,S=1     | O=1,S=2     | O=4,S=1     | O=4,S=2"
    print(header)
    print("-" * len(header))
    for model_name, cond_results in results.items():
        print(format_table_row(model_name, cond_results))


def save_results(path: str, results: dict) -> None:
    with open(path, "w") as f:
        json.dump(results, f, indent=2)
