#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse, glob, json, os
from collections import defaultdict
import numpy as np

BINS = [
    ("4-6", 3, 6),
    ("7-9", 7, 9),
    ("10-12", 10, 12),
    ("13-15", 13, 15),
    ("16-18", 16, 18),
    ("19-21", 19, 21),
]

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--results_dir", required=True)
    p.add_argument("--prefix", default=None)
    return p.parse_args()

def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def _normalize_board(board):
    if board is None:
        return ""
    if isinstance(board, str):
        return board.strip()
    if isinstance(board, (list, tuple)):
        return json.dumps(board, ensure_ascii=False, separators=(",", ":"))
    if isinstance(board, dict):
        return json.dumps(board, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
    return str(board).strip()

def iter_cells(x):
    if x is None:
        return
    if isinstance(x, str):
        for c in x: yield c
    elif isinstance(x, (list, tuple)):
        for v in x: yield from iter_cells(v)
    elif isinstance(x, dict):
        for v in x.values(): yield from iter_cells(v)
    else:
        yield x

def count_missing(board) -> int:
    m = 0
    for x in iter_cells(board):
        if x is None: m += 1
        elif isinstance(x, (int, float)) and int(x) == 0: m += 1
        elif isinstance(x, str) and x in ("0", "."): m += 1
    return m

def bin_idx(m: int) -> int:
    if m < 3 or m > 21:
        raise ValueError(f"missing={m} out of range [3,21]")
    for i, (_, lo, hi) in enumerate(BINS):
        if lo <= m <= hi:
            return i
    raise ValueError(f"missing={m} fits no bin")

def collect_files(d, prefix):
    pat = f"*{prefix}*.json" if prefix else "*.json"
    return sorted(set(glob.glob(os.path.join(d, "**", pat), recursive=True)))

def main():
    args = parse_args()
    files = collect_files(args.results_dir, args.prefix)

    # bin -> board -> list of run records (success/progress/eff_step)
    runs_by_bin_board = [defaultdict(list) for _ in range(6)]

    for fp in files:
        data = load_json(fp)
        if not isinstance(data, list):
            raise ValueError(f"{fp}: JSON list expected")
        for r in data:
            task = (r or {}).get("task") or {}
            obs  = (r or {}).get("last_observation") or {}
            steps = r.get("steps") or []

            board_norm = _normalize_board(task.get("board_config"))
            m = task.get("min_moves")
            b = bin_idx(m)

            term = obs.get("termination_reason", "")
            success = ("COMPLETE" in term)
            try:
                prog = float(obs.get("progress", 0.0) or 0.0)
            except Exception:
                prog = 0.0

            eff = None
            if ("COMPLETE" in term) or ("FILLED" in term):
                eff = len(steps)

            runs_by_bin_board[b][board_norm].append((success, prog, eff))

    cnt = [0]*6
    pass_at_k = [0.0]*6
    avg_succ = [0.0]*6
    avg_prog = [0.0]*6
    avg_eff  = [0.0]*6

    for bi in range(6):
        boards = runs_by_bin_board[bi]
        cnt[bi] = len(boards)
        if not boards:
            continue

        pass_rate_list = []
        pass_at_k_list = []
        prog_mean_list = []
        eff_steps_all = []

        for _, rs in boards.items():
            num_trials = len(rs)
            pass_rate_list.append(sum(s for s, _, _ in rs) / num_trials if num_trials else 0.0)
            pass_at_k_list.append(any(s for s, _, _ in rs))
            prog_mean_list.append(sum(p for _, p, _ in rs) / num_trials if num_trials else 0.0)
            for _, _, e in rs:
                if e is not None:
                    eff_steps_all.append(e)

        pass_at_k[bi] = float(np.mean(pass_at_k_list)) * 100.0
        avg_succ[bi]  = float(np.mean(pass_rate_list)) * 100.0
        avg_prog[bi]  = float(np.mean(prog_mean_list)) * 100.0
        avg_eff[bi]   = float(np.mean(eff_steps_all)) if eff_steps_all else 0.0

    print("\t".join(b[0] for b in BINS))
    print("\t".join(str(x) for x in cnt))
    print("\t".join(f"{x:.2f}" for x in pass_at_k))
    print("\t".join(f"{x:.2f}" for x in avg_succ))
    print("\t".join("-" for _ in avg_prog))
    print("\t".join(f"{x:.2f}" for x in avg_eff))

if __name__ == "__main__":
    main()
