#!/usr/bin/env python3
import argparse
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from pathlib import Path
import sys
import sys, types, math, torch, torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Set up flash_attn stub BEFORE any imports that might trigger transformers
def _fa_flash_stub(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, **kwargs):
    def _fp32c(t):
        if torch.is_tensor(t) and t.is_floating_point(): t = t.float()
        return t.contiguous() if torch.is_tensor(t) and not t.is_contiguous() else t
    q, k, v = _fp32c(q), _fp32c(k), _fp32c(v)

    if hasattr(F, "scaled_dot_product_attention"):
        return F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=dropout_p if torch.is_grad_enabled() else 0.0,
            is_causal=causal,
        )
    d = q.size(-1)
    scale = (1.0 / math.sqrt(d)) if softmax_scale is None else softmax_scale
    scores = (q @ k.transpose(-2, -1)) * scale
    if causal:
        L = scores.size(-2)
        mask = torch.ones_like(scores, dtype=torch.bool).tril()
        scores = scores.masked_fill(~mask, float("-inf"))
    P = scores.softmax(dim=-1)
    if torch.is_grad_enabled() and dropout_p > 0:
        P = F.dropout(P, p=dropout_p)
    return P @ v

# Create a proper module-like object with __spec__ to avoid import errors
# The issue: transformers checks for flash_attn using importlib.util.find_spec()
# which expects __spec__ to be a valid ModuleSpec object
import importlib.util
import importlib.machinery
import importlib.abc

flash_attn_module = types.ModuleType("flash_attn")
flash_attn_module.flash_attn_func = _fa_flash_stub
flash_attn_module.__file__ = "<stub>"
flash_attn_module.__package__ = "flash_attn"

# Try to create a proper ModuleSpec using spec_from_loader
class MinimalLoader(importlib.abc.Loader):
    def create_module(self, spec):
        return None
    def exec_module(self, module):
        pass

try:
    # Create a proper ModuleSpec
    spec = importlib.util.spec_from_loader("flash_attn", MinimalLoader(), origin="<stub>")
    if spec is not None:
        flash_attn_module.__spec__ = spec
        flash_attn_module.__loader__ = MinimalLoader()
    else:
        # Fallback: create a minimal spec-like object
        flash_attn_module.__spec__ = types.SimpleNamespace(
            name="flash_attn",
            loader=None,
            origin="<stub>",
            submodule_search_locations=[],
            has_location=False,
            cached=None,
            loader_state=None
        )
except Exception:
    # Final fallback: create a minimal spec-like object
    flash_attn_module.__spec__ = types.SimpleNamespace(
        name="flash_attn",
        loader=None,
        origin="<stub>",
        submodule_search_locations=[],
        has_location=False,
        cached=None,
        loader_state=None
    )

sys.modules["flash_attn"] = flash_attn_module

# Now safe to import train_lambda (which may trigger transformers imports)
from train_lambda import LambdaSchedule

REPO_ROOT = Path("")
sys.path.insert(0, str(REPO_ROOT))

from lit_gpt.model_cache import Config
from lit_gpt.diffmodel import TransEncoder
from inference_sudoku import load_mdm_state_dict
from dataclasses import dataclass
from datetime import timedelta

from GRPO_order import (
    SudokuTokenDataset,
    make_f_theta_from_model,
    build_tokenizer_and_digit_maps,
    TrainConfig,
    FusionOrderPolicy,
    load_mdm_state_dict,
    Config,
)

# ===== Candidate-set utilities (human-alignment) =====

DIGITS = set(range(1, 10))

def peers_of(idx):
    r, c = idx // 9, idx % 9
    br, bc = (r // 3) * 3, (c // 3) * 3
    peers = set(r * 9 + j for j in range(9))
    peers |= set(j * 9 + c for j in range(9))
    peers |= set(rr * 9 + cc for rr in range(br, br + 3) for cc in range(bc, bc + 3))
    peers.remove(idx)
    return peers

PEERS = [peers_of(i) for i in range(81)]

def compute_candidates(grid):
    cand = [set() for _ in range(81)]
    for i in range(81):
        if grid[i] != 0:
            cand[i] = {grid[i]}
        else:
            used = {grid[p] for p in PEERS[i] if grid[p] != 0}
            cand[i] = DIGITS - used
            if not cand[i]:
                return []
    return cand

def is_hidden_single(grid, cand, idx):
    r, c = idx // 9, idx % 9
    br, bc = (r // 3) * 3, (c // 3) * 3
    for v in cand[idx]:
        if sum(v in cand[r*9 + j] for j in range(9)) == 1:
            return True
        if sum(v in cand[j*9 + c] for j in range(9)) == 1:
            return True
        if sum(
            v in cand[(br+i)*9 + (bc+j)]
            for i in range(3) for j in range(3)
        ) == 1:
            return True
    return False


def load_ground_truth_order(jsonl_path: str) -> dict:
    """
    Load ground truth decode order from decoding_human.jsonl or similar format.
    Only processes assignment steps (those with 'pos' field), skipping elimination steps.
    Returns: dict mapping (puzzle_id, index) -> list of positions (0-80 indices)
    Both puzzle_id and index are stored for flexible matching.
    
    Expected format:
    - Steps have "pos": [row, col] where row, col are 1-indexed [1, 9]
    - Only assignment steps have "pos" (elimination steps are automatically skipped)
    """
    import json
    gt_orders_by_id = {}  # puzzle_id -> positions
    gt_orders_by_idx = {}  # index in file -> positions
    with open(jsonl_path, 'r') as f:
        for idx, line in enumerate(f):
            data = json.loads(line.strip())
            puzzle_id = data.get('puzzle_id', None)
            steps = data.get('steps', [])
            # Convert (row, col) from steps to linear position (0-80)
            # Only process assignment steps (those with 'pos' field)
            # pos is 1-indexed: [row, col] where row, col in [1, 9]
            positions = []
            for step in steps:
                pos = step.get('pos', None)
                # Skip elimination steps (e.g., "Bulk Pencil Marking") that don't have pos
                if pos and len(pos) == 2:
                    r, c = pos[0], pos[1]  # 1-indexed
                    linear_pos = (r - 1) * 9 + (c - 1)  # Convert to 0-indexed linear position
                    positions.append(linear_pos)
            if puzzle_id is not None:
                gt_orders_by_id[puzzle_id] = positions
            gt_orders_by_idx[idx] = positions
    return {'by_id': gt_orders_by_id, 'by_idx': gt_orders_by_idx}


@torch.no_grad()
def solve_batch(policy, f_theta, puz_batch, sol_batch, max_steps=40, selection_mode="argmax", lambda_schedule=None, ground_truth_order=None):
    """
    puz_batch: [B,81] long, 0 = empty
    sol_batch: [B,81] long, 1..9
    selection_mode: "argmax", "margin", or "random"
    returns:
        pred_batch: [B,81] long (model-filled)
        puzzle_correct: float in [0,1]
        cell_accuracy: float in [0,1]
        fail_step: [B] long, 0 if never made a wrong move, else 1..max_steps (first wrong fill step)
        decode_order: List[List[int]] of length B, each inner list contains position indices selected at each step
        decode_digits: List[List[int]] of length B, each inner list contains predicted digits at each step
    """
    device = puz_batch.device
    B = puz_batch.size(0)
    state = puz_batch.clone()     # [B,81]

    # first step (1-based) when each puzzle makes a wrong prediction
    fail_step = torch.zeros(B, dtype=torch.long, device=device)
    eps = 1e-20
    gumbel_tau = 0.5
    
    # Track decode order: for each puzzle, record which positions were selected at each step
    decode_order = [[] for _ in range(B)]  # List of lists: decode_order[b] = [pos1, pos2, ...]
    decode_digits = [[] for _ in range(B)]  # List of lists: decode_digits[b] = [digit1, digit2, ...]

    for t in range(max_steps):
        step_idx = t + 1  # 1-based step index

        valid = (state == 0)                     # [B,81]
        if valid.sum() == 0:
            break

        # Select position based on mode
        if selection_mode == "argmax":
            # Original method: use policy logits, argmax over valid cells
            logits = policy(state).to(torch.float32)
            logits = logits.masked_fill(~valid, float('-inf'))
            pos = torch.argmax(logits, dim=-1)       # [B] pos index 0..80
            # Predict digit at pos using f_theta
            logits_81x9 = f_theta(state).to(torch.float32)       # [B,81,9]
            probs_81x9 = F.softmax(logits_81x9, dim=-1)
        elif selection_mode == "margin":
            # Margin: use f_theta's confidence (top1 prob) at each position
            logits_81x9 = f_theta(state).to(torch.float32)  # [B,81,9]
            probs_81x9 = F.softmax(logits_81x9, dim=-1)     # [B,81,9]
            top2_probs, _ = torch.topk(probs_81x9, k=2, dim=-1)  # [B,81,2]
            margins = top2_probs[:, :, 0] - top2_probs[:, :, 1]

            margins = margins.masked_fill(~valid, float('-inf'))
            pos = torch.argmax(margins, dim=-1)  # [B] position with highest margin
        elif selection_mode == "entropy":
            # Entropy: select position with lowest entropy
            logits_81x9 = f_theta(state).to(torch.float32)  # [B,81,9]
            probs_81x9 = F.softmax(logits_81x9, dim=-1)     # [B,81,9]
            # Add small epsilon to avoid log(0) = -inf
            probs_81x9_safe = probs_81x9 + 1e-10
            entropy = -torch.sum(probs_81x9 * torch.log(probs_81x9_safe), dim=-1)  # [B,81] entropy per position
            # Mask invalid positions with +inf so argmin ignores them
            entropy = entropy.masked_fill(~valid, float('inf'))
            pos = torch.argmin(entropy, dim=-1)  # [B] position with lowest entropy
        elif selection_mode == "balanced":
            if lambda_schedule is not None:
                t_norm = torch.full((B, 1), step_idx / max_steps, dtype=torch.float32, device=device)
                lambda_ent = lambda_schedule(t_norm)  # [B]
            else:
                lambda_ent = torch.full((B,), 0.13, dtype=torch.float32, device=device)  # [B]
            logits_81x9 = f_theta(state).to(torch.float32)      # [B,81,9]
            probs_81x9 = F.softmax(logits_81x9, dim=-1)         # [B,81,9]
            log_probs_81x9 = F.log_softmax(logits_81x9, dim=-1) # [B,81,9]
            # top-1 prob
            top1_probs, _ = probs_81x9.max(dim=-1)               # [B,81]
            # entropy per cell
            entropy = -(probs_81x9 * log_probs_81x9).sum(dim=-1) # [B,81]

            #scores = top1_probs + lambda_ent.unsqueeze(-1) * entropy  # [B,81]
            scores = top1_probs + 0.13 * entropy
            scores = scores.masked_fill(~valid, float("-inf"))
            pos = torch.argmax(scores, dim=-1)

        elif selection_mode == "random":
            # Random: sample uniformly from valid positions
            pos = torch.zeros(B, dtype=torch.long, device=device)
            for b in range(B):
                valid_positions = valid[b].nonzero(as_tuple=False).squeeze(-1)  # [M]
                if valid_positions.numel() > 0:
                    pos[b] = valid_positions[torch.randint(0, valid_positions.numel(), (1,), device=device)]
                else:
                    pos[b] = 0
            # Predict digit at pos using f_theta
            logits_81x9 = f_theta(state).to(torch.float32)       # [B,81,9]
            probs_81x9 = F.softmax(logits_81x9, dim=-1)
        elif selection_mode == "ground_truth":
            # Ground truth: use positions from human-like-sudoku.jsonl
            if ground_truth_order is None:
                raise ValueError("ground_truth_order must be provided when selection_mode='ground_truth'")
            pos = torch.zeros(B, dtype=torch.long, device=device)
            for b in range(B):
                # Get the ground truth order for this puzzle (assuming puzzle index b)
                gt_order = ground_truth_order.get(b, [])
                if t < len(gt_order):
                    pos[b] = gt_order[t]
                else:
                    # If we've exhausted the ground truth order, fall back to first valid position
                    valid_positions = valid[b].nonzero(as_tuple=False).squeeze(-1)
                    if valid_positions.numel() > 0:
                        pos[b] = valid_positions[0]
                    else:
                        pos[b] = 0
            # Predict digit at pos using f_theta
            logits_81x9 = f_theta(state).to(torch.float32)       # [B,81,9]
            probs_81x9 = F.softmax(logits_81x9, dim=-1)
        else:
            raise ValueError(f"Unknown selection_mode: {selection_mode}")

        # argmax over digits for chosen pos
        row_idx = torch.arange(B, device=device)
        pred_digit = torch.argmax(probs_81x9[row_idx, pos], dim=-1) + 1  # [B] 1..9

        # --- Track decode order ---
        for b in range(B):
            if valid[b, pos[b]]:  # Only record if position was actually valid
                decode_order[b].append(pos[b].item())
                decode_digits[b].append(pred_digit[b].item())

        # --- Track failing step (first wrong fill) ---
        true_digit = sol_batch[row_idx, pos]  # [B]
        wrong_now = (pred_digit != true_digit) & (fail_step == 0)
        fail_step[wrong_now] = step_idx

        # Fill
        state[row_idx, pos] = pred_digit

    pred_batch = state

    # --- Metrics ---
    puzzle_correct_per = (pred_batch == sol_batch).all(dim=1).float()  # [B]
    puzzle_correct = puzzle_correct_per.mean().item()

    # Only evaluate on originally masked cells
    mask = (puz_batch == 0)
    if mask.sum() > 0:
        cell_accuracy = (pred_batch[mask] == sol_batch[mask]).float().mean().item()
    else:
        cell_accuracy = 1.0

    return pred_batch, puzzle_correct, cell_accuracy, fail_step, decode_order, decode_digits


def evaluate(policy, f_theta, dataset, batch_size=32, max_steps=40, device="cuda", selection_mode="argmax", lambda_schedule=None, save_decode_order=False, decode_order_path=None, ground_truth_jsonl=None):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    total_puzzles = 0
    correct_puzzles = 0
    correct_cells = 0
    total_cells = 0

    policy.eval()

    # Load ground truth order if needed
    ground_truth_order = None
    if selection_mode == "ground_truth":
        if ground_truth_jsonl is None:
            raise ValueError("ground_truth_jsonl must be provided when selection_mode='ground_truth'")
        gt_data = load_ground_truth_order(ground_truth_jsonl)
        ground_truth_order = gt_data['by_idx']  # Use index-based matching by default
        print(f"Loaded ground truth order for {len(gt_data['by_id'])} puzzles (by_id) and {len(gt_data['by_idx'])} puzzles (by_idx) from {ground_truth_jsonl}")

    # collect failing steps for failed puzzles
    fail_steps_list = []
    
    # collect decode orders for all puzzles
    all_decode_orders = []
    all_decode_digits = []
    # collect puzzles one-by-one with their decode digits
    puzzle_data = []  # List of dicts: [{"puzzle": [...], "decode_digits": [...]}, ...]

    with torch.no_grad():
        for batch_idx, (puz_batch, sol_batch) in enumerate(loader):
            # Only evaluate 2 batches
                
            puz_batch = puz_batch.to(device).long()
            sol_batch = sol_batch.to(device).long()

            # Prepare ground truth order for this batch
            batch_gt_order = None
            if selection_mode == "ground_truth" and ground_truth_order is not None:
                B = puz_batch.size(0)
                batch_gt_order = {}
                global_puzzle_idx = batch_idx * batch_size
                for b in range(B):
                    puzzle_idx = global_puzzle_idx + b
                    if puzzle_idx in ground_truth_order:
                        batch_gt_order[b] = ground_truth_order[puzzle_idx]
                    else:
                        batch_gt_order[b] = []  # Empty order if not found

            pred, pc, cc, fail_step, decode_order, decode_digits = solve_batch(
                policy, f_theta, puz_batch, sol_batch,
                max_steps=max_steps,
                selection_mode=selection_mode,
                lambda_schedule=lambda_schedule,
                ground_truth_order=batch_gt_order,
            )
            
            # Collect decode orders
            all_decode_orders.extend(decode_order)
            all_decode_digits.extend(decode_digits)
            
            # Collect puzzles with their decode digits one-by-one
            B = puz_batch.size(0)
            for b in range(B):
                # Convert linear positions (0-80) to (row, col) pairs and combine with digits
                # For 9x9 Sudoku: row = pos // 9 + 1, col = pos % 9 + 1 (1-indexed)
                # Format: [((row, col), digit), ...]
                decode_steps = [((pos // 9 + 1, pos % 9 + 1), digit) 
                               for pos, digit in zip(decode_order[b], decode_digits[b])]
                # --- ALIGN decode steps with candidate-set dynamics ---
                grid = puz_batch[b].cpu().tolist()
                aligned_steps = []

                for t, ((r, c), digit) in enumerate(decode_steps):
                    idx = (r - 1) * 9 + (c - 1)

                    cand = compute_candidates(grid)
                    if cand == []:
                        break

                    k = len(cand[idx])
                    is_ns = (k == 1)
                    is_hs = is_hidden_single(grid, cand, idx)

                    aligned_steps.append({
                        "t": t,
                        "pos": [r, c],
                        "val": digit,
                        "k": k,
                        "is_naked_single": is_ns,
                        "is_hidden_single": is_hs,
                    })

                    # apply fill
                    grid[idx] = digit

                puzzle_data.append({
                    "puzzle": puz_batch[b].cpu().tolist(),
                    "steps": aligned_steps
                })


            # per-puzzle success/failure
            success_mask = (pred == sol_batch).all(dim=1)  # [B]
            num_success = success_mask.sum().item()
            failed_mask = ~success_mask

            total_puzzles += B
            correct_puzzles += num_success

            # cell metrics
            mask = (puz_batch == 0)
            correct_cells += (pred[mask] == sol_batch[mask]).sum().item()
            total_cells += mask.sum().item()

            # record failing steps for failed puzzles (and only if > 0)
            if failed_mask.any():
                fs = fail_step[failed_mask]
                fs = fs[fs > 0]  # ignore "never wrong but unsolved" corner case
                if fs.numel() > 0:
                    fail_steps_list.append(fs.cpu())

            if (batch_idx + 1) % 5 == 0:
                puzzle_acc = correct_puzzles / total_puzzles
                cell_acc = correct_cells / total_cells if total_cells > 0 else 1.0
                print(f"[Eval] Processed {total_puzzles} puzzles "
                      f"(batch {batch_idx+1}) "
                      f"puzzle_acc={puzzle_acc*100:.2f}% "
                      f"cell_acc={cell_acc*100:.2f}%")

    puzzle_acc = correct_puzzles / total_puzzles
    cell_acc = correct_cells / total_cells if total_cells > 0 else 1.0

    print(f"Puzzle Accuracy: {puzzle_acc*100:.2f}%")
    print(f"Cell Accuracy:   {cell_acc*100:.2f}%")

    if fail_steps_list:
        fail_steps_all = torch.cat(fail_steps_list).numpy()
    else:
        fail_steps_all = np.array([])

    # Save decode order if requested (save_decode_order flag) or if decode_order_path is provided
    if save_decode_order or decode_order_path:
        import json
        decode_order_data = {
            "selection_mode": selection_mode,
            "total_puzzles": total_puzzles,
            "puzzle_accuracy": puzzle_acc,
            "cell_accuracy": cell_acc,
            "puzzles": puzzle_data,  # List of dicts: [{"puzzle": [...], "decode_steps": [((row, col), digit), ...]}, ...]
        }
        with open(decode_order_path, "w") as f:
            for item in puzzle_data:
                f.write(json.dumps(item) + "\n")

        print(f"Saved aligned decode logs to {decode_order_path} (JSONL)")
        print(f"  Saved {len(puzzle_data)} puzzles, each with puzzle input and decode steps")
        if puzzle_data:
            print(f"  Example: First puzzle has {len(puzzle_data[0]['steps'])} decode steps")

    return puzzle_acc, cell_acc, fail_steps_all, all_decode_orders, all_decode_digits


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=int, default=1028, help="Diffusion model size (e.g., 1028)")
    parser.add_argument("--ckpt", type=str, help="Diffusion checkpoint path (.pt or .pth)")
    parser.add_argument("--order_ckpt", type=str, help="Trained order policy checkpoint")
    parser.add_argument("--jsonl", type=str, required=True, help="Evaluation JSONL with puzzles/solutions")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--max_steps", type=int, default=81)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument(
        "--selection_mode",
        type=str,
        default="argmax",
        choices=["argmax", "margin", "random", "entropy", "balanced", "ground_truth"],
        help="Position selection mode: argmax (top-1 confidence), margin (top1-top2), random, entropy, balanced, or ground_truth (use human order from JSONL)",
    )
    parser.add_argument(
        "--ground_truth_jsonl",
        type=str,
        default='decoding_human_full.jsonl',
        help="Path to decoding_human.jsonl file (required when selection_mode='ground_truth')",
    )
    parser.add_argument(
        "--fail_hist_path",
        type=str,
        default="fail_step_hist_balanced.png",
        help="Path to save histogram of failing steps for unsolved puzzles",
    )
    parser.add_argument(
        "--save_decode_order",
        action="store_true",
        help="Save decode order (selected positions) to a JSON file",
    )
    parser.add_argument(
        "--decode_order_path",
        type=str,
        default="decode_policy.jsonl",
        help="Path to save decode order JSON file (will save if --save_decode_order is set)",
    )
    args = parser.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    # Diffusion teacher (load on CPU first to avoid CUDA OOM)
    model_name = f"Diff_LLaMA_{args.model}M"
    config = Config.from_name(model_name)
    diffusion = TransEncoder(config)
    
    # Load checkpoint - handle both .pt and .pth files
    # .pt files typically contain {"model_state_dict": ..., "epoch": ..., ...}
    # .pth files might need load_mdm_state_dict for special processing
    if args.ckpt.endswith('.pt'):
        # Direct loading for .pt files (CPU first to avoid CUDA OOM)
        ckpt_data = torch.load(args.ckpt, map_location="cpu")
        if isinstance(ckpt_data, dict) and "model_state_dict" in ckpt_data:
            state_dict = ckpt_data["model_state_dict"]
        elif isinstance(ckpt_data, dict) and "state_dict" in ckpt_data:
            state_dict = ckpt_data["state_dict"]
        elif isinstance(ckpt_data, dict):
            # Try to use the dict directly as state_dict
            state_dict = ckpt_data
        else:
            state_dict = ckpt_data
        diffusion.load_state_dict(state_dict, strict=True)
    else:
        # Use load_mdm_state_dict for .pth files (original behavior)
        state_dict = load_mdm_state_dict(args.ckpt)
        diffusion.load_state_dict(state_dict, strict=True)
    # Move to target device after load
    diffusion = diffusion.to(device)
    diffusion.eval()

    tok, digit2id, id2digit, _ = build_tokenizer_and_digit_maps(MASK_ID=32000)
    cfg = TrainConfig(MASK_ID=tok.pad_token_id, epochs=1, batch_size=args.batch_size, lr=3e-4, seed=1234)
    f_theta = make_f_theta_from_model(diffusion, digit2id, mask_id=cfg.MASK_ID, use_amp_bf16=True)

    # Order policy
    policy = FusionOrderPolicy(d_model=256, nhead=8, mlp_hidden=256, topk=2, f_theta=f_theta)
    ckpt = torch.load(args.order_ckpt, map_location="cpu")
    state_dict = ckpt.get("model_state_dict", ckpt)
    policy.load_state_dict(state_dict, strict=False)
    policy = policy.to(device)

    # Lambda schedule
    lambda_schedule = LambdaSchedule(hidden=256).to(device)


    # Dataset
    dataset = SudokuTokenDataset(args.jsonl)

    puzzle_acc, cell_acc, fail_steps, decode_orders, decode_digits = evaluate(
        policy,
        f_theta,
        dataset,
        batch_size=args.batch_size,
        max_steps=args.max_steps,
        device=device,
        selection_mode=args.selection_mode,
        lambda_schedule=lambda_schedule,
        save_decode_order=args.save_decode_order,
        decode_order_path=args.decode_order_path,
        ground_truth_jsonl=args.ground_truth_jsonl,
    )
    
    # --- Plot distribution of failing steps ---
    if fail_steps.size > 0:
        plt.figure()
        # bins from 1..max_steps inclusive (left-aligned)
        bins = np.arange(1, args.max_steps + 2)
        plt.hist(fail_steps, bins=bins, align="left", rwidth=0.8)
        plt.xlabel("Failing step (first wrong fill)")
        plt.ylabel("Number of failed puzzles")
        plt.title(f"Distribution of failing steps in test dataset(mode=order policy)")
        out_path = Path(args.fail_hist_path)
        plt.tight_layout()
        plt.savefig(out_path)
        print(f"Saved failing-step histogram to {out_path}")
    else:
        print("No failed puzzles or no recorded failing steps; no histogram created.")
    

if __name__ == "__main__":
    main()
