#!/usr/bin/env python
"""Comprehensive analysis (v2) for GRPO TicTacToe inference results.

Updated Key Features (v2.1 modifications in this commit):
* Simultaneous loading of ASCII and non-ASCII representation result trees.
* Robust filename parsing (model, rep mode, dataset variant, checkpoint).
* Aggregation of full inference results (all checkpoints) with metadata: ascii_flag, random_xy flag, training_group, experiment_mode.
* Wilson confidence intervals + derived metrics (precision under pressure, distraction index).
* Best checkpoint selection now groups by: (model_name, representation_mode, training_group, dataset_variant, experiment_mode, random_xy_moves, ascii_board) with tie-break on higher checkpoint when accuracies equal.
* Plots now separated by dataset_variant (e.g., random_80_10_10 vs canconical-symmetry-grouping) to avoid mixed grouping confusion.
* Open spaces plot normalization fixes missing 8 & 9 by offsetting earliest observed ply.
* Minimax outcome heatmaps include fixed expected score bins [-4,-2,0,1,3,5] and are generated per dataset_variant & experiment_mode.
* Progression plot baselines (majority legal move proportion & per-model off-the-shelf) made optional via CLI flags (default disabled) and legend repositioned.
* Outcome score heatmaps and open space curves output nested directories per dataset_variant.
"""
from __future__ import annotations
import os
import re
import json
import math
import argparse
from dataclasses import dataclass
from typing import Dict, Any, List, Optional, Tuple

# Added missing imports (safe even if already present in later fuller version)
import os, json, math, argparse, re
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D

# ----------------------------- Config ----------------------------------
REP_MODES = ["nl", "special"]
LOW_CHOICE_SET = {1,2,3}
HIGH_CHOICE_SET = {6,7,8}
RANDOM_MOVE_UNIVERSE_SIZE = 19  # moves 1..18 plus None token

# ----------------------------- Data Classes ----------------------------
@dataclass
class ParsedFilename:
    model_name: str
    representation_mode: str
    dataset_variant: str  # e.g. random_80_10_10, canconical-symmetry-grouping, off_the-shelf, etc.
    checkpoint_step: int

# ----------------------------- Parsing ---------------------------------

def parse_result_filename(filename: str, rep_modes: List[str] = REP_MODES) -> Optional[ParsedFilename]:
    """Parse filenames like:
    Qwen_Qwen2.5-0.5B-Instruct_nl_random_80_10_10-checkpoint-1050_results.json
    meta-llama_Llama-3.2-1B-Instruct_nl_canconical-symmetry-grouping-checkpoint-1050_results.json
    """
    if not filename.endswith('_results.json'):
        return None
    base = filename[:-len('_results.json')]
    if '-checkpoint-' not in base:
        return None
    left, ckpt_str = base.rsplit('-checkpoint-', 1)
    try:
        ckpt = int(ckpt_str)
    except ValueError:
        return None
    # dynamic anchor on rep mode token surrounded by underscores
    for mode in rep_modes:
        anchor = f"_{mode}_"
        if anchor in left:
            model_part, rest = left.split(anchor, 1)
            model_name = model_part
            dataset_variant = rest
            return ParsedFilename(model_name=model_name,
                                  representation_mode=mode,
                                  dataset_variant=dataset_variant,
                                  checkpoint_step=ckpt)
    # If rep mode is terminal segment (rare) e.g. model_special-random? fallback regex
    regex = re.compile(rf"^(?P<model>.+?)_({'|'.join(rep_modes)})_(?P<variant>.+)$")
    m = regex.match(left)
    if m:
        mode = next((rm for rm in rep_modes if f"_{rm}_" in left), rep_modes[0])
        return ParsedFilename(model_name=m.group('model'),
                              representation_mode=mode,
                              dataset_variant=m.group('variant'),
                              checkpoint_step=ckpt)
    return None

# ----------------------------- JSON Extraction -------------------------

def safe_get(d: Dict[str, Any], *path, default=None):
    cur = d
    for p in path:
        if not isinstance(cur, dict) or p not in cur:
            return default
        cur = cur[p]
    return cur

POSSIBLE_OVERALL_KEYS = [
    ("overall_accuracy",),
    ("overall", "accuracy"),
    ("summary", "overall_accuracy"),
    ("overall_stats", "accuracy_percent"),
]
POSSIBLE_CORRECT_TOTAL_KEYS = [
    ("overall_correct", "overall_total"),
    ("overall", "correct_total"),  # (tuple) maybe
    ("summary", "overall_correct"),
    ("overall_stats", "correct_predictions"),
]

FINE_GRAINED_KEYS = {
    'by_game_ply': ['by_game_ply', 'accuracy_by_game_ply'],
    'by_num_legal_moves': ['by_num_legal_moves', 'accuracy_by_num_legal_moves'],
    'by_num_best_moves': ['by_num_best_moves', 'accuracy_by_num_best_moves'],
    'by_minimax_outcome_score': ['by_minimax_outcome_score', 'accuracy_by_minimax_outcome_score'],
}

def extract_overall_accuracy(result: Dict[str, Any]) -> Optional[float]:
    for path in POSSIBLE_OVERALL_KEYS:
        val = safe_get(result, *path)
        if isinstance(val, (int, float)):
            # Stored accuracy might already be percent (overall_stats.accuracy_percent) or fraction.
            # Assume >1 implies percent.
            return float(val)
    return None


def extract_correct_total(result: Dict[str, Any]) -> Tuple[Optional[int], Optional[int]]:
    # Direct overall_stats schema
    correct = safe_get(result, 'overall_stats', 'correct_predictions')
    total = safe_get(result, 'overall_stats', 'total_samples')
    if isinstance(correct, int) and isinstance(total, int):
        return correct, total
    # Legacy explicit keys
    c2 = safe_get(result, 'overall_correct'); t2 = safe_get(result, 'overall_total')
    if isinstance(c2, int) and isinstance(t2, int):
        return c2, t2
    # summary fallback
    c3 = safe_get(result, 'summary', 'overall_correct'); t3 = safe_get(result, 'summary', 'overall_total')
    if isinstance(c3, int) and isinstance(t3, int):
        return c3, t3
    # Derive from accuracy + total_samples
    acc = extract_overall_accuracy(result)
    total_samples = safe_get(result, 'overall_stats', 'total_samples') or safe_get(result, 'total_samples') or safe_get(result, 'summary', 'total_samples')
    if acc is not None and isinstance(total_samples, int):
        # If acc > 1 treat as percent
        correct_est = int(round((acc/100 if acc>1 else acc) * total_samples))
        return correct_est, total_samples
    return None, None


def extract_fine_grained(result: Dict[str, Any]) -> Dict[str, Dict[str, Dict[str, int]]]:
    # New schema already nests under fine_grained_stats
    root = safe_get(result, 'fine_grained_stats')
    out = {}
    if isinstance(root, dict):
        for canon_key in FINE_GRAINED_KEYS.keys():
            if canon_key in root and isinstance(root[canon_key], dict):
                out[canon_key] = root[canon_key]
    # Fallback legacy keys if not found
    if not out:
        for canon_key, candidates in FINE_GRAINED_KEYS.items():
            for c in candidates:
                data = safe_get(result, c)
                if isinstance(data, dict):
                    out[canon_key] = data
                    break
    return out

# ----------------------------- Statistics ------------------------------

def wilson_ci(correct: int, total: int, z: float = 1.96) -> Tuple[float, float, float]:
    if total <= 0:
        return math.nan, math.nan, math.nan
    p = correct / total
    denom = 1 + z**2 / total
    centre = p + z**2 / (2*total)
    margin = z * math.sqrt(p*(1-p)/total + z**2/(4*total**2))
    lower = (centre - margin) / denom
    upper = (centre + margin) / denom
    return p, p - lower, upper - p

# ----------------------------- Loading ---------------------------------

def walk_result_tree(root_dir: str, ascii_flag: bool) -> List[Dict[str, Any]]:
    collected = []
    if not os.path.isdir(root_dir):
        return collected
    for training_group in os.listdir(root_dir):
        tg_path = os.path.join(root_dir, training_group)
        if not os.path.isdir(tg_path):
            continue
        for experiment_mode in os.listdir(tg_path):
            em_path = os.path.join(tg_path, experiment_mode)
            if not os.path.isdir(em_path):
                continue
            # Some off-the-shelf dumps have an extra nested folder repeating experiment_mode (e.g. legal_move/legal_move/random_xy_moves_False)
            potential_level = [experiment_mode]
            # If inside em_path there is a single directory with same name as experiment_mode, descend one level transparently for iteration
            nested_repeat = os.path.join(em_path, experiment_mode)
            if os.path.isdir(nested_repeat):
                potential_level.append(os.path.join(experiment_mode, experiment_mode))  # informational only
            rx_parent_paths = []
            if os.path.isdir(nested_repeat):
                # iterate random_xy inside nested repeat
                for rx in os.listdir(nested_repeat):
                    rx_parent_paths.append(os.path.join(nested_repeat, rx))
            # Also iterate random_xy directories directly under em_path (normal structure)
            for rx in os.listdir(em_path):
                candidate = os.path.join(em_path, rx)
                if os.path.isdir(candidate) and rx.startswith('random_xy_moves_'):
                    rx_parent_paths.append(candidate)
            # De-duplicate
            rx_parent_paths = list({p: True for p in rx_parent_paths}.keys())
            for rx_path in rx_parent_paths:
                random_xy_dir = os.path.basename(rx_path)
                random_xy_flag = True if random_xy_dir.endswith('True') else False if random_xy_dir.endswith('False') else None
                for fname in os.listdir(rx_path):
                    if not fname.endswith('_results.json'):
                        continue
                    full_path = os.path.join(rx_path, fname)
                    try:
                        with open(full_path, 'r') as f:
                            data = json.load(f)
                    except Exception:
                        continue
                    parsed = parse_result_filename(fname)
                    baseline_no_ckpt = False
                    if not parsed:
                        # Handle off-the-shelf files without checkpoint: pattern <model>_results.json
                        if training_group.lower().startswith('off_the') and '-checkpoint-' not in fname:
                            # Infer model & representation_mode from metadata if present
                            meta = data.get('metadata', {}) if isinstance(data, dict) else {}
                            raw_model = meta.get('model_mark') or meta.get('model_checkpoint') or fname.replace('_results.json','')
                            rep_mode = meta.get('representation_mode', 'nl')
                            parsed = ParsedFilename(model_name=raw_model,
                                                    representation_mode=rep_mode,
                                                    dataset_variant='off_the_shelf',
                                                    checkpoint_step=0)
                            baseline_no_ckpt = True
                        else:
                            continue  # skip unparseable non-baseline file
                    # Accuracy extraction
                    reported_acc = extract_overall_accuracy(data)
                    correct, total = extract_correct_total(data)
                    overall_acc = None
                    fine = extract_fine_grained(data)
                    p = err_low = err_high = None
                    if correct is not None and total is not None:
                        p, err_low, err_high = wilson_ci(correct, total)
                        overall_acc = p * 100.0
                    elif reported_acc is not None:
                        overall_acc = reported_acc if reported_acc <= 100 else reported_acc / 100.0
                        if overall_acc <= 1:
                            overall_acc *= 100.0
                    indiv_outputs = None
                    # Attempt to pull individual_outputs from top-level or nested fine stats
                    try:
                        indiv_outputs = data.get('individual_outputs') if isinstance(data, dict) else None
                        if not indiv_outputs and isinstance(fine, dict):
                            # Sometimes might be nested under fine_grained or stats; heuristics
                            for key in ['individual_outputs','per_board_outputs']:
                                cand = fine.get(key) if isinstance(fine, dict) else None
                                if isinstance(cand, dict):
                                    indiv_outputs = cand
                                    break
                        if indiv_outputs is not None and not isinstance(indiv_outputs, dict):
                            indiv_outputs = None
                    except Exception:
                        indiv_outputs = None
                    record = {
                        'filepath': full_path,
                        'model_name': clean_model_name(parsed.model_name),
                        'raw_model_name': parsed.model_name,
                        'representation_mode': parsed.representation_mode,
                        'dataset_variant': parsed.dataset_variant,
                        'checkpoint_step': parsed.checkpoint_step,
                        'training_group': training_group,
                        'experiment_mode': experiment_mode,
                        'random_xy_moves': random_xy_flag,
                        'ascii_board': ascii_flag,
                        'overall_accuracy': overall_acc,
                        'overall_correct': correct,
                        'overall_total': total,
                        'ci_err_low': None if err_low is None else abs(err_low*100.0),
                        'ci_err_high': None if err_high is None else abs(err_high*100.0),
                        'fine_grained_stats': fine,
                        'individual_outputs': indiv_outputs,
                        'individual_outputs_size': None if indiv_outputs is None else len(indiv_outputs),
                        'is_zero_shot': detect_zero_shot(training_group, parsed.dataset_variant, parsed.model_name),
                        'is_off_the_shelf_file': baseline_no_ckpt
                    }
                    collected.append(record)
    return collected

# ----------------------------- Helpers ---------------------------------

def clean_model_name(name: str) -> str:
    cleaned = name.replace('__', '_')
    # Normalize meta-llama variants
    # Cases: meta-llama-Llama-3.2-1B-Instruct, meta-llama_Llama-3.2-1B-Instruct, Llama-Llama-3.2-1B-Instruct
    # Goal: Llama-3.2-1B-Instruct (single leading Llama-)
    cleaned = re.sub(r'^meta-llama[-_]+', '', cleaned, flags=re.IGNORECASE)
    # If string starts with Llama-Llama- collapse duplication
    cleaned = re.sub(r'^Llama-+', 'Llama-', cleaned)
    # Remove leading duplicated model vendor prefix e.g. Llama-Llama-
    cleaned = cleaned.replace('Llama-Llama-', 'Llama-')
    # Qwen normalization
    cleaned = cleaned.replace('Qwen_', 'Qwen-')
    # Final pass: if still contains meta-llama anywhere remove that token
    cleaned = cleaned.replace('meta-llama-', '').replace('meta-llama_', '')
    return cleaned

def detect_off_the_shelf(training_group: str, dataset_variant: str, model_name: str) -> bool:
    """Detect off-the-shelf (baseline) models.
    We previously labeled these as 'zero shot'; semantics clarified: baseline rows have
    training group / dataset variant / model markers containing off_the_shelf style tokens.
    """
    hints = ['off_the-shelf', 'off_the_shelf', 'offtheshelf']
    combo = f"{training_group}|{dataset_variant}|{model_name}".lower()
    return any(h in combo for h in hints)

# Backwards compatibility: keep name expected by older code paths if any.
def detect_zero_shot(training_group: str, dataset_variant: str, model_name: str) -> bool:  # noqa: D401
    return detect_off_the_shelf(training_group, dataset_variant, model_name)

# ----------------------------- Derived Metrics -------------------------

def compute_precision_under_pressure(row: pd.Series) -> Optional[float]:
    stats = row.get('fine_grained_stats') or {}
    by_nbm = stats.get('by_num_best_moves', {})
    one = by_nbm.get('1', {})
    if not one:
        return None
    correct = one.get('correct') or one.get('num_correct') or one.get('correct_count')
    total = one.get('total') or one.get('num_total') or one.get('count')
    if not correct or not total:
        return None
    return 100 * correct / total if total else None

def compute_distraction_index(row: pd.Series) -> Optional[float]:
    stats = row.get('fine_grained_stats') or {}
    by_nlm = stats.get('by_num_legal_moves', {})
    if not by_nlm:
        return None
    def agg(selected):
        c = t = 0
        for k, v in by_nlm.items():
            try:
                n = int(k)
            except Exception:
                continue
            if n in selected:
                corr = v.get('correct') or v.get('num_correct') or 0
                tot = v.get('total') or v.get('num_total') or 0
                c += corr; t += tot
        return (100*c/t) if t else None
    low = agg(LOW_CHOICE_SET)
    high = agg(HIGH_CHOICE_SET)
    if low is None or high is None:
        return None
    return low - high

def explode_ply_stats(df: pd.DataFrame) -> pd.DataFrame:
    """Explode per-ply accuracy stats.

    We now treat recorded ply indices as raw (no re-basing). Open spaces are a
    direct 9 - ply_i mapping for a 3x3 board; values outside 0..9 are clipped.
    """
    rows = []
    for _, r in df.iterrows():
        stats = r.get('fine_grained_stats', {}) or {}
        by_ply = stats.get('by_game_ply', {}) or {}
        for ply, vals in by_ply.items():
            try:
                ply_i = int(ply)
            except Exception:
                continue
            correct = (vals.get('correct') or vals.get('num_correct') or 0) or 0
            total = (vals.get('total') or vals.get('num_total') or 0) or 0
            acc = 100 * correct / total if total else np.nan
            open_spaces = 9 - ply_i
            if open_spaces < 0:
                open_spaces = 0
            elif open_spaces > 9:
                open_spaces = 9
            rows.append({
                'model_name': r.model_name,
                'training_group': r.training_group,
                'experiment_mode': r.experiment_mode,
                'representation_mode': r.representation_mode,
                'dataset_variant': r.dataset_variant,
                'ascii_board': r.ascii_board,
                'random_xy_moves': r.random_xy_moves,
                'checkpoint_step': r.checkpoint_step,
                'ply': ply_i,
                'open_spaces': open_spaces,
                'ply_accuracy': acc
            })
    return pd.DataFrame(rows)

def explode_open_space_stats(df: pd.DataFrame, mode: str = 'empties') -> pd.DataFrame:
    """Explode accuracy by open spaces using the raw individual_outputs board keys.

    Parameters
    ----------
    df : DataFrame
        Source results (full or best checkpoints subset).
    mode : {'empties','inverted'}
        'empties'  -> open_spaces = number of empty cells (zeros)  (0..9) [standard]
        'inverted' -> open_spaces = 9 - number_of_empties          (legacy earlier impl)

    Each row in df may have 'individual_outputs':
        { "[b0, b1, ..., b8]": [ {"is_correct": bool, ...}, ... ], ... }

    Returns columns:
        model_name, training_group, experiment_mode, representation_mode,
        dataset_variant, ascii_board, random_xy_moves, checkpoint_step,
        open_spaces, correct, total, open_space_accuracy
    """
    out_rows = []
    for _, r in df.iterrows():
        indiv = r.get('individual_outputs') or r.get('fine_grained_stats', {}).get('individual_outputs') or {}
        if not isinstance(indiv, dict) or not indiv:
            continue
        agg = {}  # open_spaces -> [correct, total]
        for board_key, attempts in indiv.items():
            if not isinstance(attempts, list):
                continue
            # Parse list from string like "[0, 2, 1, ...]"
            try:
                import ast 
                board_key_list = ast.literal_eval(board_key)
                # get number of 0s in list
                num_zeros = 0
                for v in board_key_list:
                    if v == 0:
                        num_zeros += 1
            except Exception:
                continue
            open_spaces = num_zeros
            c = 0
            t = 0
            for attempt in attempts:
                if not isinstance(attempt, dict):
                    continue
                is_correct = attempt.get('is_correct')
                if is_correct is True:
                    c += 1; t += 1
                elif is_correct is False:
                    t += 1
                # if None / missing, ignore
            if t == 0:
                continue
            prev = agg.get(open_spaces, [0,0])
            agg[open_spaces] = [prev[0] + c, prev[1] + t]
        for os_val, (c_sum, t_sum) in agg.items():
            acc = 100 * c_sum / t_sum if t_sum else 0
            out_rows.append({
                'model_name': r.model_name,
                'training_group': r.training_group,
                'experiment_mode': r.experiment_mode,
                'representation_mode': r.representation_mode,
                'dataset_variant': r.dataset_variant,
                'ascii_board': getattr(r, 'ascii_board', None),
                'random_xy_moves': getattr(r, 'random_xy_moves', None),
                'checkpoint_step': r.checkpoint_step,
                'open_spaces': os_val,
                'correct': c_sum,
                'total': t_sum,
                'open_space_accuracy': acc
            })
        print(f"[info] Processed individual_outputs for {r.model_name} ckpt {r.checkpoint_step}, found {len(agg)} open_spaces groups.")
    return pd.DataFrame(out_rows)

# ----------------------------- Best Checkpoints ------------------------

def select_best_checkpoints(df: pd.DataFrame) -> pd.DataFrame:
    if df.empty:
        return df
    base_group_cols = ['model_name','representation_mode','training_group','dataset_variant','experiment_mode','random_xy_moves','ascii_board']
    # Only keep grouping columns that actually exist to avoid KeyErrors
    group_cols = [c for c in base_group_cols if c in df.columns]
    work = df[pd.to_numeric(df['overall_accuracy'], errors='coerce').notnull()].copy()
    # Guard: if no grouping columns present, just return best (highest accuracy overall)
    if not group_cols:
        return work.sort_values(['overall_accuracy','checkpoint_step'], ascending=[False, False]).head(1).copy()
    # Sort so that higher accuracy first, then higher checkpoint_step as tie-breaker
    sort_cols = group_cols + ['overall_accuracy','checkpoint_step']
    ascending_flags = [True]*len(group_cols) + [False, False]
    work = work.sort_values(sort_cols, ascending=ascending_flags)
    # Keep first per group (after sorting) -> highest accuracy, then highest checkpoint
    best = work.groupby(group_cols, as_index=False).head(1).copy()
    return best

# ----------------------------- Plotting Utils --------------------------

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def save_all_formats(png_path: str, formats: List[str], **savefig_kwargs):
    """Save the current Matplotlib figure to multiple formats.

    Parameters
    ----------
    png_path : str
        Existing code constructs a .png path; we derive the base name from it.
    formats : List[str]
        List of format extensions (e.g. ['png','pdf']).
    **savefig_kwargs : dict
        Extra kwargs passed to plt.savefig (e.g., bbox_inches='tight').
    """
    base, _ = os.path.splitext(png_path)
    for fmt in formats:
        out = f"{base}.{fmt}"
        try:
            plt.savefig(out, **savefig_kwargs)
            print(f"[saved] {out}")
        except Exception as e:
            print(f"[warn] Failed to save {out}: {e}")

def simplify_variant(variant: str) -> str:
    """Map raw dataset_variant strings to concise canonical folder names.

    Examples:
    - random_80_10_10 -> random
    - canconical-symmetry-grouping (and spelling variants) -> canonical_symmetry
    - off_the_shelf -> off_the_shelf (unchanged)
    Fallback: return original variant if no rule matches.
    """
    v = (variant or '').lower()
    if 'random' in v:
        return 'random'
    if 'cancon' in v or 'canon' in v:  # tolerate misspellings
        return 'canonical_symmetry'
    return variant

def canonicalize_dataset_variant(variant: str) -> str:
    """Return a canonical form of dataset_variant so that spurious suffixes like
    *_best_move or *_legal_move appended inside some training group result paths
    don't cause artificial splitting of what is logically the same dataset.

    Transformations:
    - Strip trailing _best_move / _legal_move (case insensitive).
    - Normalize common misspelling 'canconical' -> 'canonical'.
    - Unify off-the-shelf token variants (off-the-shelf, off_the-shelf) -> off_the_shelf.
    - Leave random_80_10_10 untouched besides suffix stripping.
    The function preserves the original core token so downstream simplify_variant
    can still map to short folder names.
    """
    if not isinstance(variant, str):
        return variant
    v = variant.strip()
    low = v.lower()
    # Remove experiment-mode suffices erroneously baked into variant names
    low = re.sub(r'(?:_(best|legal)_?move)$', '', low)
    # Fix common misspelling
    low = low.replace('canconical', 'canonical')
    # Unify off the shelf forms
    if 'off' in low and 'shelf' in low:
        low = 'off_the_shelf'
    return low

def save_df(df: pd.DataFrame, path: str):
    try:
        df.to_csv(path, index=False)
        print(f"[saved] {path}")
    except Exception as e:
        print(f"[warn] Could not save {path}: {e}")

sns.set_theme(style='whitegrid')

# ----------------------------- Global Palette -------------------------
MODEL_COLOR_MAP: Dict[str, str] = {}

# def init_model_colors(model_names: List[str]):
#     """Initialize a deterministic color map for all models.

#     Uses a seaborn color palette with cycling if more models than base colors.
#     Stable ordering is alphabetical on model name for reproducibility.
#     """
#     global MODEL_COLOR_MAP
#     if MODEL_COLOR_MAP:
#         return  # already initialized
#     base_palette = sns.color_palette('tab20') + sns.color_palette('Set3') + sns.color_palette('husl', 12)
#     uniq = sorted(set(model_names))
#     cmap = {}
#     for idx, name in enumerate(uniq):
#         cmap[name] = base_palette[idx % len(base_palette)]
#     MODEL_COLOR_MAP = cmap


def init_model_colors(model_names: List[str]):
    """
    Initialize a deterministic color map for all models.

    For a small number of models (<= 10), it uses a high-contrast,
    colorblind-friendly palette. For more, it falls back to a cycling
    palette to accommodate the larger number.
    """
    global MODEL_COLOR_MAP
    if MODEL_COLOR_MAP:
        return  # already initialized

    uniq = sorted(set(model_names))
    num_models = len(uniq)
    
    if 0 < num_models <= 10:
        # Use a palette designed for a few, distinct categories
        # 'colorblind' is an excellent, accessible choice.
        # Other great options: 'deep', 'muted', 'bright'
        base_palette = sns.color_palette('colorblind', n_colors=num_models)
    else:
        # Fallback to your original logic for many models
        base_palette = sns.color_palette('tab20') + sns.color_palette('Set3') + sns.color_palette('husl', 12)

    cmap = {}
    for idx, name in enumerate(uniq):
        # The modulo is still good practice for the fallback case
        cmap[name] = base_palette[idx % len(base_palette)]
    MODEL_COLOR_MAP = cmap

def get_model_palette(subset_models: List[str]) -> List[str]:
    return [MODEL_COLOR_MAP.get(m, 'gray') for m in subset_models]

# Individual plot functions (initial minimal implementations) -------------

def _compute_random_baseline(subset_full: pd.DataFrame, exp_mode: str) -> Optional[float]:
    """Uniform(19) baseline: choose uniformly among 19 tokens (moves 1..18 plus None).

    Probability(correct | state) = (#acceptable_labels)/19, where acceptable labels:
      - legal_move: number of legal moves (n); terminal (n=0) => only None is valid -> 1.
      - best_move: number of best moves (b); if b=0 treat as 0.
    We approximate using aggregated fine_grained_stats distributions.
    """
    if subset_full.empty:
        return None
    rep = None
    for r in subset_full.itertuples():
        fg = getattr(r, 'fine_grained_stats', {})
        if fg:
            rep = r
            break
    if rep is None:
        return None
    fg = rep.fine_grained_stats
    by_legal = fg.get('by_num_legal_moves', {}) or {}
    by_best = fg.get('by_num_best_moves', {}) or {}
    total_states = sum(v.get('total',0) for v in by_legal.values()) or sum(v.get('total',0) for v in by_best.values())
    if total_states == 0:
        return None
    acc_sum = 0.0
    if exp_mode == 'legal_move':
        for k,v in by_legal.items():
            try: n = int(k)
            except Exception: continue
            count = v.get('total',0)
            labels = 1 if n==0 else max(0,n)  # terminal => only None
            acc_sum += count * (labels / RANDOM_MOVE_UNIVERSE_SIZE)
    else:
        for k,v in by_best.items():
            try: b = int(k)
            except Exception: continue
            count = v.get('total',0)
            labels = max(0,b)
            acc_sum += count * (labels / RANDOM_MOVE_UNIVERSE_SIZE)
    return 100.0 * acc_sum / total_states if total_states else None

def _compute_off_the_shelf_baseline(best_subset: pd.DataFrame) -> Optional[float]:
    """Mean accuracy over off-the-shelf (baseline) best checkpoint rows.
    If none present, returns None silently (caller should skip drawing baseline).
    """
    if best_subset.empty:
        return None
    col = 'is_off_the_shelf' if 'is_off_the_shelf' in best_subset.columns else 'is_zero_shot'
    subset = best_subset[best_subset[col] == True]
    if subset.empty:
        return None
    vals = pd.to_numeric(subset['overall_accuracy'], errors='coerce').dropna()
    if vals.empty:
        return None
    return float(vals.mean())

def _compute_majority_move_baseline(subset_full: pd.DataFrame) -> Optional[float]:
    """Approximate majority-label baseline using fine_grained distributions.
    We approximate probability of most frequent outcome by taking max class proportion across aggregated bins.
    For legal_move: majority of (by_num_legal_moves) terminal vs non-terminal contributes fully; we fallback to overall accuracy if counts missing.
    For best_move: use by_num_best_moves distribution and assume model guessing the most frequent best-move cardinality then uniform within it -> expected accuracy = (max_count / total_states) * ( avg_best_moves_within_argmax / avg_legal_moves_within_argmax ). Simplified to proportion of argmax bucket when deeper info missing.
    """
    if subset_full.empty:
        return None
    # pick representative stats row
    rep = None
    for r in subset_full.itertuples():
        fg = getattr(r, 'fine_grained_stats', {})
        if fg:
            rep = r
            break
    if rep is None:
        return None
    fg = rep.fine_grained_stats
    by_legal = fg.get('by_num_legal_moves', {})
    total = sum(v.get('total',0) for v in by_legal.values())
    if total > 0:
        # majority bucket proportion (upper bound baseline)
        maj = max((v.get('total',0) for v in by_legal.values()), default=0)
        return 100.0 * maj / total if total else None
    return None

def plot_overall_bar(best_df: pd.DataFrame, outdir: str, full_df: Optional[pd.DataFrame] = None, formats: Optional[List[str]] = None):
    """Overall accuracy bars using seaborn's built-in 95% CI (where replicates exist).

    Baselines: most frequent move (dashed red), off-the-shelf mean (dash-dot black), per-model off-the-shelf (dotted lines).
    Off-the-shelf bars are hatched. If only one sample per bar, CI won't appear (expected)."""
    if best_df.empty:
        return
    ensure_dir(outdir)
    plot_df = best_df.copy()
    # Guarantee numeric
    plot_df['overall_accuracy'] = pd.to_numeric(plot_df['overall_accuracy'], errors='coerce')
    for exp_mode in sorted(plot_df['experiment_mode'].unique()):
        sub_mode = plot_df[plot_df['experiment_mode'] == exp_mode]
        if sub_mode.empty:
            continue
        for variant in sorted(sub_mode['dataset_variant'].unique()):
            variant_subset = sub_mode[sub_mode['dataset_variant']==variant].copy()
            if variant_subset.empty:
                continue
            for ascii_flag in sorted(variant_subset['ascii_board'].unique()):
                ascii_subset = variant_subset[variant_subset['ascii_board']==ascii_flag]
                for rand_flag in sorted(ascii_subset['random_xy_moves'].dropna().unique()):
                    sub = ascii_subset[ascii_subset['random_xy_moves']==rand_flag].copy()
                    if sub.empty:
                        continue
                    # Trim training groups per mode requirements
                    if exp_mode == 'legal_move':
                        sub = sub[sub['training_group']=='legal_move']
                    else:  # best_move
                        sub = sub[sub['training_group'].isin(['best_move_from_base','best_move_from_legal'])]
                    # Inject off_the_shelf rows for this split. We take baseline rows regardless of dataset_variant
                    # so that off-the-shelf appears as a reference across all variants.
                    if full_df is not None:
                        base_off = full_df[(full_df['experiment_mode']==exp_mode) &
                                           (full_df['ascii_board']==ascii_flag) &
                                           (full_df['random_xy_moves']==rand_flag)]
                        if not base_off.empty:
                            off_mask = base_off['training_group'].str.lower().str.contains('off_the').fillna(False)
                            base_off = base_off[off_mask]
                            if not base_off.empty:
                                group_cols_local = ['model_name','representation_mode']
                                work_off = base_off[pd.to_numeric(base_off['overall_accuracy'], errors='coerce').notnull()].copy()
                                if not work_off.empty:
                                    work_off = work_off.sort_values(group_cols_local + ['overall_accuracy','checkpoint_step'], ascending=[True]*len(group_cols_local)+[False, False])
                                    best_off = work_off.groupby(group_cols_local, as_index=False).head(1)
                                    best_off = best_off.copy()
                                    best_off['training_group'] = 'off_the_shelf'
                                    # Force dataset_variant to current variant for plotting consistency (retain original in raw column if needed)
                                    best_off['dataset_variant'] = variant
                                    # Ensure hatch flag remains
                                    if 'is_zero_shot' not in best_off.columns and 'is_off_the_shelf' in best_off.columns:
                                        best_off['is_zero_shot'] = best_off['is_off_the_shelf']
                                    # Align columns
                                    sub = pd.concat([sub, best_off[sub.columns.intersection(best_off.columns)]], ignore_index=True)
                    if sub.empty:
                        continue
                    variant_simple = simplify_variant(variant)
                    if exp_mode == 'legal_move':
                        desired_order = ['legal_move','off_the_shelf']
                    else:
                        desired_order = ['best_move_from_base','best_move_from_legal','off_the_shelf']
                    existing = list(sub['training_group'].unique())
                    order = [g for g in desired_order if g in existing] + [g for g in sorted(existing) if g not in desired_order]
                    hue_order = sorted(sub['model_name'].unique())
                    if not hue_order:
                        continue
                    # Debug
                    try:
                        dbg_counts = sub.groupby('training_group')['overall_accuracy'].count().to_dict()
                        print(f"[overall_plot_debug] exp_mode={exp_mode} variant={variant} ascii={ascii_flag} rand={rand_flag} groups={dbg_counts}")
                    except Exception:
                        pass
                    plt.figure(figsize=(16,8))
                    ax = sns.barplot(
                        data=sub,
                        x='training_group',
                        y='overall_accuracy',
                        hue='model_name',
                        order=order,
                        hue_order=hue_order,
                        palette=get_model_palette(hue_order),
                        errorbar='ci',
                    )
                    # Hatch off-the-shelf bars
                    patches = ax.patches
                    bars_per_group = len(hue_order)
                    for i, patch in enumerate(patches):
                        group_idx = i // bars_per_group
                        hue_idx = i % bars_per_group
                        if group_idx >= len(order):
                            continue
                        tg = order[group_idx]
                        model = hue_order[hue_idx]
                        row = sub[(sub['training_group']==tg) & (sub['model_name']==model)].head(1)
                        if row.empty:
                            continue
                        r = row.iloc[0]
                        if bool(r.get('is_zero_shot')) or tg == 'off_the_shelf':
                            patch.set_hatch('///')
                            patch.set_edgecolor('black')
                    has_off_bars = any(sub['training_group'].str.lower().str.contains('off_the').fillna(False))
                    rand_baseline = _compute_random_baseline(sub, exp_mode)
                    if rand_baseline is not None:
                        ax.axhline(rand_baseline, linestyle='--', color='red', linewidth=1.2, label='Uniform(19) random baseline')
                    majority_baseline = None
                    if full_df is not None:
                        subset_full = full_df[(full_df['experiment_mode']==exp_mode) & (full_df['ascii_board']==ascii_flag) & (full_df['random_xy_moves']==rand_flag) & (full_df['dataset_variant']==variant)]
                        majority_baseline = _compute_majority_move_baseline(subset_full)
                    if majority_baseline is not None:
                        ax.axhline(majority_baseline, linestyle=':', color='purple', linewidth=1.2, label='Majority legal-move-count proportion')
                    zero_shot_baselines = {}
                    if not has_off_bars:
                        off_mean = _compute_off_the_shelf_baseline(sub)
                        if off_mean is not None:
                            ax.axhline(off_mean, linestyle='-.', color='black', linewidth=1.2, label='Off-the-shelf mean baseline')
                        for model in hue_order:
                            zs_rows = sub[(sub['model_name']==model) & (sub['training_group']=='off_the_shelf')]
                            if zs_rows.empty:
                                continue
                            val = pd.to_numeric(zs_rows['overall_accuracy'], errors='coerce').dropna()
                            if val.empty:
                                continue
                            zero_shot_baselines[model] = float(val.mean())
                        model_color = {}
                        for i, patch in enumerate(ax.patches):
                            group_idx = i // len(hue_order)
                            hue_idx = i % len(hue_order)
                            if hue_idx < len(hue_order):
                                m = hue_order[hue_idx]
                                if m not in model_color:
                                    model_color[m] = patch.get_facecolor()
                            if len(model_color) == len(hue_order):
                                break
                        for m, baseline_val in zero_shot_baselines.items():
                            c = model_color.get(m, 'gray')
                            ax.axhline(baseline_val, linestyle=':', color=c, linewidth=1.2)
                            ax.text(len(order)-0.4, baseline_val+0.3, f'{m} off-the-shelf {baseline_val:.1f}%', color=c, fontsize=8, ha='right', va='bottom')
                    tlabel = f"Overall Accuracy - {exp_mode} | {variant_simple} | ASCII={ascii_flag} | RandomXY={rand_flag}"
                    plt.title(tlabel)
                    plt.ylabel('Accuracy (%)')
                    plt.ylim(0, 105)
                    plt.xticks(rotation=25, ha='right')
                    handles, labels = ax.get_legend_handles_labels()
                    if zero_shot_baselines:
                        handles.append(Line2D([0],[0], color='gray', linestyle=':', label='Per-model off-the-shelf baselines'))
                        labels.append('Per-model off-the-shelf baselines')
                    if any(g for g in sub['training_group'].unique() if 'off_the' in g.lower()):
                        handles.append(mpatches.Patch(facecolor='white', hatch='///', edgecolor='black', label='Off-the-shelf model bar'))
                        labels.append('Off-the-shelf model bar')
                    seen = {}
                    new_h, new_l = [], []
                    for h,lbl in zip(handles, labels):
                        if lbl in seen:
                            continue
                        seen[lbl]=True
                        new_h.append(h); new_l.append(lbl)
                    ax.legend(new_h, new_l, bbox_to_anchor=(1.01,1), loc='upper left', borderaxespad=0)
                    plt.subplots_adjust(right=0.76)
                    plt.tight_layout()
                    vdir = os.path.join(outdir, variant_simple, exp_mode)
                    ensure_dir(vdir)
                    fp = os.path.join(vdir, f'overall_{exp_mode}_ascii-{ascii_flag}_rand-{rand_flag}.png')
                    save_all_formats(fp, formats or ['png'])
                    plt.close()
                    existing = list(sub['training_group'].unique())
                    order = [g for g in desired_order if g in existing] + [g for g in sorted(existing) if g not in desired_order]
                    hue_order = sorted(sub['model_name'].unique())
                    plt.figure(figsize=(16,8))
                    ax = sns.barplot(
                        data=sub,
                        x='training_group',
                        y='overall_accuracy',
                        hue='model_name',
                        order=order,
                        hue_order=hue_order,
                        palette=get_model_palette(hue_order),
                        errorbar='ci',
                    )
                    # Hatch off-the-shelf bars
                    patches = ax.patches
                    bars_per_group = len(hue_order)
                    for i, patch in enumerate(patches):
                        group_idx = i // bars_per_group
                        hue_idx = i % bars_per_group
                        if group_idx >= len(order):
                            continue
                        tg = order[group_idx]
                        model = hue_order[hue_idx]
                        row = sub[(sub['training_group']==tg) & (sub['model_name']==model)].head(1)
                        if row.empty:
                            continue
                        r = row.iloc[0]
                        if bool(r.get('is_zero_shot')):
                            patch.set_hatch('///')
                            patch.set_edgecolor('black')
                    has_off_bars = any(sub['training_group'].str.lower().str.contains('off_the').fillna(False))
                    rand_baseline = _compute_random_baseline(sub, exp_mode)
                    if rand_baseline is not None:
                        ax.axhline(rand_baseline, linestyle='--', color='red', linewidth=1.2, label='Uniform(19) random baseline')
                    majority_baseline = None
                    if full_df is not None:
                        subset_full = full_df[(full_df['experiment_mode']==exp_mode) & (full_df['ascii_board']==ascii_flag) & (full_df['random_xy_moves']==rand_flag) & (full_df['dataset_variant']==variant)]
                        majority_baseline = _compute_majority_move_baseline(subset_full)
                    if majority_baseline is not None:
                        ax.axhline(majority_baseline, linestyle=':', color='purple', linewidth=1.2, label='Majority legal-move-count proportion')
                    zero_shot_baselines = {}
                    if not has_off_bars:
                        off_mean = _compute_off_the_shelf_baseline(sub)
                        if off_mean is not None:
                            ax.axhline(off_mean, linestyle='-.', color='black', linewidth=1.2, label='Off-the-shelf mean baseline')
                        for model in hue_order:
                            zs_rows = sub[(sub['model_name']==model) & (sub['is_zero_shot']==True)]
                            if zs_rows.empty:
                                continue
                            val = pd.to_numeric(zs_rows['overall_accuracy'], errors='coerce').dropna()
                            if val.empty:
                                continue
                            zero_shot_baselines[model] = float(val.mean())
                        model_color = {}
                        for i, patch in enumerate(ax.patches):
                            group_idx = i // len(hue_order)
                            hue_idx = i % len(hue_order)
                            if hue_idx < len(hue_order):
                                m = hue_order[hue_idx]
                                if m not in model_color:
                                    model_color[m] = patch.get_facecolor()
                            if len(model_color) == len(hue_order):
                                break
                        for m, baseline_val in zero_shot_baselines.items():
                            c = model_color.get(m, 'gray')
                            ax.axhline(baseline_val, linestyle=':', color=c, linewidth=1.2)
                            ax.text(len(order)-0.4, baseline_val+0.3, f'{m} off-the-shelf {baseline_val:.1f}%', color=c, fontsize=8, ha='right', va='bottom')
                    tlabel = f"Overall Accuracy - {exp_mode} | {variant_simple} | ASCII={ascii_flag} | RandomXY={rand_flag}"
                    plt.title(tlabel)
                    plt.ylabel('Accuracy (%)')
                    plt.ylim(0, 105)
                    plt.xticks(rotation=25, ha='right')
                    handles, labels = ax.get_legend_handles_labels()
                    if zero_shot_baselines:
                        handles.append(Line2D([0],[0], color='gray', linestyle=':', label='Per-model off-the-shelf baselines'))
                        labels.append('Per-model off-the-shelf baselines')
                    if any(bool(r.get('is_zero_shot')) for _, r in sub.iterrows()) and not has_off_bars:
                        handles.append(mpatches.Patch(facecolor='white', hatch='///', edgecolor='black', label='Off-the-shelf model bar'))
                        labels.append('Off-the-shelf model bar')
                    seen = {}
                    new_h, new_l = [], []
                    for h,lbl in zip(handles, labels):
                        if lbl in seen:
                            continue
                        seen[lbl]=True
                        new_h.append(h); new_l.append(lbl)
                    ax.legend(new_h, new_l, bbox_to_anchor=(1.01,1), loc='upper left', borderaxespad=0)
                    plt.subplots_adjust(right=0.76)
                    plt.tight_layout()
                    vdir = os.path.join(outdir, variant_simple, exp_mode)
                    ensure_dir(vdir)
                    fp = os.path.join(vdir, f'overall_{exp_mode}_ascii-{ascii_flag}_rand-{rand_flag}.png')
                    save_all_formats(fp, formats or ['png'])
                    plt.close()

def plot_progression_all_models(full_df: pd.DataFrame, outdir: str, show_majority: bool = True, show_off_baselines: bool = True, formats: Optional[List[str]] = None):
    """Line plots showing progression over checkpoints for each model across training groups.
    Facet by experiment_mode and ascii/random flags.
    """
    if full_df.empty:
        return
    ensure_dir(outdir)
    prog_df = full_df.copy()
    prog_df['overall_accuracy'] = pd.to_numeric(prog_df['overall_accuracy'], errors='coerce')
    prog_df = prog_df.dropna(subset=['overall_accuracy'])
    for exp_mode in sorted(prog_df['experiment_mode'].unique()):
        sub_mode = prog_df[prog_df['experiment_mode']==exp_mode]
        for variant in sorted(sub_mode['dataset_variant'].unique()):
            sub_variant = sub_mode[sub_mode['dataset_variant']==variant]
            for ascii_flag in sorted(sub_variant['ascii_board'].unique()):
                sub_ascii = sub_variant[sub_variant['ascii_board']==ascii_flag]
                for rand_flag in sorted(sub_ascii['random_xy_moves'].dropna().unique()):
                    sub = sub_ascii[sub_ascii['random_xy_moves']==rand_flag]
                    if sub.empty:
                        continue
                    variant_simple = simplify_variant(variant)
                    off_mask = sub['training_group'].str.lower().str.contains('off_the').fillna(False)
                    sub_progress = sub[~off_mask]
                    sub_off = sub[off_mask]
                    if sub_progress.empty:
                        continue
                    g = sns.relplot(
                        data=sub_progress,
                        x='checkpoint_step', y='overall_accuracy',
                        hue='model_name', style='training_group', col='representation_mode',
                        palette=get_model_palette(sorted(sub_progress['model_name'].unique())),
                        kind='line', markers=True, dashes=False,
                        facet_kws={'sharey': True, 'sharex': True}, height=5, aspect=1.6, legend='brief'
                    )
                    rep_modes_here = sorted(sub['representation_mode'].unique())
                    # Track which baselines we actually draw so legend only shows real elements.
                    drew_random = False
                    drew_majority = False
                    drew_off = False
                    for ax_index, (ax, rep_mode) in enumerate(zip(g.axes.flatten(), rep_modes_here)):
                        facet_sub = sub_progress[sub_progress['representation_mode']==rep_mode]
                        rand_bl = _compute_random_baseline(facet_sub, exp_mode)
                        maj_bl = _compute_majority_move_baseline(facet_sub) if show_majority else None
                        if rand_bl is not None:
                            ax.axhline(rand_bl, linestyle='--', color='red', linewidth=1)
                            drew_random = True
                        if maj_bl is not None:
                            ax.axhline(maj_bl, linestyle=':', color='purple', linewidth=1)
                            drew_majority = True
                        facet_off = sub_off[sub_off['representation_mode']==rep_mode]
                        if show_off_baselines and not facet_off.empty:
                            line_handles, line_labels = ax.get_legend_handles_labels()
                            color_map = {}
                            for h, lbl in zip(line_handles, line_labels):
                                if hasattr(h, 'get_color'):
                                    try:
                                        color_map[lbl] = h.get_color()
                                    except Exception:
                                        pass
                            palette = sns.color_palette()
                            palette_iter = iter(palette)
                            for row in facet_off.itertuples():
                                m = row.model_name
                                acc = row.overall_accuracy
                                colr = color_map.get(m)
                                if colr is None:
                                    try:
                                        colr = next(palette_iter)
                                    except StopIteration:
                                        colr = 'gray'
                                ax.axhline(acc, linestyle='-.', color=colr, linewidth=1.2, alpha=0.85)
                            drew_off = True
                    first_ax = g.axes.flatten()[0]
                    model_handles, model_labels = first_ax.get_legend_handles_labels()
                    if g._legend is not None:
                        g._legend.remove()
                    existing = set(model_labels)
                    # Conditionally add legend entries only if drawn
                    if drew_random and 'Uniform(19) random baseline' not in existing:
                        model_handles.append(Line2D([0],[0], linestyle='--', color='red', label='Uniform(19) random baseline'))
                        model_labels.append('Uniform(19) random baseline'); existing.add('Uniform(19) random baseline')
                    if drew_majority and 'Majority legal-move-count proportion' not in existing:
                        model_handles.append(Line2D([0],[0], linestyle=':', color='purple', label='Majority legal-move-count proportion'))
                        model_labels.append('Majority legal-move-count proportion'); existing.add('Majority legal-move-count proportion')
                    if drew_off and 'Per-model off-the-shelf baseline (horizontal)' not in existing:
                        model_handles.append(Line2D([0],[0], linestyle='-.', color='black', label='Per-model off-the-shelf baseline (horizontal)'))
                        model_labels.append('Per-model off-the-shelf baseline (horizontal)'); existing.add('Per-model off-the-shelf baseline (horizontal)')
                    # Place legend closer (upper left inside figure area)
                    # Place legend outside plot area on the right to avoid overlap with lines
                    g.fig.legend(
                        model_handles,
                        model_labels,
                        loc='center left',
                        bbox_to_anchor=(0.82, 0.5),
                        borderaxespad=0.0,
                        frameon=False
                    )
                    g.set_titles('{col_name}')
                    g.set_axis_labels('Checkpoint', 'Accuracy (%)')
                    plt.suptitle(f'Progression - {exp_mode} | {variant_simple} | ASCII={ascii_flag} | RandomXY={rand_flag}', y=0.98)
                    # Adjust margins since legend now inside upper-left
                    # Leave horizontal space on right for legend
                    g.fig.subplots_adjust(top=0.90, left=0.12, right=0.80)
                    vdir = os.path.join(outdir, variant_simple, exp_mode)
                    ensure_dir(vdir)
                    fp = os.path.join(vdir, f'progression_{exp_mode}_ascii-{ascii_flag}_rand-{rand_flag}.png')
                    save_all_formats(fp, formats or ['png'], bbox_inches='tight')
                    plt.close()

def plot_precision_and_distraction(best_df: pd.DataFrame, outdir: str, formats: Optional[List[str]] = None):
    if best_df.empty:
        return
    ensure_dir(outdir)
    # Precision Under Pressure
    p_df = best_df.copy()
    if 'precision_under_pressure' in p_df.columns:
        for variant in sorted(p_df['dataset_variant'].unique()):
            sub = p_df[(p_df['dataset_variant']==variant) & pd.notnull(p_df['precision_under_pressure'])]
            # Focus on best-move related training groups when present
            bm_mask = sub['training_group'].isin(['best_move_from_base','best_move_from_legal'])
            if bm_mask.any():
                sub_use = sub[bm_mask]
            else:
                sub_use = sub
            if sub_use.empty: continue
            variant_simple = simplify_variant(variant)
            plt.figure(figsize=(16,8))
            hue_order = sorted(sub_use['model_name'].unique())
            ax = sns.barplot(data=sub_use, x='training_group', y='precision_under_pressure', hue='model_name', errorbar=None, ci=None, hue_order=hue_order, palette=get_model_palette(hue_order))
            plt.title(f'Precision Under Pressure (num_best_moves == 1) | {variant_simple}')
            plt.ylabel('Accuracy (%)')
            plt.ylim(0, 105)
            plt.xticks(rotation=25, ha='right')
            ax.legend(bbox_to_anchor=(1.01,1), loc='upper left')
            plt.subplots_adjust(right=0.76)
            plt.tight_layout()
            vdir = os.path.join(outdir, variant_simple, 'best_move')
            ensure_dir(vdir)
            save_all_formats(os.path.join(vdir, 'precision_under_pressure.png'), formats or ['png'])
            plt.close()
    # Distraction Index
    if 'distraction_index' in p_df.columns:
        for variant in sorted(p_df['dataset_variant'].unique()):
            sub = p_df[(p_df['dataset_variant']==variant) & pd.notnull(p_df['distraction_index'])]
            bm_mask = sub['training_group'].isin(['best_move_from_base','best_move_from_legal'])
            if bm_mask.any():
                sub_use = sub[bm_mask]
            else:
                sub_use = sub
            if sub_use.empty: continue
            variant_simple = simplify_variant(variant)
            plt.figure(figsize=(16,8))
            hue_order = sorted(sub_use['model_name'].unique())
            ax = sns.barplot(data=sub_use, x='training_group', y='distraction_index', hue='model_name', errorbar=None, ci=None, hue_order=hue_order, palette=get_model_palette(hue_order))
            plt.title(f'Distraction Index (Low-choice minus High-choice accuracy) | {variant_simple}')
            plt.ylabel('Accuracy Drop (percentage points)')
            plt.xticks(rotation=25, ha='right')
            plt.axhline(0, color='black', linewidth=1)
            ax.legend(bbox_to_anchor=(1.01,1), loc='upper left')
            plt.subplots_adjust(right=0.76)
            plt.tight_layout()
            vdir = os.path.join(outdir, variant_simple, 'best_move')
            ensure_dir(vdir)
            save_all_formats(os.path.join(vdir, 'distraction_index.png'), formats or ['png'])
            plt.close()

def plot_open_spaces_curve(source_df: pd.DataFrame, outdir: str, mode: str = 'empties', formats: Optional[List[str]] = None):
    """Plot accuracy vs open spaces derived from individual board states.

    Splits by variant, experiment_mode, ascii_board, random_xy_moves.
    """
    if source_df.empty:
        return
    ensure_dir(outdir)
    os_df = explode_open_space_stats(source_df, mode=mode)
    if os_df.empty:
        return
    for exp_mode in sorted(os_df['experiment_mode'].unique()):
        for variant in sorted(os_df['dataset_variant'].unique()):
            vdf = os_df[(os_df['experiment_mode']==exp_mode) & (os_df['dataset_variant']==variant)]
            if vdf.empty:
                continue
            variant_simple = simplify_variant(variant)
            for ascii_flag in sorted(vdf['ascii_board'].dropna().unique()):
                avdf = vdf[vdf['ascii_board']==ascii_flag]
                for rx_flag in sorted(avdf['random_xy_moves'].dropna().unique()):
                    sub = avdf[avdf['random_xy_moves']==rx_flag]
                    if sub.empty:
                        continue
                    desired_bins = list(range(0,10))
                    aug_rows = []
                    for (m, tg), grp in sub.groupby(['model_name','training_group']):
                        present = set(grp['open_spaces'].tolist())
                        for os_val in desired_bins:
                            if os_val not in present:
                                aug_rows.append({
                                    'model_name': m,
                                    'training_group': tg,
                                    'experiment_mode': exp_mode,
                                    'representation_mode': grp['representation_mode'].iloc[0],
                                    'dataset_variant': variant,
                                    'ascii_board': ascii_flag,
                                    'random_xy_moves': rx_flag,
                                    'checkpoint_step': grp['checkpoint_step'].iloc[0],
                                    'open_spaces': os_val,
                                    'open_space_accuracy': np.nan
                                })
                    if aug_rows:
                        sub = pd.concat([sub, pd.DataFrame(aug_rows)], ignore_index=True)
                    plt.figure(figsize=(16,8))
                    hue_order = sorted(sub['model_name'].unique())
                    sns.lineplot(data=sub.sort_values('open_spaces'), x='open_spaces', y='open_space_accuracy', hue='model_name', style='training_group', markers=True, palette=get_model_palette(hue_order), hue_order=hue_order, errorbar=None)
                    title_mode = 'empties (# empty)' if mode=='empties' else 'inverted (9 - # empty)'
                    plt.title(f'Accuracy vs Open Spaces ({title_mode}) - {exp_mode} | {variant_simple} | ascii={ascii_flag} random_xy={rx_flag}')
                    plt.ylabel('Accuracy (%)')
                    plt.xlabel('Open Spaces (# empty cells)' if mode=='empties' else 'Open Spaces (9 - # empty)')
                    plt.ylim(0, 105)
                    plt.tight_layout()
                    vdir = os.path.join(outdir, variant_simple, exp_mode, f'ascii_{ascii_flag}', f'random_{rx_flag}')
                    ensure_dir(vdir)
                    save_all_formats(os.path.join(vdir, f'open_spaces_{exp_mode}_ascii{ascii_flag}_rx{rx_flag}.png'), formats or ['png'])
                    plt.close()

def plot_outcome_score_heatmap(best_df: pd.DataFrame, outdir: str, formats: Optional[List[str]] = None):
    """Heatmap of accuracy by minimax outcome score, split per model & training_group & ascii/random settings."""
    if best_df.empty:
        return
    ensure_dir(outdir)
    expected_scores = [-4,-2,0,1,3,5]
    desired_tg_order = ['legal_move','best_move_from_base','best_move_from_legal','off_the_shelf']
    for exp_mode in sorted(best_df['experiment_mode'].unique()):
        for variant in sorted(best_df['dataset_variant'].unique()):
            base_subset = best_df[(best_df['experiment_mode']==exp_mode) & (best_df['dataset_variant']==variant)]
            if base_subset.empty:
                continue
            for ascii_flag in sorted(base_subset['ascii_board'].dropna().unique()):
                ascii_subset = base_subset[base_subset['ascii_board']==ascii_flag]
                for rx_flag in sorted(ascii_subset['random_xy_moves'].dropna().unique()):
                    subset = ascii_subset[ascii_subset['random_xy_moves']==rx_flag]
                    if subset.empty:
                        continue
                    records = []
                    for row in subset.itertuples():
                        stats = getattr(row, 'fine_grained_stats', {}) or {}
                        osm = stats.get('by_minimax_outcome_score', {})
                        for score in expected_scores:
                            sval = str(score)
                            vals = osm.get(sval, {})
                            correct = vals.get('correct') or vals.get('num_correct') or 0
                            total = vals.get('total') or vals.get('num_total') or 0
                            acc = 100*correct/total if total else np.nan
                            records.append({
                                'model_name': row.model_name,
                                'training_group': row.training_group,
                                'score': score,
                                'accuracy': acc
                            })
                    if not records:
                        continue
                    heat_df = pd.DataFrame(records)
                    heat_df['tg_rank'] = heat_df['training_group'].apply(lambda g: desired_tg_order.index(g) if g in desired_tg_order else len(desired_tg_order)+1)
                    heat_df['row_key'] = heat_df['model_name'] + ' | ' + heat_df['training_group']
                    agg = heat_df.groupby(['row_key','training_group','model_name','tg_rank','score'], as_index=False)['accuracy'].mean()
                    pivot = agg.pivot_table(index='row_key', columns='score', values='accuracy', aggfunc='mean')
                    for sc in expected_scores:
                        if sc not in pivot.columns:
                            pivot[sc] = np.nan
                    pivot = pivot[expected_scores]
                    row_meta = agg[['row_key','tg_rank','model_name','training_group']].drop_duplicates().set_index('row_key')
                    pivot = pivot.loc[row_meta.sort_values(['tg_rank','model_name']).index]
                    plt.figure(figsize=(14, max(4, 0.5*len(pivot))))
                    sns.heatmap(pivot, annot=True, fmt='.1f', cmap='coolwarm', center=50, vmin=0, vmax=100)
                    variant_simple = simplify_variant(variant)
                    plt.title(f'Outcome Score Accuracy Heatmap - {exp_mode} | {variant_simple} | ascii={ascii_flag} random_xy={rx_flag}')
                    plt.xlabel('Minimax Outcome Score')
                    plt.ylabel('Model | Training Group')
                    plt.tight_layout()
                    vdir = os.path.join(outdir, variant_simple, exp_mode, f'ascii_{ascii_flag}', f'random_{rx_flag}')
                    ensure_dir(vdir)
                    save_all_formats(os.path.join(vdir, f'outcome_score_heatmap_{exp_mode}_ascii{ascii_flag}_rx{rx_flag}.png'), formats or ['png'])
                    plt.close()

def plot_legal_vs_best_comparison(best_df: pd.DataFrame, outdir: str, formats: Optional[List[str]] = None):
    """Compare best_move vs legal_move evaluation per training_group & (ascii, random_xy) setting.

    Uses only best_df rows. Positive gap means best_move eval > legal_move eval.
    Separately shows training groups: legal_move, best_move_from_base, best_move_from_legal.
    """
    ensure_dir(outdir)
    if best_df.empty:
        return
    work = best_df.copy()
    if 'dataset_variant' in work.columns:
        work['dataset_variant'] = work['dataset_variant'].astype(str)
    # Restrict to experiment modes of interest
    work = work[work['experiment_mode'].isin(['legal_move','best_move'])]
    if work.empty:
        return
    base_cols = ['model_name','representation_mode','training_group','dataset_variant','random_xy_moves','ascii_board']
    pivot = work.pivot_table(index=base_cols, columns='experiment_mode', values='overall_accuracy', aggfunc='mean')
    if 'legal_move' not in pivot.columns or 'best_move' not in pivot.columns:
        return
    pivot = pivot.reset_index()
    pivot['gap_best_minus_legal'] = pivot['best_move'] - pivot['legal_move']
    for variant in sorted(pivot['dataset_variant'].unique()):
        v_sub = pivot[pivot['dataset_variant']==variant]
        if v_sub.empty:
            continue
        variant_simple = simplify_variant(variant)
        for ascii_flag in sorted(v_sub['ascii_board'].dropna().unique()):
            a_sub = v_sub[v_sub['ascii_board']==ascii_flag]
            for rx_flag in sorted(a_sub['random_xy_moves'].dropna().unique()):
                sub = a_sub[a_sub['random_xy_moves']==rx_flag]
                if sub.empty:
                    continue
                plt.figure(figsize=(18,8))
                hue_order = sorted(sub['training_group'].unique())
                sns.barplot(data=sub, x='model_name', y='gap_best_minus_legal', hue='training_group', errorbar=None)
                plt.axhline(0, color='black', linewidth=1)
                plt.ylabel('Accuracy Gain (Best - Legal)')
                plt.title(f'Best vs Legal Eval Gap | {variant_simple} | ascii={ascii_flag} | random_xy={rx_flag}')
                plt.xticks(rotation=30, ha='right')
                plt.tight_layout()
                vdir = os.path.join(outdir, variant_simple, f'ascii_{ascii_flag}', f'random_{rx_flag}')
                ensure_dir(vdir)
                save_all_formats(os.path.join(vdir, f'best_vs_legal_gap_{variant_simple}_ascii{ascii_flag}_rx{rx_flag}.png'), formats or ['png'])
                plt.close()

def plot_off_the_shelf_comparison(best_df: pd.DataFrame, outdir: str, formats: Optional[List[str]] = None):
    """Compare off-the-shelf baseline vs tuned models (best checkpoints)."""
    ensure_dir(outdir)
    if best_df.empty:
        return
    z = best_df.copy()
    # Derive flag if missing
    if 'is_off_the_shelf' not in z.columns:
        if 'training_group' in z.columns:
            z['is_off_the_shelf'] = z.apply(lambda r: detect_off_the_shelf(r.get('training_group',''), r.get('dataset_variant',''), r.get('model_name','')), axis=1)
        elif 'is_zero_shot' in z.columns:
            z['is_off_the_shelf'] = z['is_zero_shot']
    if 'is_off_the_shelf' not in z.columns:
        return
    agg = z.groupby(['is_off_the_shelf','experiment_mode','dataset_variant'])['overall_accuracy'].mean().reset_index()
    if agg.empty:
        return
    for variant in sorted(agg['dataset_variant'].unique()):
        sub = agg[agg['dataset_variant']==variant]
        if sub.empty:
            continue
        variant_simple = simplify_variant(variant)
        plt.figure(figsize=(10,6))
        sns.barplot(data=sub, x='experiment_mode', y='overall_accuracy', hue='is_off_the_shelf')
        plt.title(f'Off-the-shelf (Baseline) vs Fine-tuned Mean Accuracy | {variant_simple}')
        plt.ylabel('Mean Accuracy (%)')
        plt.ylim(0, 105)
        plt.tight_layout()
        vdir = os.path.join(outdir, variant_simple, 'off_the_shelf')
        ensure_dir(vdir)
        save_all_formats(os.path.join(vdir, f'off_the_shelf_comparison_{variant_simple}.png'), formats or ['png'])
        plt.close()

# (Detailed plotting functions to be expanded in later steps.)

# ----------------------------- Main Pipeline ---------------------------

def load_all(ascii_root: str, non_ascii_root: str) -> pd.DataFrame:
    records = []
    records.extend(walk_result_tree(non_ascii_root, ascii_flag=False))
    records.extend(walk_result_tree(ascii_root, ascii_flag=True))
    df = pd.DataFrame(records)
    return df


def augment_metrics(df: pd.DataFrame) -> pd.DataFrame:
    if df.empty:
        return df
    # Add off-the-shelf baseline flag (robust to missing columns)
    if 'is_off_the_shelf' not in df.columns:
        df['is_off_the_shelf'] = df.apply(lambda r: detect_off_the_shelf(r.get('training_group',''), r.get('dataset_variant',''), r.get('model_name','')), axis=1)
    # Backward compat alias if prior visualization functions expect is_zero_shot
    if 'is_zero_shot' not in df.columns:
        df['is_zero_shot'] = df['is_off_the_shelf']
    df['precision_under_pressure'] = df.apply(compute_precision_under_pressure, axis=1)
    df['distraction_index'] = df.apply(compute_distraction_index, axis=1)
    return df

def plot_ply_curve(best_df: pd.DataFrame, outdir: str, formats: Optional[List[str]] = None):
    """Plot accuracy vs ply using only best checkpoints; split by ascii/random."""
    if best_df.empty:
        return
    ensure_dir(outdir)
    # Track which models lack by_game_ply stats so user understands absences
    models_missing_ply = []
    for m, grp in best_df.groupby('model_name'):
        has_any = False
        for _, r in grp.iterrows():
            stats = (r.get('fine_grained_stats') or {})
            if stats.get('by_game_ply'):
                has_any = True
                break
        if not has_any:
            models_missing_ply.append(m)
    rows = []
    for _, r in best_df.iterrows():
        stats = (r.get('fine_grained_stats') or {})
        by_ply = stats.get('by_game_ply', {}) or {}
        for ply_str, vals in by_ply.items():
            try:
                ply = int(ply_str)
            except Exception:
                continue
            correct = vals.get('correct') or vals.get('num_correct') or 0
            total = vals.get('total') or vals.get('num_total') or 0
            if total <= 0:
                acc = np.nan
            else:
                acc = 100 * correct / total
            rows.append({
                'model_name': r['model_name'],
                'training_group': r['training_group'],
                'representation_mode': r['representation_mode'],
                'experiment_mode': r['experiment_mode'],
                'dataset_variant': str(r.get('dataset_variant','')),
                'ply': ply,
                'accuracy': acc,
                'ascii_board': r.get('ascii_board'),
                'random_xy_moves': r.get('random_xy_moves')
            })
    if not rows:
        return
    df = pd.DataFrame(rows)
    for variant in sorted(df['dataset_variant'].unique()):
        vdf = df[df['dataset_variant']==variant]
        variant_simple = simplify_variant(variant)
        for ascii_flag in sorted(vdf['ascii_board'].dropna().unique()):
            avdf = vdf[vdf['ascii_board']==ascii_flag]
            for rx_flag in sorted(avdf['random_xy_moves'].dropna().unique()):
                sub = avdf[avdf['random_xy_moves']==rx_flag]
                if sub.empty:
                    continue
                for exp_mode in sorted(sub['experiment_mode'].unique()):
                    em_sub = sub[sub['experiment_mode']==exp_mode]
                    if em_sub.empty:
                        continue
                    plt.figure(figsize=(18,8))
                    # Plot each model separately; style encodes training_group
                    hue_order = sorted(em_sub['model_name'].unique())
                    sns.lineplot(data=em_sub, x='ply', y='accuracy', hue='model_name', style='training_group', markers=False, palette=get_model_palette(hue_order), hue_order=hue_order, estimator=None, errorbar=None)
                    plt.title(f'Accuracy by Ply | {variant_simple} | {exp_mode} | ascii={ascii_flag} random_xy={rx_flag} (per-model)')
                    plt.ylabel('Accuracy (%)')
                    plt.xlabel('Ply')
                    plt.ylim(0,105)
                    plt.tight_layout()
                    vdir = os.path.join(outdir, variant_simple, 'ply_curves', exp_mode, f'ascii_{ascii_flag}', f'random_{rx_flag}')
                    ensure_dir(vdir)
                    save_all_formats(os.path.join(vdir, f'ply_curve_{variant_simple}_{exp_mode}_ascii{ascii_flag}_rx{rx_flag}.png'), formats or ['png'])
                    plt.close()
    if models_missing_ply:
        print(f"[ply_plot_debug] Models missing by_game_ply stats (not plotted): {sorted(set(models_missing_ply))}")

def plot_complexity_curves(best_df: pd.DataFrame, outdir: str, formats: Optional[List[str]] = None):
    """Plot accuracy vs (num legal moves) and (num best moves) using best checkpoints.

    Splits by variant, experiment_mode, ascii_board, random_xy_moves.
    """
    if best_df.empty:
        return
    ensure_dir(outdir)
    rows_legal = []
    rows_best = []
    for _, r in best_df.iterrows():
        stats = (r.get('fine_grained_stats') or {})
        nlm = stats.get('by_num_legal_moves', {}) or {}
        nbm = stats.get('by_num_best_moves', {}) or {}
        for bucket, vals in nlm.items():
            try:
                b = int(bucket)
            except Exception:
                continue
            correct = vals.get('correct') or vals.get('num_correct') or 0
            total = vals.get('total') or vals.get('num_total') or 0
            acc = 100*correct/total if total else np.nan
            rows_legal.append({
                'bucket': b,
                'accuracy': acc,
                'model_name': r['model_name'],
                'training_group': r['training_group'],
                'representation_mode': r['representation_mode'],
                'experiment_mode': r['experiment_mode'],
                'dataset_variant': r['dataset_variant'],
                'ascii_board': r.get('ascii_board'),
                'random_xy_moves': r.get('random_xy_moves')
            })
        for bucket, vals in nbm.items():
            try:
                b = int(bucket)
            except Exception:
                continue
            correct = vals.get('correct') or vals.get('num_correct') or 0
            total = vals.get('total') or vals.get('num_total') or 0
            acc = 100*correct/total if total else np.nan
            rows_best.append({
                'bucket': b,
                'accuracy': acc,
                'model_name': r['model_name'],
                'training_group': r['training_group'],
                'representation_mode': r['representation_mode'],
                'experiment_mode': r['experiment_mode'],
                'dataset_variant': r['dataset_variant'],
                'ascii_board': r.get('ascii_board'),
                'random_xy_moves': r.get('random_xy_moves')
            })
    df_legal = pd.DataFrame(rows_legal)
    df_best = pd.DataFrame(rows_best)
    # Diagnostic: ensure only single best checkpoint per grouping
    dup_keys = ['model_name','representation_mode','training_group','experiment_mode','dataset_variant','ascii_board','random_xy_moves']
    dup_counts = best_df.groupby(dup_keys)['checkpoint_step'].nunique().reset_index(name='n_ckpts')
    multi = dup_counts[dup_counts['n_ckpts']>1]
    if not multi.empty:
        print('[complexity_debug] Warning: multiple checkpoints found after best selection for some groups:')
        print(multi.head().to_string(index=False))
    for variant in sorted(best_df['dataset_variant'].unique()):
        vmask = best_df['dataset_variant']==variant
        variant_best = best_df[vmask]
        for exp_mode in sorted(variant_best['experiment_mode'].unique()):
            em_best = variant_best[variant_best['experiment_mode']==exp_mode]
            for ascii_flag in sorted(em_best['ascii_board'].dropna().unique()):
                ascii_best = em_best[em_best['ascii_board']==ascii_flag]
                for rx_flag in sorted(ascii_best['random_xy_moves'].dropna().unique()):
                    sub_legal = df_legal[(df_legal['dataset_variant']==variant)&(df_legal['experiment_mode']==exp_mode)&(df_legal['ascii_board']==ascii_flag)&(df_legal['random_xy_moves']==rx_flag)]
                    sub_best = df_best[(df_best['dataset_variant']==variant)&(df_best['experiment_mode']==exp_mode)&(df_best['ascii_board']==ascii_flag)&(df_best['random_xy_moves']==rx_flag)]
                    if sub_legal.empty and sub_best.empty:
                        continue
                    variant_simple = simplify_variant(variant)
                    # Use hue=model_name, style=training_group to mirror ply plot semantics and prevent averaging.
                    if not sub_legal.empty:
                        plt.figure(figsize=(16,8))
                        horder = sorted(sub_legal['model_name'].unique())
                        sns.lineplot(data=sub_legal.sort_values('bucket'), x='bucket', y='accuracy', hue='model_name', style='training_group', hue_order=horder, palette=get_model_palette(horder), estimator=None, errorbar=None)
                        plt.title(f'Accuracy vs # Legal Moves | {variant_simple} | {exp_mode} | ascii={ascii_flag} random_xy={rx_flag}')
                        plt.ylabel('Accuracy (%)')
                        plt.xlabel('# Legal Moves')
                        plt.ylim(0,105)
                        plt.tight_layout()
                        vdir = os.path.join(outdir, variant_simple, exp_mode, f'ascii_{ascii_flag}', f'random_{rx_flag}')
                        ensure_dir(vdir)
                        save_all_formats(os.path.join(vdir, f'legal_moves_curve_{variant_simple}_{exp_mode}_ascii{ascii_flag}_rx{rx_flag}.png'), formats or ['png'])
                        plt.close()
                    if not sub_best.empty:
                        plt.figure(figsize=(16,8))
                        horder = sorted(sub_best['model_name'].unique())
                        sns.lineplot(data=sub_best.sort_values('bucket'), x='bucket', y='accuracy', hue='model_name', style='training_group', hue_order=horder, palette=get_model_palette(horder), estimator=None, errorbar=None)
                        plt.title(f'Accuracy vs # Best Moves | {variant_simple} | {exp_mode} | ascii={ascii_flag} random_xy={rx_flag}')
                        plt.ylabel('Accuracy (%)')
                        plt.xlabel('# Best Moves')
                        plt.ylim(0,105)
                        plt.tight_layout()
                        vdir = os.path.join(outdir, variant_simple, exp_mode, f'ascii_{ascii_flag}', f'random_{rx_flag}')
                        ensure_dir(vdir)
                        save_all_formats(os.path.join(vdir, f'best_moves_curve_{variant_simple}_{exp_mode}_ascii{ascii_flag}_rx{rx_flag}.png'), formats or ['png'])
                        plt.close()

def main():
    parser = argparse.ArgumentParser(description='Analyze full inference results (v2).')
    parser.add_argument('--ascii_results_root', type=str, default='/mnt/shared/stlm-logic/results_full_inference_updated_ascii_board')
    parser.add_argument('--non_ascii_results_root', type=str, default='/mnt/shared/stlm-logic/results_full_inference_updated')
    parser.add_argument('--output_dir', type=str, default='/home/data/stlm-game-logic/analysis_plots_v5')
    parser.add_argument('--save_csv', action='store_true')
    parser.add_argument('--show_majority_progression', action='store_true', default=False, help='Show majority legal move count baseline on progression plots (default off).')
    parser.add_argument('--show_off_baselines_progression', action='store_true', default=False, help='Show per-model off-the-shelf baselines on progression plots (default off).')
    parser.add_argument('--open_spaces_mode', type=str, default='empties', choices=['empties','inverted'], help='How to compute open_spaces: empties = number of empty cells; inverted = 9 - empties.')
    parser.add_argument('--dpi', type=int, default=300, help='DPI for all saved figures (default 300).')
    parser.add_argument('--pdf', action='store_true', help='Also save every plot as PDF (vector format) alongside PNG.')
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    # Set global DPI so every subsequent plt.savefig uses this resolution unless overridden.
    try:
        plt.rcParams['savefig.dpi'] = args.dpi
        plt.rcParams['figure.dpi'] = args.dpi  # affects some backends / interactive displays
        print(f"[config] Using DPI={args.dpi} for all saved figures.")
    except Exception as e:
        print(f"[warn] Could not set DPI rcParams: {e}")
    print('[load] Collecting results...')
    full_df = load_all(args.ascii_results_root, args.non_ascii_results_root)
    if full_df.empty:
        print('No results found. Exiting.')
        return
    print(f'[load] Loaded {len(full_df)} result files.')
    try:
        init_model_colors(list(full_df['model_name'].unique()))
    except Exception as e:
        print(f'[warn] Could not initialize model color map: {e}')
    # Diagnostics: individual_outputs availability
    try:
        if 'individual_outputs' in full_df.columns:
            non_empty = full_df['individual_outputs'].apply(lambda v: isinstance(v, dict) and len(v)>0)
            cnt_non_empty = int(non_empty.sum())
            print(f'[diag] Rows with non-empty individual_outputs: {cnt_non_empty} / {len(full_df)}')
            if cnt_non_empty == 0:
                print('[diag] All individual_outputs empty or missing. Open spaces plot will be empty. Ensure inference was run with --save_individual_outputs.')
        else:
            print('[diag] Column individual_outputs not present in loaded DataFrame.')
    except Exception as e:
        print(f'[diag] individual_outputs diagnostic error: {e}')
    # Diagnostics: counts per training_group / experiment_mode
    try:
        cnt = full_df.groupby(['training_group','experiment_mode','ascii_board','random_xy_moves']).size().reset_index(name='files')
        print('[diag] File counts (first 20):')
        print(cnt.head(20).to_string(index=False))
    except Exception as e:
        print(f'[diag] Skipped counts: {e}')

    # Canonicalize dataset_variant early
    try:
        if 'dataset_variant' in full_df.columns:
            full_df['dataset_variant'] = full_df['dataset_variant'].apply(canonicalize_dataset_variant)
    except Exception as e:
        print(f"[warn] dataset_variant canonicalization failed: {e}")

    full_df = augment_metrics(full_df)
    best_df = select_best_checkpoints(full_df)
    if 'dataset_variant' in best_df.columns:
        best_df['dataset_variant'] = best_df['dataset_variant'].apply(canonicalize_dataset_variant)
    print(f'[select] Best checkpoints: {len(best_df)} rows.')

    if args.save_csv:
        save_df(full_df, os.path.join(args.output_dir, 'full_results_raw.csv'))
        save_df(best_df, os.path.join(args.output_dir, 'best_checkpoints.csv'))

    # Determine formats
    formats = ['png','pdf'] if args.pdf else ['png']
    # Plots
    plot_overall_bar(best_df, os.path.join(args.output_dir, 'overall'), full_df, formats=formats)
    plot_progression_all_models(full_df, os.path.join(args.output_dir, 'progression'), show_majority=args.show_majority_progression, show_off_baselines=args.show_off_baselines_progression, formats=formats)
    plot_precision_and_distraction(best_df, os.path.join(args.output_dir, 'pressure'), formats=formats)
    # Open spaces now always uses best checkpoints for consistency
    plot_open_spaces_curve(best_df, os.path.join(args.output_dir, 'open_spaces'), mode=args.open_spaces_mode, formats=formats)
    # Diagnostics: open space distribution
    try:
        os_tmp = explode_open_space_stats(best_df, mode=args.open_spaces_mode)
        if not os_tmp.empty:
            print(f"[diag] open_spaces (best only) min={os_tmp['open_spaces'].min()} max={os_tmp['open_spaces'].max()} distinct={sorted(os_tmp['open_spaces'].unique())}")
        else:
            print('[diag] open_spaces DataFrame empty after explosion (best checkpoints).')
    except Exception as e:
        print(f'[diag] open_spaces explosion error: {e}')
    plot_ply_curve(best_df, os.path.join(args.output_dir, 'ply'), formats=formats)
    plot_outcome_score_heatmap(best_df, os.path.join(args.output_dir, 'heatmaps'), formats=formats)
    plot_complexity_curves(best_df, os.path.join(args.output_dir, 'complexity'), formats=formats)
    plot_legal_vs_best_comparison(best_df, os.path.join(args.output_dir, 'comparisons'), formats=formats)
    plot_off_the_shelf_comparison(best_df, os.path.join(args.output_dir, 'off_the_shelf'), formats=formats)

    print('[done] Analysis v2 complete.')

if __name__ == '__main__':
    main()
