#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse, glob, json, os, sys
from collections import defaultdict
import numpy as np
import fnmatch  # NEW: for folder_name glob

# For colored output
def print_red(text, file=sys.stdout):
    RED = "\033[91m"
    RESET = "\033[0m"
    print(RED + text + RESET, file=file)

def print_blue(text, file=sys.stdout):
    BLUE = "\033[94m"
    RESET = "\033[0m"
    print(BLUE + text + RESET, file=file)

# -----------------------------
# Missing bins
# -----------------------------
BINS = [
    ("11-15", 11, 15),
    ("16-20", 16, 20),
    ("21-25", 21, 25),
    ("26-30", 26, 30),
    ("31-35", 31, 35),
    ("36-40", 36, 40),
    ("41-45", 41, 45),
]

# Extra categories (missing irrelevant)
EXTRA_CATS = [
    ("hard", None, None),
    ("variant", None, None),
]

def _build_cats(for_horizon_generalization: bool):
    """
    Returns:
      CATS, CAT_HARD_IDX, CAT_VARIANT_IDX, BIN_OFFSET, OUT_ORDER, HARD_BIN_OFFSET(optional)
    """
    if not for_horizon_generalization:
        # Original: hard/variant are special categories, plus missing bins
        CATS = EXTRA_CATS + BINS
        CAT_HARD_IDX = 0
        CAT_VARIANT_IDX = 1
        BIN_OFFSET = len(EXTRA_CATS)

        # Output order: bins first, then hard, variant
        OUT_ORDER = list(range(BIN_OFFSET, len(CATS))) + [CAT_HARD_IDX, CAT_VARIANT_IDX]
        HARD_BIN_OFFSET = None
        return CATS, CAT_HARD_IDX, CAT_VARIANT_IDX, BIN_OFFSET, OUT_ORDER, HARD_BIN_OFFSET

    # for_horizon_generalization:
    # bins + bins(hard) + variant
    HARD_BINS = [(f"{name}(hard)", lo, hi) for (name, lo, hi) in BINS]
    CATS = BINS + HARD_BINS + [("variant", None, None)]

    BIN_OFFSET = 0
    HARD_BIN_OFFSET = len(BINS)
    CAT_HARD_IDX = None  # unused in this mode
    CAT_VARIANT_IDX = len(CATS) - 1

    # Already in desired output order
    OUT_ORDER = list(range(len(CATS)))
    return CATS, CAT_HARD_IDX, CAT_VARIANT_IDX, BIN_OFFSET, OUT_ORDER, HARD_BIN_OFFSET

# Initialize globals with default mode (original behavior)
CATS, CAT_HARD_IDX, CAT_VARIANT_IDX, BIN_OFFSET, OUT_ORDER, HARD_BIN_OFFSET = _build_cats(False)

HARD_REPO_DEFAULT = YOUR_PATH
HARD_SPLIT_DEFAULT = "train"

# -----------------------------
# Expected counts per folder (must match original OUT_ORDER)
# Order: 11-15 16-20 21-25 26-30 31-35 36-40 41-45 hard variant
# -----------------------------
_EXPECT_ALL_ZERO = [0, 0, 0, 0, 0, 0, 0, 0, 0]

EXPECTED_COUNTS_BY_FOLDER = {
    "last_final_standard":       [100, 100,  50,  50, 100, 100,  50,   0,   0],
    "last_final_standard_hotfix": [100, 100, 100, 100, 100, 100, 50,   0,   0],
    "last_final_standard_hotfix_only": [0, 0, 50, 50, 0, 0, 0, 0, 0],
    "last_final_standard_remain": [0, 0, 50, 50, 0, 0, 0, 250, 68],
    "last_final_standard_all":   [100, 100, 100, 100, 100, 100,  50, 250, 68],
    "last_final_standard_hard":   [0, 0, 50, 50, 50, 50,  50, 250, 0],
    "last_final_standard_hard_chance": _EXPECT_ALL_ZERO,
    "last_final_standard_hard_chance1": _EXPECT_ALL_ZERO,
    "last_final_standard_hard_chance3": _EXPECT_ALL_ZERO,
    "last_final_standard_hard_chance5": _EXPECT_ALL_ZERO,
    "last_final_standard_hard_remove_hardskill": _EXPECT_ALL_ZERO,
}

# -----------------------------
# Args / IO
# -----------------------------
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--results_dir", required=True)
    p.add_argument("--prefix", default=None)
    p.add_argument("--hard_repo", default=HARD_REPO_DEFAULT)
    p.add_argument("--hard_split", default=HARD_SPLIT_DEFAULT)
    p.add_argument("--for_horizon_generalization", action="store_true")
    # NEW: add --folder_name option
    p.add_argument(
        "--folder_name",
        default=None,
        help="Glob pattern to select subfolder names under results_dir (e.g. 'last_final_standard*')."
    )
    # NEW: add --ignore_maxk argument
    p.add_argument(
        "--ignore_maxk",
        action="store_true",
        default=False,
        help="If set, skip the max_k==4 checks."
    )
    return p.parse_args()

def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

# -----------------------------
# Board utils
# -----------------------------
def _normalize_board(board):
    if board is None:
        return ""
    if isinstance(board, str):
        return board.strip()
    if isinstance(board, (list, tuple)):
        return json.dumps(board, ensure_ascii=False, separators=(",", ":"))
    if isinstance(board, dict):
        return json.dumps(board, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
    return str(board).strip()

def iter_cells(x):
    if x is None:
        return
    if isinstance(x, str):
        for c in x:
            yield c
    elif isinstance(x, (list, tuple)):
        for v in x:
            yield from iter_cells(v)
    elif isinstance(x, dict):
        for v in x.values():
            yield from iter_cells(v)
    else:
        yield x

def count_missing(board) -> int:
    m = 0
    for x in iter_cells(board):
        if x is None:
            m += 1
        elif isinstance(x, (int, float)) and int(x) == 0:
            m += 1
        elif isinstance(x, str) and x in ("0", "."):
            m += 1
    return m

def bin_local_idx(m: int) -> int:
    if m < 11 or m > 45:
        raise ValueError(f"missing={m} out of range [11,45]")
    for i, (_, lo, hi) in enumerate(BINS):
        if lo <= m <= hi:
            return i
    raise ValueError(f"missing={m} fits no bin")

# -----------------------------
# Classification
# -----------------------------
def load_hard_board_set(repo_id: str, split: str) -> set:
    from datasets import load_dataset
    ds = load_dataset(repo_id, split=split)
    return {_normalize_board(x) for x in ds["initial_board"]}

def category_idx(task: dict, board_norm: str, hard_set: set) -> int:
    # priority: variant > hard > missing-bin
    task_name = str(task.get("task_name", "") or "")
    if "variant" in task_name:
        return CAT_VARIANT_IDX

    if board_norm in hard_set:
        # Original mode: single hard bucket
        if HARD_BIN_OFFSET is None:
            return CAT_HARD_IDX
        # for_horizon_generalization: hard also follows missing bins
        m = count_missing(task.get("initial_board"))
        return HARD_BIN_OFFSET + bin_local_idx(m)

    m = count_missing(task.get("initial_board"))
    return BIN_OFFSET + bin_local_idx(m)

# -----------------------------
# IO helpers
# -----------------------------
def collect_files(d, prefix):
    pat = f"*{prefix}*.json" if prefix else "*.json"
    return sorted(set(glob.glob(os.path.join(d, "**", pat), recursive=True)))

# -----------------------------
# Core processing
# -----------------------------
def list_target_subdirs(results_dir: str, folder_name: str | None = None):
    """
    - Only subdirectories directly under results_dir are considered
    - name.startswith('last_final_standard')
    - Must exist in EXPECTED_COUNTS_BY_FOLDER
    - If folder_name is given, additional filter using fnmatch (glob)
    """
    subs = []
    try:
        for name in sorted(os.listdir(results_dir)):
            if folder_name is not None:
                if not fnmatch.fnmatch(name, folder_name):
                    continue

            p = os.path.join(results_dir, name)
            if (
                os.path.isdir(p)
                and name.startswith("last_final_standard")
                and name in EXPECTED_COUNTS_BY_FOLDER
            ):
                subs.append((name, p))
    except FileNotFoundError:
        raise ValueError(f"results_dir not found: {results_dir}")
    return subs

def _validate_dir_names(parsed_names):
    """
    parsed_names: iterable[str]
    Enforce that the set of parsed dir names matches one of the allowed sets.
    """
    allowed_sets = [
        {'last_final_standard'},
        {"last_final_standard", "last_final_standard_hotfix_only"},
        {"last_final_standard", "last_final_standard_hotfix_only", "last_final_standard_remain"},
        {"last_final_standard_hotfix"},
        {"last_final_standard_hotfix", "last_final_standard_remain"},
        {"last_final_standard_all"},
        {"last_final_standard_hard_chance"},
        {"last_final_standard_hard_chance1"},
        {"last_final_standard_hard_chance3"},
        {"last_final_standard_hard_chance5"},
        {"last_final_standard_hard_remove_hardskill"},
    ]
    got = set(parsed_names)
    if got not in allowed_sets:
        allowed_str = " | ".join([", ".join(sorted(s)) for s in allowed_sets])
        raise ValueError(
            "Parsed dir_names are not allowed.\n"
            f"got: {', '.join(sorted(got)) if got else '(empty)'}\n"
            f"allowed sets: {allowed_str}"
        )

def process_one_dir(dir_path: str, prefix: str, hard_set: set, act_mode: str = None):
    files = collect_files(dir_path, prefix)
    runs_by_cat_board = [defaultdict(list) for _ in range(len(CATS))]

    bad_records = 0

    # For add_info: gather distinct observed add_info values
    add_info_values = set()

    for fp in files:
        data = load_json(fp)
        if not isinstance(data, list):
            raise ValueError(f"{fp}: JSON list expected")

        for j, r in enumerate(data):
            task = (r or {}).get("task") or {}
            if not isinstance(task, dict):
                raise ValueError(f"{fp}[{j}]: task must be a dict, got {type(task)}")

            # ---- add_info enforcement ----
            if act_mode == "one_act":
                if task.get("add_info") != "only_one_action":
                    raise ValueError(
                        f"{fp}[{j}]: results_dir has 'one_act' but task.add_info != 'only_one_action' "
                        f"(got {task.get('add_info')!r})"
                    )
            elif act_mode == "multi_act":
                if "add_info" in task:
                    raise ValueError(
                        f"{fp}[{j}]: results_dir has 'multi_act' but task.add_info must NOT exist "
                        f"(got {task.get('add_info')!r})"
                    )
            # --------------------------------

            # Gather add_info value
            add_info_val = task.get("add_info")
            if add_info_val is not None:
                add_info_values.add(str(add_info_val))
            else:
                add_info_values.add("None")

            obs = (r or {}).get("last_observation") or {}
            steps = r.get("steps") or []

            board = task.get("initial_board")
            board_norm = _normalize_board(board)
            if not board_norm:
                bad_records += 1
                continue

            try:
                ci = category_idx(task, board_norm, hard_set)
            except Exception:
                bad_records += 1
                continue

            term = obs.get("termination_reason", "") or ""
            success = ("COMPLETE" in term)

            try:
                prog = float(obs.get("progress", 0.0) or 0.0)
            except Exception:
                prog = 0.0

            eff = None
            if ("COMPLETE" in term) or ("FILLED" in term):
                eff = len(steps)

            runs_by_cat_board[ci][board_norm].append((success, prog, eff))

    return files, runs_by_cat_board, bad_records, add_info_values

def merge_runs(dst_runs, src_runs):
    # cat -> board -> list[(...)] merge
    for ci in range(len(CATS)):
        for b, rs in src_runs[ci].items():
            dst_runs[ci][b].extend(rs)

def aggregate(runs_by_cat_board):
    cnt = [0] * len(CATS)
    pass_at_k = [0.0] * len(CATS)
    avg_succ = [0.0] * len(CATS)
    avg_prog = [0.0] * len(CATS)
    avg_eff = [0.0] * len(CATS)

    for ci in range(len(CATS)):
        boards = runs_by_cat_board[ci]
        cnt[ci] = len(boards)
        if not boards:
            continue

        pass_rate_list = []
        pass_at_k_list = []
        prog_mean_list = []
        eff_steps_all = []

        for rs in boards.values():
            n = len(rs)
            pass_rate_list.append(sum(s for s, _, _ in rs) / n if n else 0.0)
            pass_at_k_list.append(any(s for s, _, _ in rs))
            prog_mean_list.append(sum(p for _, p, _ in rs) / n if n else 0.0)
            for _, _, e in rs:
                if e is not None:
                    eff_steps_all.append(e)

        pass_at_k[ci] = float(np.mean(pass_at_k_list)) * 100.0
        avg_succ[ci] = float(np.mean(pass_rate_list)) * 100.0
        avg_prog[ci] = float(np.mean(prog_mean_list)) * 100.0
        avg_eff[ci] = float(np.mean(eff_steps_all)) if eff_steps_all else 0.0

    return cnt, pass_at_k, avg_succ, avg_prog, avg_eff

def out_counts_in_out_order(cnt):
    return [cnt[i] for i in OUT_ORDER]

def print_block(title, files_n, bad_records, cnt, pass_at_k, avg_succ, avg_prog, avg_eff, add_info_set=None):
    # Print dir name in blue inside the === ... === headline
    headline = f"=== {title} ==="
    if title != "FINAL_MERGED":
        print_blue(headline)
    else:
        print(headline)
    print(f"files={files_n}\tbad_records={bad_records}\ttotal_boards={sum(cnt)}")

    # Print add_info values per folder
    if add_info_set is not None:
        add_info_list = sorted(add_info_set)
        print(f"add_info values: {', '.join(list(set(add_info_list)))}")

    print("\t".join(CATS[i][0] for i in OUT_ORDER))
    print("\t".join(str(cnt[i]) for i in OUT_ORDER))
    print("\t".join(f"{pass_at_k[i]:.2f}" for i in OUT_ORDER))
    print("\t".join(f"{avg_succ[i]:.2f}" for i in OUT_ORDER))
    print("\t".join(f"{avg_prog[i]:.2f}" for i in OUT_ORDER))
    print("\t".join(f"{avg_eff[i]:.2f}" for i in OUT_ORDER))

def check_expected(folder_name: str, cnt):
    expected = EXPECTED_COUNTS_BY_FOLDER.get(folder_name)
    if expected is None:
        return

    real = out_counts_in_out_order(cnt)
    if real != expected:
        # warning (stdout) - print in red
        print_red(f"WARNING: count mismatch for {folder_name}")
        print_red("expected:\t" + "\t".join(map(str, expected)))
        print_red("real:\t\t" + "\t".join(map(str, real)))

# -----------------------------
# NEW UTILITY: max k assertor
# variant category is an exception for max_k checking
# -----------------------------
def _cat_label(ci: int) -> str:
    # CATS[i] = (name, lo, hi)
    try:
        return CATS[ci][0]
    except Exception:
        return f"cat_{ci}"

def assert_max_k_is_4(runs_by_cat_board, *, context: str = ""):
    """
    runs_by_cat_board: List[Dict[board_norm, List[runs]]]
    Enforce max k == 4.
    If violated, print per-category max k, then raise.
    Note: The variant category is excepted from max_k check.
    """
    per_cat_maxk = []
    global_maxk = 0

    # Find the index of the variant category
    variant_idx = None
    for i, cat in enumerate(CATS):
        if cat[0] == "variant":
            variant_idx = i
            break

    for ci, cat_runs in enumerate(runs_by_cat_board):
        mk = 0
        for rs in cat_runs.values():
            if len(rs) > mk:
                mk = len(rs)
        per_cat_maxk.append(mk)
        # Exclude variant category from checking
        if ci == variant_idx:
            continue
        if mk > global_maxk:
            global_maxk = mk

    if global_maxk != 4:
        header = f"[max_k violation]{' ' + context if context else ''}"
        print_red(header)

        # Print max k per category (following OUT_ORDER for clarity)
        lines = []
        for i in OUT_ORDER:
            lines.append(f"{_cat_label(i)}={per_cat_maxk[i]}")
        print_red("per-category max k: " + ", ".join(lines))

        raise ValueError(f"Invalid max k detected (except variant category): max k = {global_maxk}, expected exactly 4")

# -----------------------------
# Main
# -----------------------------
def main():
    args = parse_args()

    # ---- NEW: rebuild categories if requested ----
    global CATS, CAT_HARD_IDX, CAT_VARIANT_IDX, BIN_OFFSET, OUT_ORDER, HARD_BIN_OFFSET
    CATS, CAT_HARD_IDX, CAT_VARIANT_IDX, BIN_OFFSET, OUT_ORDER, HARD_BIN_OFFSET = _build_cats(
        args.for_horizon_generalization
    )
    # --------------------------------------------

    # Enforce add_info constraints based on results_dir name
    rd = str(args.results_dir)
    has_one = "one_act" in rd
    has_multi = "multi_act" in rd
    if has_one and has_multi:
        raise ValueError(f'results_dir contains both "one_act" and "multi_act": {rd}')
    act_mode = "one_act" if has_one else ("multi_act" if has_multi else None)

    # CHANGED: now provides args.folder_name
    subdirs = list_target_subdirs(args.results_dir, args.folder_name)
    if not subdirs:
        names = ", ".join(sorted(EXPECTED_COUNTS_BY_FOLDER.keys()))
        raise ValueError(
            f"No matching subdirs found under {args.results_dir}. "
            f"folder_name={args.folder_name!r}. "
            f"Expected one of: {names}"
        )

    # ---- NEW: validate parsed dir_names ----
    _validate_dir_names([n for n, _ in subdirs])
    # ---------------------------------------

    hard_set = load_hard_board_set(args.hard_repo, args.hard_split)

    # global merged runs
    global_runs = [defaultdict(list) for _ in range(len(CATS))]
    global_files_n = 0
    global_bad = 0

    # per-folder
    for name, path in subdirs:
        files, runs, bad, add_info_set = process_one_dir(path, args.prefix, hard_set, act_mode=act_mode)

        # ---- enforce max k == 4 (except variant category) ----
        if not args.ignore_maxk:
            assert_max_k_is_4(runs, context=f"folder={name}")
        # --------------------------------

        cnt, pass_at_k, avg_succ, avg_prog, avg_eff = aggregate(runs)

        print_block(name, len(files), bad, cnt, pass_at_k, avg_succ, avg_prog, avg_eff, add_info_set=add_info_set)

        # NOTE: expected-count check only makes sense for the original category layout
        if not args.for_horizon_generalization:
            check_expected(name, cnt)
        print()

        merge_runs(global_runs, runs)
        global_files_n += len(files)
        global_bad += bad

    # final merged
    # Optionally also check merged runs for max k constraint; except variant category
    if not args.ignore_maxk:
        assert_max_k_is_4(global_runs, context="folder=FINAL_MERGED")

    g_cnt, g_pass_at_k, g_avg_succ, g_avg_prog, g_avg_eff = aggregate(global_runs)
    print_block("FINAL_MERGED", global_files_n, global_bad, g_cnt, g_pass_at_k, g_avg_succ, g_avg_prog, g_avg_eff)

if __name__ == "__main__":
    main()
