# peptide/plot_compare_logz.py
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import matplotlib.pyplot as plt


LOGZ_KEYS_PRIORITY = ["logz_main", "logz"]


def read_jsonl(path: Path) -> List[dict]:
    rows: List[dict] = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if not s:
                continue
            rows.append(json.loads(s))
    return rows


def choose_metrics_file(base: Path) -> Optional[Path]:
    direct = base / "metrics.jsonl"
    if direct.exists():
        return direct

    candidates = []
    if base.exists():
        for sub in base.iterdir():
            if sub.is_dir():
                m = sub / "metrics.jsonl"
                if m.exists():
                    candidates.append(m)

    if not candidates:
        return None
    return sorted(candidates, key=lambda p: p.parent.name)[-1]


def auto_pick_key(rows: List[dict], keys_priority: List[str]) -> Optional[str]:
    present = set()
    for r in rows:
        present.update(r.keys())
    for k in keys_priority:
        if k in present:
            return k
    return None


def extract_epoch_series(rows: List[dict], key: str) -> Dict[int, float]:
    out: Dict[int, float] = {}
    for r in rows:
        if "epoch" not in r or key not in r:
            continue
        try:
            e = int(r["epoch"])
            v = float(r[key])
        except Exception:
            continue
        out[e] = v
    return out


def find_method_seed_metrics(root: Path, exp: str, run_id: str) -> Dict[str, List[Path]]:
    exp_dir = root / exp
    if not exp_dir.exists():
        raise FileNotFoundError(f"Experiment directory not found: {exp_dir}")

    out: Dict[str, List[Path]] = {}

    for method_dir in exp_dir.iterdir():
        if not method_dir.is_dir():
            continue

        run_dir = method_dir / run_id
        if not run_dir.exists():
            continue

        metrics_list: List[Path] = []

        seed_dirs = sorted([p for p in run_dir.iterdir() if p.is_dir() and p.name.startswith("seed_")])
        if seed_dirs:
            for sd in seed_dirs:
                m = choose_metrics_file(sd)
                if m is not None:
                    metrics_list.append(m)
        else:
            m = choose_metrics_file(run_dir)
            if m is not None:
                metrics_list.append(m)

        if metrics_list:
            out[method_dir.name] = metrics_list

    return out


def aggregate_across_seeds(series_by_seed: List[Dict[int, float]]):
    """
    Returns arrays:
      epochs, mean, std, vmin, vmax, count
    Stats computed per-epoch over available seeds.
    """
    all_epochs = sorted(set().union(*[set(d.keys()) for d in series_by_seed]) if series_by_seed else set())
    if not all_epochs:
        z = np.array([])
        return z, z, z, z, z, z

    means, stds, vmins, vmaxs, counts = [], [], [], [], []

    for e in all_epochs:
        vals = [d[e] for d in series_by_seed if e in d]
        arr = np.array(vals, dtype=float)
        means.append(arr.mean())
        stds.append(arr.std(ddof=0))
        vmins.append(arr.min())
        vmaxs.append(arr.max())
        counts.append(len(vals))

    return (
        np.array(all_epochs),
        np.array(means),
        np.array(stds),
        np.array(vmins),
        np.array(vmaxs),
        np.array(counts),
    )


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--run_id", required=True, type=str, help='e.g. "run_27"')
    p.add_argument("--root", default="runs", type=str)
    p.add_argument("--exp", default="peptide", type=str)

    p.add_argument("--metric", default="logz", type=str, choices=["logz", "Z"], help="Plot logz or Z=exp(logz)")
    p.add_argument("--band", default="std", choices=["std", "minmax"], help="Band type")
    p.add_argument("--sigma", default=1.0, type=float, help="If band=std: mean ± sigma*std")

    p.add_argument("--out", default=None, type=str)
    p.add_argument("--min_epoch", default=None, type=int)
    p.add_argument("--max_epoch", default=None, type=int)
    return p.parse_args()


def main():
    args = parse_args()
    root = Path(args.root)

    methods = find_method_seed_metrics(root=root, exp=args.exp, run_id=args.run_id)
    if not methods:
        raise SystemExit(
            f"No methods found for run_id={args.run_id} under {root / args.exp}.\n"
            f"Expected: runs/{args.exp}/<method>/{args.run_id}/seed_<seed>/metrics.jsonl"
        )

    plt.figure()
    any_plotted = False

    for method, metrics_paths in sorted(methods.items()):
        per_seed: List[Dict[int, float]] = []

        for mp in metrics_paths:
            rows = read_jsonl(mp)
            key = auto_pick_key(rows, LOGZ_KEYS_PRIORITY)
            if key is None:
                continue
            s = extract_epoch_series(rows, key)
            if s:
                per_seed.append(s)

        if not per_seed:
            print(f"[WARN] {method}: no usable seeds (missing {LOGZ_KEYS_PRIORITY})")
            continue

        epochs, mean, std, vmin, vmax, count = aggregate_across_seeds(per_seed)
        if epochs.size == 0:
            continue

        # epoch filter
        if args.min_epoch is not None:
            mask = epochs >= args.min_epoch
            epochs, mean, std, vmin, vmax, count = epochs[mask], mean[mask], std[mask], vmin[mask], vmax[mask], count[mask]
        if args.max_epoch is not None:
            mask = epochs <= args.max_epoch
            epochs, mean, std, vmin, vmax, count = epochs[mask], mean[mask], std[mask], vmin[mask], vmax[mask], count[mask]

        if epochs.size == 0:
            continue

        if args.metric == "Z":
            mean = np.exp(mean)
            # para banda:
            if args.band == "std":
                std = np.exp(std)  # aproximação visual
            vmin = np.exp(vmin)
            vmax = np.exp(vmax)
            ylabel = "Z = exp(logz)"
        else:
            ylabel = "logz"

        (line,) = plt.plot(epochs, mean, label=f"{method} (n={len(per_seed)})")
        color = line.get_color()

        if args.band == "std":
            band = args.sigma * std
            low = mean - band
            high = mean + band
        else:
            low = vmin
            high = vmax

        plt.fill_between(epochs, low, high, alpha=0.2, color=color)
        any_plotted = True

    if not any_plotted:
        raise SystemExit("Nothing plotted (no usable methods/metrics).")

    plt.xlabel("epoch")
    plt.ylabel(ylabel)
    plt.title(f"{args.exp}: {args.run_id} — mean + band across seeds (band={args.band})")
    plt.legend()
    plt.tight_layout()

    if args.out is None:
        out_dir = root / args.exp / "plots"
        out_dir.mkdir(parents=True, exist_ok=True)
        out_path = out_dir / f"compare_{args.metric}_{args.band}__{args.run_id}.png"
    else:
        out_path = Path(args.out)
        out_path.parent.mkdir(parents=True, exist_ok=True)

    plt.savefig(out_path, dpi=200)
    print(f"Saved: {out_path}")


if __name__ == "__main__":
    main()