#!/usr/bin/env python3
"""
Calculate mean/CVaR (+SE across seeds) from ep_returns.(npy|py) under frozen_logs/runs.
"""

import argparse
import ast
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union

import numpy as np


def load_ep_returns(run_dir: Path) -> np.ndarray:
    npy_path = run_dir / "ep_returns.npy"
    if npy_path.exists():
        return np.load(npy_path)

    py_path = run_dir / "ep_returns.py"
    if py_path.exists():
        text = py_path.read_text()
        module = ast.parse(text)
        for node in module.body:
            if isinstance(node, ast.Assign):
                for target in node.targets:
                    if isinstance(target, ast.Name) and target.id == "ep_returns":
                        return np.asarray(ast.literal_eval(node.value), dtype=float)
        raise ValueError(f"ep_returns not found in {py_path}")

    raise FileNotFoundError(f"ep_returns.(npy|py) not found in {run_dir}")


def select_returns(arr: np.ndarray, steps: Optional[int]) -> np.ndarray:
    arr = np.asarray(arr, float)
    if steps is None:
        return arr.reshape(-1)
    if arr.ndim == 1:
        return arr[:steps].reshape(-1)
    idx = steps
    if idx < 0:
        idx = arr.shape[0] + idx
    if idx >= arr.shape[0]:
        idx = arr.shape[0] - 1
    return np.asarray(arr[idx], float).reshape(-1)


def calculate_cvar(scores: np.ndarray, alpha: float) -> float:
    if len(scores) == 0:
        return float("nan")
    sorted_scores = np.sort(scores)
    k = int(np.ceil(alpha * len(sorted_scores)))
    if k == 0:
        return float(sorted_scores[0])
    return float(np.mean(sorted_scores[:k]))


def sem(values: List[float]) -> float:
    if len(values) <= 1:
        return float("nan")
    vals = np.asarray(values, float)
    return float(np.std(vals, ddof=1) / np.sqrt(len(vals)))


def parse_seed(seed_str: str) -> Union[int, str]:
    try:
        return int(seed_str)
    except ValueError:
        return seed_str


def format_value(value: float, error: float, scientific: bool) -> str:
    if np.isnan(value) or np.isnan(error):
        return "N/A"
    if scientific and (abs(value) >= 100 or abs(error) >= 100):
        exp = int(np.floor(np.log10(max(abs(value), abs(error)))))
        scaled_val = value / (10**exp)
        scaled_err = error / (10**exp)
        return f"{scaled_val:.2f}×10^{exp}±{scaled_err:.2f}×10^{exp}"
    return f"{value:.2f}±{error:.2f}"


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Compute mean/CVaR (+SE) across seeds from frozen_logs/runs ep_returns."
    )
    parser.add_argument(
        "--runs_dir",
        type=str,
        default="~/tmp/rdp-submission/frozen_logs/runs",
        help="Path to frozen_logs/runs",
    )
    parser.add_argument(
        "--steps",
        type=int,
        default=1000,
        help="Use up to this many returns for 1D arrays; or index for 2D arrays.",
    )
    parser.add_argument(
        "--seeds",
        type=str,
        default="0,1,2,3,4",
        help="Comma-separated seed list to include (default: 0-4).",
    )
    parser.add_argument(
        "--alpha",
        type=float,
        default=0.1,
        help="CVaR alpha (default: 0.1).",
    )
    parser.add_argument(
        "--scientific",
        action="store_true",
        help="Use scientific notation for large values.",
    )
    args = parser.parse_args()

    runs_dir = Path(args.runs_dir)
    if not runs_dir.exists():
        raise FileNotFoundError(f"runs_dir not found: {runs_dir}")

    seed_filter = {parse_seed(s.strip()) for s in args.seeds.split(",") if s.strip()}

    groups: Dict[str, Dict[Union[int, str], Path]] = defaultdict(dict)
    for run_dir in sorted(runs_dir.iterdir()):
        if not run_dir.is_dir():
            continue
        parts = run_dir.name.split("|")
        if len(parts) < 2:
            continue
        seed = parse_seed(parts[-1])
        prefix = "|".join(parts[:-1])
        if seed in seed_filter:
            groups[prefix][seed] = run_dir

    if not groups:
        print("No matching runs found.")
        return

    rows = []
    for prefix, seed_dirs in sorted(groups.items()):
        seed_means = []
        seed_cvars = []
        episodes_per_seed = None
        used_seeds = []

        for seed in sorted(seed_dirs.keys(), key=lambda x: (isinstance(x, str), x)):
            try:
                ep_returns = load_ep_returns(seed_dirs[seed])
                selected = select_returns(ep_returns, args.steps)
                if episodes_per_seed is None:
                    episodes_per_seed = len(selected)
                mean_score = float(np.mean(selected)) if len(selected) else float("nan")
                cvar_score = calculate_cvar(selected, args.alpha)
                seed_means.append(mean_score)
                seed_cvars.append(cvar_score)
                used_seeds.append(seed)
            except Exception as exc:
                print(f"Warning: {prefix}|{seed}: {exc}")

        if seed_means:
            mean = float(np.mean(seed_means))
            mean_err = sem(seed_means)
            cvar = float(np.mean(seed_cvars))
            cvar_err = sem(seed_cvars)
        else:
            mean = mean_err = cvar = cvar_err = float("nan")

        rows.append(
            {
                "run": prefix,
                "mean": mean,
                "mean_err": mean_err,
                "cvar": cvar,
                "cvar_err": cvar_err,
                "seeds_used": len(used_seeds),
                "episodes": episodes_per_seed or 0,
            }
        )

    run_width = max(len(r["run"]) for r in rows)
    print("\nSUMMARY TABLE (Seed mean ± SE)")
    print(f"steps={args.steps}, alpha={args.alpha}, seeds={sorted(seed_filter)}")
    print("=" * (run_width + 70))
    header = (
        f"{'Run':<{run_width}} "
        f"{'Mean±SE':<20} "
        f"{'CVaR±SE':<20} "
        f"{'Seeds':<6} "
        f"{'Episodes':<9}"
    )
    print(header)
    print("-" * (run_width + 70))

    for row in rows:
        mean_str = format_value(row["mean"], row["mean_err"], args.scientific)
        cvar_str = format_value(row["cvar"], row["cvar_err"], args.scientific)
        print(
            f"{row['run']:<{run_width}} "
            f"{mean_str:<20} "
            f"{cvar_str:<20} "
            f"{row['seeds_used']:<6} "
            f"{row['episodes']:<9}"
        )


if __name__ == "__main__":
    main()
