# ============================================================
# RL Order Policy for Sudoku (Tokenizer-based, MASK=32000)
# Minimal: policy NN + teacher-forced reward from your model
# Model API: logits = model(ids)[0] with shape [B,81,Vocab]
# JSONL: each line is a dict with keys "puzzles" and "solutions"
#        (each a flat list of 81 ints; 0 = empty, 1..9 = digit)
# ============================================================
#!/usr/bin/env python3
import sys, os
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Callable
import math
import time
import argparse, json, random, torch
import torch.multiprocessing as mp
from copy import deepcopy

if torch.cuda.is_available():
    from torch.backends.cuda import sdp_kernel

    # Force the safe (matmul) attention backend
    sdp_kernel(
        enable_flash=False,
        enable_math=True,
        enable_mem_efficient=False,
    )


# put this at the top of GRPO_order.py (before importing modules that import flash_attn)
import sys, types, math, torch, torch.nn.functional as F
import importlib.util
import importlib.machinery
import importlib.abc

def _fa_flash_stub(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, **kwargs):
    def _fp32c(t):
        if torch.is_tensor(t) and t.is_floating_point(): t = t.float()
        return t.contiguous() if torch.is_tensor(t) and not t.is_contiguous() else t
    q, k, v = _fp32c(q), _fp32c(k), _fp32c(v)

    if hasattr(F, "scaled_dot_product_attention"):
        return F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=dropout_p if torch.is_grad_enabled() else 0.0,
            is_causal=causal,
        )
    d = q.size(-1)
    scale = (1.0 / math.sqrt(d)) if softmax_scale is None else softmax_scale
    scores = (q @ k.transpose(-2, -1)) * scale
    if causal:
        L = scores.size(-2)
        mask = torch.ones_like(scores, dtype=torch.bool).tril()
        scores = scores.masked_fill(~mask, float("-inf"))
    P = scores.softmax(dim=-1)
    if torch.is_grad_enabled() and dropout_p > 0:
        P = F.dropout(P, p=dropout_p)
    return P @ v

# Create a proper module-like object with __spec__ to avoid import errors
flash_attn_module = types.ModuleType("flash_attn")
flash_attn_module.flash_attn_func = _fa_flash_stub
flash_attn_module.__file__ = "<stub>"
flash_attn_module.__package__ = "flash_attn"

class MinimalLoader(importlib.abc.Loader):
    def create_module(self, spec):
        return None
    def exec_module(self, module):
        pass

try:
    # Create a proper ModuleSpec
    spec = importlib.util.spec_from_loader("flash_attn", MinimalLoader(), origin="<stub>")
    if spec is not None:
        flash_attn_module.__spec__ = spec
        flash_attn_module.__loader__ = MinimalLoader()
    else:
        # Fallback: create a minimal spec-like object
        flash_attn_module.__spec__ = types.SimpleNamespace(
            name="flash_attn",
            loader=None,
            origin="<stub>",
            submodule_search_locations=[],
            has_location=False,
            cached=None,
            loader_state=None
        )
except Exception:
    # Final fallback: create a minimal spec-like object
    flash_attn_module.__spec__ = types.SimpleNamespace(
        name="flash_attn",
        loader=None,
        origin="<stub>",
        submodule_search_locations=[],
        has_location=False,
        cached=None,
        loader_state=None
    )

sys.modules["flash_attn"] = flash_attn_module

# CRITICAL: Set multiprocessing start method to 'spawn' BEFORE any CUDA
# This prevents fork-after-CUDA deadlocks with DataLoader workers
try:
    mp.set_start_method('spawn', force=True)
except RuntimeError:
    pass  # Already set
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm
import numpy as np
import wandb
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

# ------------------ your repo wiring ------------------
REPO_ROOT = Path("")
sys.path.insert(0, str(REPO_ROOT))

from lit_gpt.model_cache import Config
from lit_gpt.diffmodel import TransEncoder
from inference_sudoku import load_mdm_state_dict
from dataclasses import dataclass
from datetime import timedelta


# ============================================================
# Distributed Training Utilities
# ============================================================
def setup_distributed():
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ['RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        local_rank = int(os.environ['LOCAL_RANK'])
        torch.cuda.set_device(local_rank)  # set device first
        dist.init_process_group(backend='nccl', timeout=timedelta(seconds=1800))
        dist.barrier()  # ensure all ranks initialized
        return rank, world_size, local_rank
    else:
        return 0, 1, 0

def cleanup_distributed():
    """Clean up distributed training."""
    if dist.is_initialized():
        dist.destroy_process_group()

def is_main_process(rank=0):
    """Check if this is the main process (rank 0)."""
    return rank == 0

def set_seed(seed: int = 1234):
    import numpy as np
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# -----------------------------
# Config
# -----------------------------
@dataclass
class TrainConfig:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size: int = 64
    lr: float = 2e-6
    epochs: int = 5
    max_steps: int = 81
    entropy_weight: float = 1e-3
    grad_clip: float = 1.0
    ema_lambda: float = 0.95
    log_every: int = 50
    # policy model
    d_model: int = 128
    nhead: int = 8
    num_layers: int = 4
    dim_feedforward: int = 256
    dropout: float = 0.1
    # tokenizer/mask
    MASK_ID: int = 32000
    use_amp_bf16: bool = True
    seed: int = 1234


# ============================================================
# 1) Your tokenizer builder (verbatim style)
# ============================================================
def build_tokenizer_and_digit_maps(MASK_ID: int = 32000):
    """
    Returns:
      tok:           AutoTokenizer with pad_token_id = MASK_ID
      digit2id:      {0..9 -> vocab id}  (NOTE: includes 0)
      id2digit:      {vocab id -> 0..9}
      digit_vocab_ids: LongTensor[10]    ids for tokens 0..9 in vocab order
    """
    tok = AutoTokenizer.from_pretrained(
        "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
        padding_side="right",
        use_fast=True,
    )
    tok.add_special_tokens({'pad_token': '[PAD]'})
    tok.pad_token_id = MASK_ID

    def digit_id(d: int) -> int:
        # IMPORTANT: space-prefix to match your original encoding
        ids = tok.encode(f" {d}", add_special_tokens=False)
        return ids[-1]

    digit2id = {d: digit_id(d) for d in range(1, 10)}  # 1..9 only
    digit2id[0] = MASK_ID  # Digit 0 (empty cell) maps to MASK_ID, not the "0" token
    id2digit = {v: k for k, v in digit2id.items()}
    digit_vocab_ids = torch.tensor([digit2id[d] for d in range(10)], dtype=torch.long)  # 0..9
    return tok, digit2id, id2digit, digit_vocab_ids

# ============================================================
# 2) Dataset (reads your JSONL)
# ============================================================
class SudokuTokenDataset(Dataset):
    """
    Each item:
      - init_tokens: LongTensor [81] in {0..9}, 0 is empty/mask
      - solution_tokens: LongTensor [81] in {1..9}
    JSONL format (per line): {"puzzles": [...81 ints...], "solutions": [...81 ints...]}
    """
    def __init__(self, jsonl_path: str, limit: int = 0):
        # Store as lists, NOT tensors, to avoid shared memory issues with multiprocessing
        self.items: List[Tuple[List[int], List[int]]] = []
        with open(jsonl_path, "r") as f:
            for line in f:
                ex = json.loads(line)
                if "puzzles" in ex and "solutions" in ex:
                    puz = ex["puzzles"]
                    sol = ex["solutions"]
                    # Handle both cases: either they are single boards (len=81), or nested
                    if isinstance(puz, list) and len(puz) == 81 and isinstance(sol, list) and len(sol) == 81:
                        self.items.append((puz, sol))  # Store as lists
                    elif isinstance(puz, list) and isinstance(sol, list) and len(puz) == len(sol):
                        for p, s in zip(puz, sol):
                            if isinstance(p, list) and isinstance(s, list) and len(p) == 81 and len(s) == 81:
                                self.items.append((p, s))  # Store as lists
                # (Optional) also support {"puzzle": [...], "solution": [...]}
                elif "puzzle" in ex and "solution" in ex:
                    p = ex["puzzle"]; s = ex["solution"]
                    if isinstance(p, list) and isinstance(s, list) and len(p) == 81 and len(s) == 81:
                        self.items.append((p, s))  # Store as lists
        if limit and limit > 0:
            self.items = self.items[:limit]
        if not self.items:
            raise ValueError("No (puzzle, solution) pairs found in JSONL.")

    def __len__(self): return len(self.items)
    
    def __getitem__(self, i):
        # Convert to tensors in worker process (avoids shared memory issues)
        puz, sol = self.items[i]
        return (torch.tensor(puz, dtype=torch.long),
                torch.tensor(sol, dtype=torch.long))



# ============================================================
# 3) Minimal Sudoku helpers
# ============================================================
def all_filled(tokens: torch.Tensor) -> torch.Tensor:
    """tokens: [B,81] -> [B] bool: True if no zeros remain."""
    return (tokens == 0).sum(dim=1) == 0

def mask_empty_positions(tokens: torch.Tensor) -> torch.Tensor:
    """1 for empty (selectable), 0 otherwise. Shape [B,81]."""
    return (tokens == 0).float()

def tokens_fill_truth(tokens: torch.Tensor, pos: torch.Tensor, gt_digit: torch.Tensor) -> torch.Tensor:
    """Teacher forcing: set tokens[b, pos[b]] = gt_digit[b]."""
    out = tokens.clone()
    out[torch.arange(tokens.size(0), device=tokens.device), pos] = gt_digit
    return out

def gather_gt_digits(solution_tokens: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
    """Pick the ground-truth digit for each selected position (batchwise)."""
    return solution_tokens[torch.arange(solution_tokens.size(0), device=solution_tokens.device), pos]


# ============================================================
# 4)  order policy (no legality features)
# ============================================================
class LambdaSchedule(nn.Module):
    """
    λ(t) = sigmoid(MLP(t/T)) ∈ (0,1).
    Input: t_norm shape [B,1]
    Output: λ(t) shape [B]
    """
    def __init__(self, hidden: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
        )

    def forward(self, t_norm: torch.Tensor) -> torch.Tensor:
        """
        t_norm: [B,1] or [B]
        returns: [B] in (0,1)
        """
        if t_norm.dim() == 1:
            t_norm = t_norm.unsqueeze(-1)  # [B,1]
        raw = self.net(t_norm)            # [B,1]
        lam = torch.sigmoid(raw)          # (0,1)
        return lam.squeeze(-1)            # [B]

'''
class OrderPolicyWithFusion(nn.Module):
    """
    tokens: [B,81] in {0..9}, 0=empty
    returns logits over 81 positions: [B,81]
    Architecture:
      tokens -> token/pos embeddings -> 1-layer Transformer -> pos_feat[a]
      f_theta(tokens) -> TopK digit probs per pos -> conf_feat[a]
      concat [pos_feat[a]; conf_feat[a]] -> 3-layer MLP -> score[a]
    """
    def __init__(self, d_model=256, nhead=8, mlp_hidden=256, topk=3, vocab_size=10, freeze_f_theta=True, use_conf=False):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.mlp_hidden = mlp_hidden
        self.topk = topk
        self.freeze_f_theta = freeze_f_theta
        self.encoder_chunk_size = 4096

        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb   = nn.Embedding(81, d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.encoder  = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.use_conf = use_conf

        self.fusion = nn.Sequential(
            nn.Linear(d_model + topk, mlp_hidden),
            nn.GELU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.GELU(),
            nn.Linear(mlp_hidden, 1),
        )

    def forward(self, tokens, f_theta):
        B = tokens.size(0)
        device = tokens.device
        
        if B == 0:
            return torch.zeros(
                (0, tokens.size(1)),
                device=device,
                dtype=self.fusion[-1].weight.dtype,
            )
        

        pos_indices = torch.arange(tokens.size(1), device=device)
        pos_emb = self.pos_emb(pos_indices).unsqueeze(0)

        h_chunks = []
        topk_chunks = [] if self.use_conf else None
        chunk_size = self.encoder_chunk_size

        for tokens_chunk in tokens.split(chunk_size, dim=0):
            if tokens_chunk.size(0) == 0:
                continue

            pos_emb_chunk = pos_emb.expand(tokens_chunk.size(0), -1, -1)
            x_chunk = self.token_emb(tokens_chunk) + pos_emb_chunk
            x_chunk = x_chunk.contiguous()
            h_chunk = self.encoder(x_chunk).contiguous()                 # [b,81,d]
            h_chunks.append(h_chunk)

            if self.use_conf:
                if self.freeze_f_theta:
                    with torch.no_grad():
                        with torch.cuda.amp.autocast(enabled=False):
                            logits_chunk = f_theta(tokens_chunk)
                else:
                    with torch.cuda.amp.autocast(enabled=False):
                        logits_chunk = f_theta(tokens_chunk)

                probs_chunk = torch.softmax(logits_chunk.float(), dim=-1)
                topk_vals_chunk = torch.topk(probs_chunk, k=self.topk, dim=-1).values
                topk_chunks.append(topk_vals_chunk.to(h_chunk.dtype))

        h = torch.cat(h_chunks, dim=0)

        if self.use_conf:
            topk_vals = torch.cat(topk_chunks, dim=0)
        else:
            topk_vals = torch.zeros(
                (B, tokens.size(1), self.topk),
                device=h.device,
                dtype=h.dtype,
            )


        # fuse & score
        feats  = torch.cat([h, topk_vals], dim=-1).contiguous()            # [B,81,d+K]
        scores = self.fusion(feats).squeeze(-1)                            # [B,81]
        return scores

'''

class OrderPolicyWithFusion(nn.Module):
    """
    Simple λ(t)-based order policy.

    - tokens: [B,81] in {0..9}, 0=empty
    - f_theta: frozen diffusion head giving logits over digits 1..9
    - λ(t) is a scalar schedule (via LambdaSchedule) shared across positions at step t

    At each call:
      scores[b,i] = top1_prob[b,i] + λ_t[b] * entropy[b,i]
    """
    def __init__(self, f_theta, lambda_hidden: int = 64):
        super().__init__()
        self.f_theta = f_theta
        self.lambdanet = LambdaSchedule(hidden=lambda_hidden)

    def forward(self, tokens: torch.Tensor, t_norm: torch.Tensor) -> torch.Tensor:
        """
        tokens: [B,81] in {0..9}
        t_norm: [B] or [B,1], normalized time t/T

        returns:
            scores: [B,81] (logits over positions, we will mask non-empties outside)
        """
        device = tokens.device
        B = tokens.size(0)

        # diffusion logits over digits 1..9
        with torch.no_grad():
            logits_81x9 = self.f_theta(tokens)           # [B,81,9], likely bf16
        logits_81x9 = logits_81x9.float()                # work in fp32 for stability

        probs_81x9 = F.softmax(logits_81x9, dim=-1)      # [B,81,9]
        log_probs  = F.log_softmax(logits_81x9, dim=-1)  # [B,81,9]

        # top1 prob and entropy per cell
        top1, _ = probs_81x9.max(dim=-1)                 # [B,81]
        entropy = -(probs_81x9 * log_probs).sum(dim=-1)  # [B,81]

        # λ(t) per puzzle
        t_norm = t_norm.to(device)
        lam_t = self.lambdanet(t_norm)                   # [B]

        # scores for each position
        scores = top1 + lam_t.unsqueeze(1) * entropy     # [B,81]
        return scores


class FusionOrderPolicy(OrderPolicyWithFusion):
    """
    Thin wrapper so we can call the fusion policy with stored f_theta, matching the GRPO codepath.
    Accepts old arguments (d_model, nhead, etc.) for backward compatibility but ignores them.
    """
    def __init__(self, *args, f_theta=None, lambda_hidden=64, d_model=None, nhead=None, mlp_hidden=None, topk=None, **kwargs):
        # Extract f_theta from kwargs if not provided directly
        if f_theta is None:
            f_theta = kwargs.pop('f_theta', None)
        if f_theta is None:
            raise ValueError("FusionOrderPolicy requires f_theta")
        # Pass only the arguments that OrderPolicyWithFusion expects
        # Ignore d_model, nhead, mlp_hidden, topk for backward compatibility
        super().__init__(f_theta=f_theta, lambda_hidden=lambda_hidden)
        # Store old kwargs for reference (not used by new implementation)
        self._old_kwargs = {'d_model': d_model, 'nhead': nhead, 'mlp_hidden': mlp_hidden, 'topk': topk}

    def forward(self, tokens: torch.Tensor, t_norm: torch.Tensor = None) -> torch.Tensor:
        """
        tokens: [B,81] in {0..9}
        t_norm: [B] or [B,1] normalized time, or None (defaults to zeros for backward compat)
        """
        if t_norm is None:
            # Default to t_norm=0 for backward compatibility
            B = tokens.size(0)
            t_norm = torch.zeros(B, device=tokens.device)
        return super().forward(tokens, t_norm=t_norm)


# ============================================================
# 5) Model-to-reward adapter (use your model directly)
# ============================================================
def sudoku_tokens_to_vocab_ids(tokens: torch.Tensor, digit2id: Dict[int,int], mask_id: int) -> torch.Tensor:
    """
    Map Sudoku tokens to vocab ids for the model forward:
      0 -> mask_id; 1..9 -> digit2id[d]
    Returns: ids [B,81] (LongTensor)
    """
    ids = torch.full_like(tokens, fill_value=mask_id, dtype=torch.long, device=tokens.device)
    for d in range(1, 10):
        ids[tokens == d] = digit2id[d]
    return ids

def make_f_theta_from_model(
    model: nn.Module,
    digit2id: Dict[int, int],
    mask_id: int = 32000,
    use_amp_bf16: bool = True,
) -> Callable[[torch.Tensor], torch.Tensor]:
    """
    Your model API: model(ids) returns a tuple; logits = model(ids)[0] with shape [B,81,Vocab].
    Returns f_theta(tokens[B,81 in {0..9}]) -> logits[B,81,9] over digits 1..9.
    """
    try:
        device = next(model.parameters()).device
    except StopIteration:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    digit_ids = torch.tensor([digit2id[d] for d in range(1, 10)], dtype=torch.long, device=device)

    @torch.no_grad()
    def f_theta(tokens: torch.Tensor) -> torch.Tensor:
        # Map Sudoku tokens to vocab ids (0->mask_id, d->digit2id[d])
        ids = sudoku_tokens_to_vocab_ids(tokens.to(device), digit2id, mask_id)  # [B,81]

        ctx = torch.cuda.amp.autocast(dtype=torch.bfloat16) if (use_amp_bf16 and device.type == "cuda") else torch.no_grad()
        with ctx:
            out = model(ids)               # tuple/list or tensor
        logits = out[0] if isinstance(out, (list, tuple)) else out  # [B,81,V]
        if logits.dim() == 2 and logits.shape[0] == ids.shape[1]:
            logits = logits.unsqueeze(0)  # single item fallback

        # gather only digit columns 1..9
        logits = torch.as_tensor(logits, device=device)
        logits_9 = logits.index_select(-1, digit_ids)      # [B,81,9]
        return logits_9

    return f_theta


@torch.no_grad()
def reward_logprob(
    f_theta: Callable[[torch.Tensor], torch.Tensor],
    tokens: torch.Tensor,             # [B,81] in {0..9}
    pos_idx: torch.Tensor,            # [B] in 0..80
    gt_digit: torch.Tensor,           # [B] in 1..9
) -> torch.Tensor:
    """
    r_t = log p_theta(gt_digit | tokens, pos_idx)
    """
    logits_81x9 = f_theta(tokens)   # [B,81,9]
    ls = logits_81x9[torch.arange(tokens.size(0), device=tokens.device), pos_idx]  # [B,9]
    logp = F.log_softmax(ls, dim=-1)
    return logp[torch.arange(tokens.size(0), device=tokens.device), gt_digit - 1]  # [B]

# ============================================================
# 6) Sample trajectory rollout with policy
# ============================================================
#sample the reward for one trajectory
@torch.no_grad()
def sample_trajectories_with_policy_batched(
    policy_frozen: nn.Module,
    f_theta,                               # used by reward_logprob
    init_tokens: torch.Tensor,             # [B,81], {0..9}
    solution_tokens: torch.Tensor,         # [B,81], {1..9}
    max_steps: int,
    n_traj_per_puzzle: int = 1,
    sample: bool = True,                   # True: sample from π_old; False: greedy
):
    """
    Parallel version of your sample_trajectory_with_policy:
      - Rolls out B*N trajectories (B puzzles, N per puzzle)
      - Records states BEFORE each action
      - Step reward r_t = log p_theta(gt_digit | state, pos)

    Returns dict with:
      states_pad:  [T, BN, 81]  tokens before each action (padded)
      actions_pad: [T, BN]      chosen positions (-1 after length)
      rewards_pad: [T, BN]      per-step rewards (0 after length)
      done_steps:  [BN]         number of steps taken for each traj
      B, N, BN, T: ints
    """
    device = init_tokens.device
    dtype_long = torch.long

    B = init_tokens.size(0)
    N = int(n_traj_per_puzzle)
    BN = B * N
    T = int(max_steps)

    # Repeat puzzles N times → big batch
    tokens = init_tokens.to(device, dtype_long).repeat_interleave(N, dim=0)      # [BN,81]
    sol    = solution_tokens.to(device, dtype_long).repeat_interleave(N, dim=0)  # [BN,81]

    # Buffers (padded)
    states_pad  = torch.zeros(T, BN, 81, dtype=tokens.dtype, device=device)
    actions_pad = torch.full((T, BN), -1, dtype=dtype_long, device=device)
    rewards_pad = torch.zeros(T, BN, dtype=torch.float32, device=device)

    done = torch.zeros(BN, dtype=torch.bool, device=device)
    done_steps = torch.zeros(BN, dtype=dtype_long, device=device)  # first t+1 when finished (or T)

    for t in range(T):
        # Record state BEFORE action
        states_pad[t] = tokens

        # Valid positions (1.0 for empty)
        valid = mask_empty_positions(tokens)                 # [BN,81] float in {0,1}

        # Rows with no valid actions (already finished this step)
        no_act_now = (valid.sum(dim=1) == 0)                 # [BN] bool
        # Mark them done immediately (prevents sampling on all -inf rows)
        newly_done_pre = (~done) & no_act_now
        done_steps = torch.where(newly_done_pre,
                                torch.full_like(done_steps, t), done_steps)
        done = done | newly_done_pre

        # Indices that can act this step
        can_act = (~done) & (valid.sum(dim=1) > 0)           # [BN] bool
        idx_act = can_act.nonzero(as_tuple=False).squeeze(-1)  # [M]

        # Default fillers
        a = torch.zeros(BN, dtype=dtype_long, device=device)
        r_t = torch.zeros(BN, dtype=rewards_pad.dtype, device=device)

        if idx_act.numel() > 0:
            # Compute logits and sanitize
            t_norm_all = torch.full((BN, 1), t / T, device=device)
            t_norm_act = t_norm_all.index_select(0, idx_act)      # [M,1]
            pos_logits_all = policy_frozen(tokens)           # [BN,81]
            pos_logits_all = torch.nan_to_num(pos_logits_all, nan=0.0,
                                            posinf=1e30, neginf=-1e30)

            # Mask only on active rows
            masked_logits = pos_logits_all.index_select(0, idx_act)  # [M,81]
            valid_act = valid.index_select(0, idx_act).bool()        # [M,81]
            masked_logits = masked_logits.masked_fill(~valid_act, float('-inf'))

            # Extra safety: if any row still becomes all -inf (shouldn't), set to zeros
            all_neg_inf = torch.isneginf(masked_logits).all(dim=1)   # [M]
            if all_neg_inf.any():
                masked_logits[all_neg_inf] = 0.0

            # Sample (or greedy) on active rows
            if sample:
                a_act = torch.distributions.Categorical(logits=masked_logits).sample()  # [M]
            else:
                a_act = masked_logits.argmax(dim=-1)

            # Scatter actions back to full batch
            a[idx_act] = a_act

            # Gather gt digits + compute rewards on active rows
            gt_digit = sol[idx_act, a_act]
            r_act = reward_logprob(f_theta=f_theta,
                                tokens=tokens.index_select(0, idx_act),  # [M,81]
                                pos_idx=a_act,                            # [M]
                                gt_digit=gt_digit)                        # [M]
            r_t[idx_act] = r_act.to(r_t.dtype)

        # Record action & reward (finished/non-active rows stay at zeros)
        actions_pad[t] = a
        rewards_pad[t] = r_t

        # Teacher forcing only for active rows
        if idx_act.numel() > 0:
            tokens = tokens.clone()
            tokens[idx_act, a[idx_act].long()] = sol[idx_act, a[idx_act].long()]

        # Recompute valid and close out anyone who just finished after filling
        valid_next = mask_empty_positions(tokens)                   # [BN,81]
        newly_done_post = (~done) & (valid_next.sum(dim=1) == 0)
        done_steps = torch.where(newly_done_post,
                                torch.full_like(done_steps, t+1), done_steps)
        done = done | newly_done_post

        if done.all():
            break


    # Any still-unset lengths → set to T
    unset = (done_steps == 0)
    if unset.any():
        done_steps[unset] = T
    # Inspect a sample trajectory for debugging


    return {
        "states_pad":  states_pad,    # [T, BN, 81] tokens BEFORE each action
        "actions_pad": actions_pad,   # [T, BN]     positions (-1 for padded)
        "rewards_pad": rewards_pad,   # [T, BN]     r_t (0 for padded)
        "done_steps":  done_steps,    # [BN]        #steps taken
        "B": B, "N": N, "BN": BN, "T": T,
    }

@torch.no_grad()
def sample_trajectory_with_policy(
    policy_frozen: nn.Module,
    f_theta: Callable[[torch.Tensor], torch.Tensor],
    init_tokens: torch.Tensor,         # [81], {0..9}
    solution_tokens: torch.Tensor,     # [81], {1..9}
    max_steps: int
):
    """
    Returns:
        states:       List[Tensor[81]]  tokens *before* each action
        actions:      List[int]         chosen positions
        rewards_step: List[float]       r_t = log p_theta(gt | state, pos)
        done_steps:   int               number of steps actually taken
    """
    device = init_tokens.device
    tokens = init_tokens.clone()
    states, actions, rewards_step = [], [], []

    for t in range(max_steps):
        valid = mask_empty_positions(tokens.unsqueeze(0)).squeeze(0)  # [81] 1 if empty
        if valid.sum().item() == 0:
            break

        t_norm = torch.tensor([[t / max_steps]], device=tokens.device)  # [1,1]
        pos_logits = policy_frozen(tokens.unsqueeze(0), t_norm)         # [1,81]

        masked_logits = pos_logits.clone()
        masked_logits[0][valid == 0] = float('-inf')

        # sample from π_old
        dist = torch.distributions.Categorical(logits=masked_logits[0])
        a = dist.sample().item()   # int in [0..80]

        # reward from diffusion head
        gt_digit = solution_tokens[a].item()                 # 1..9
        r_t = reward_logprob(f_theta,
                             tokens.unsqueeze(0),
                             torch.tensor([a], device=device),
                             torch.tensor([gt_digit], device=device)).item()

        # record BEFORE filling
        states.append(tokens.clone())
        actions.append(a)
        rewards_step.append(r_t)

        # teacher forcing: fill ground-truth at chosen pos
        tokens[a] = gt_digit

    return states, actions, rewards_step, len(actions)

'''
#KL penalty for the new policy at a given state
def kl_new_old_at_state(policy_new: nn.Module,
                        policy_old: nn.Module,
                        state_tokens: torch.Tensor  # [81]
                        ) -> torch.Tensor:
    """
    Returns scalar KL( π_new(.|s) || π_old(.|s) ) over VALID indices only,
    computed from re-normalized distributions on that subset.
    Guards against degenerate cases.
    """
    with torch.no_grad():
        valid_mask = (state_tokens == 0)  # 1 where empty/valid to choose
    valid_idx = valid_mask.nonzero(as_tuple=False).view(-1)  # [K]

    # If nothing (or only 1) is valid, KL is zero by definition for our use.
    if valid_idx.numel() <= 1:
        return torch.zeros((), device=state_tokens.device)

    # Get logits and restrict to valid indices
    logits_new_full = policy_new(state_tokens.unsqueeze(0))[0]  # [81]
    logits_old_full = policy_old(state_tokens.unsqueeze(0))[0]  # [81]

    logits_new = logits_new_full.index_select(0, valid_idx)
    logits_old = logits_old_full.index_select(0, valid_idx)

    # Fresh log-softmax over ONLY the valid subset
    logp_new = F.log_softmax(logits_new, dim=-1)
    logp_old = F.log_softmax(logits_old, dim=-1)
    p_new = logp_new.exp()

    # KL over the valid subset
    kl = (p_new * (logp_new - logp_old)).sum()

    # Extra safety
    if not torch.isfinite(kl):
        kl = torch.nan_to_num(kl, nan=0.0, posinf=0.0, neginf=0.0)

    return kl
'''

# ============================================================
# 7) Training loop (REINFORCE + entropy)
# ============================================================
def train_order_policy_grpo_supervised(
    policy: nn.Module,
    f_theta: Callable[[torch.Tensor], torch.Tensor],
    train_loader: DataLoader,
    cfg: TrainConfig,
    n_traj_per_puzzle: int = 2,     # N trajectories per puzzle
    alpha: float = 1.0,             # reward temperature (tilt)
    beta_kl: float = 0.05,          # KL trust-region weight
    epochs: int = 3,
    rank: int = 0,
    world_size: int = 1,
    log_every: int = 10,
    eval_every: int = 200,
    val_loader: Optional[DataLoader] = None,
    max_eval_batches: Optional[int] = 10,
    save_path: str = "workdir/GRPO_order/policy_ckpt.pt",
    lr_warmup_epochs: int = 10,     # Use initial LR for first N epochs
    lr_initial: float = 1e-3,       # Initial LR
    lr_final: float = 1e-4,         # LR after warmup
):
    """
    Self-referenced GRPO/DPO-style supervised updates:
      - π_old := frozen copy of current policy
      - Roll out N trajectories per puzzle with π_old
      - Build per-puzzle normalized weights w_k ∝ exp(R_k / α)
      - Train π_new with weighted NLL on sampled actions + β * KL(π_new || π_old)
    """
    device = torch.device(cfg.device)
    
    # Get unwrapped model for saving/evaluation (unwrap DDP if present)
    model_without_ddp = policy.module if isinstance(policy, DDP) else policy
    
    policy.to(device)
    # Use initial LR for optimizer (will be adjusted by scheduler)
    opt = torch.optim.AdamW(policy.parameters(), lr=lr_initial, weight_decay=0.01)
    
    # Learning rate scheduler: lr_initial for first lr_warmup_epochs, then lr_final
    def lr_lambda(epoch):
        if epoch < lr_warmup_epochs:
            return 1.0  # Use initial LR
        else:
            return lr_final / lr_initial  # Scale to final LR
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda)
    
    global_step = 0
    best_val_metric = float('-inf')

    for epoch in range(epochs):
        # freeze a snapshot policy_old (use unwrapped model for copying)
        policy_old = deepcopy(model_without_ddp).eval().to(device)
        for p in policy_old.parameters():
            p.requires_grad_(False)

        policy.train()
        for batch_idx, (init_tokens, solution_tokens) in enumerate(train_loader):
            # Debug: detect puzzles that are already filled before rollout
            blanks_per_puzzle = (init_tokens == 0).sum(dim=1)
            if blanks_per_puzzle.min().item() == 0 and is_main_process(rank):
                print(
                    f"[debug] batch {batch_idx}: puzzles with no blanks",
                    blanks_per_puzzle.tolist(),
                )
            if batch_idx % 10 == 0 and is_main_process(rank):
                print(f"[Epoch {epoch}] Processing batch {batch_idx}...")
            global_step += 1
            batch_start = time.time()

            init_tokens = init_tokens.to(device).long()       # [B,81]
            solution_tokens = solution_tokens.to(device).long()
            B = init_tokens.size(0)
            with torch.no_grad():
                roll = sample_trajectories_with_policy_batched(
                        policy_frozen=policy_old,
                        f_theta=f_theta,
                        init_tokens=init_tokens,          # [B,81]
                        solution_tokens=solution_tokens,  # [B,81]
                        max_steps=cfg.max_steps,
                        n_traj_per_puzzle=n_traj_per_puzzle,
                        sample=True,                      # sampling for on-policy training
                    )
            B, N = roll["B"], roll["N"]


            # --- Rewards per trajectory ---
            R = roll["rewards_pad"].sum(dim=0).view(B, N)      # [B, N]
            R = torch.clamp(R, -50, 0)
            valid_traj = (roll["done_steps"].view(B, N) > 0)

            baseline = (R.masked_fill(~valid_traj, 0).sum(dim=1) /
                        valid_traj.sum(dim=1).clamp_min(1)).unsqueeze(1)

            A = (R - baseline).masked_fill(~valid_traj, 0.0)

            # mean & std per puzzle
            mean = (A.sum(dim=1) / valid_traj.sum(dim=1).clamp_min(1)).unsqueeze(1)
            var = ((A - mean)**2).masked_fill(~valid_traj, 0).sum(dim=1) / valid_traj.sum(dim=1).clamp_min(1)
            std = var.sqrt().clamp_min(1e-8).unsqueeze(1)

            A_norm = (A - mean) / std
            A_norm = A_norm.masked_fill(~valid_traj, 0.0)

            advantages = A_norm.detach()

            # [B, N]

            # --- Compute supervised weighted log-probs exactly as before, but replace weights with advantages ---

            T, BN = roll["T"], roll["BN"]
            device = init_tokens.device

            lens = roll["done_steps"]                             # [BN]
            alive = (torch.arange(T, device=device)[:, None] < lens[None, :])  # [T, BN]

            states_flat = roll["states_pad"].reshape(T * BN, 81).long()
            logits_flat = policy(states_flat)                     # [T*BN, 81]

            valid_flat = mask_empty_positions(states_flat).bool()
            logits_flat = logits_flat.masked_fill(~valid_flat, float('-inf'))

            logp_flat = torch.log_softmax(logits_flat, dim=-1)
            a_flat = roll["actions_pad"].reshape(T * BN)
            a_safe = torch.clamp(a_flat, min=0)
            logp_taken_flat = logp_flat[torch.arange(T * BN, device=device), a_safe]

            logp_taken = logp_taken_flat.view(T, BN)
            logp_taken = torch.where(alive, logp_taken, torch.zeros_like(logp_taken))

            traj_len = lens.clamp_min(1)
            mean_logp_traj = logp_taken.sum(dim=0) / traj_len      # [BN]

            mean_logp = mean_logp_traj.view(B, N)                  # [B, N]
            mean_logp = mean_logp.masked_fill(~valid_traj, 0.0)
            #print('mean_logp', mean_logp)
            ### GRPO: multiply by advantage (instead of softmax weights)
            loss_per_puzzle = - (advantages * mean_logp).sum(dim=1)     # [B]
            valid_puzzles = valid_traj.any(dim=1)                       # [B]
            denom = valid_puzzles.sum().clamp_min(1)

            loss = loss_per_puzzle[valid_puzzles].sum() / denom


            count_traj = int(valid_traj.sum().item())
            count_steps = float(lens.view(B, N)[valid_traj].sum().item())

            opt.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(policy.parameters(), cfg.grad_clip)
            opt.step()

            # --- logging --- (only on main process to avoid duplicate output)
            if (batch_idx % log_every == 0) and is_main_process(rank):
                bt = time.time() - batch_start
                avg_R = R.mean().item()  if B > 0 else 0.0
                current_lr = opt.param_groups[0]['lr']
                print(
                    f"[E{epoch} B{batch_idx}] "
                    f"loss={loss.item():.4f} "
                    #f"NLL={loss_nll_total.item():.4f} "
                    f"| avg_R={avg_R:.4f} steps/tra={count_steps/max(1,count_traj):.1f} "
                    f"| Ntraj={n_traj_per_puzzle} | lr={current_lr:.2e} | time={bt:.2f}s"
                )
                if wandb.run is not None:
                    wandb.log({
                        "train/loss": loss.item(),
                        "train/loss_nll": loss.item(),
                        "train/avg_R": avg_R,
                        # bookkeeping
                        "train/epoch": epoch,
                        "train/batch_idx": batch_idx,
                    }, step=global_step)

        
                if is_main_process(rank):
                    # Create checkpoint directory if it doesn't exist
                    save_dir = Path(save_path).parent
                    save_dir.mkdir(parents=True, exist_ok=True)
                    
                    # Save checkpoint for this epoch
                    epoch_ckpt_path = str(Path(save_path).parent / f"{Path(save_path).stem}_epoch{epoch:02d}_batch{batch_idx:04d}.pt")
                    
                    checkpoint = {
                        'epoch': epoch,
                        'global_step': global_step,
                        'model_state_dict': model_without_ddp.state_dict(),
                        'optimizer_state_dict': opt.state_dict(),
                        'config': {
                            'd_model': cfg.d_model,
                            'nhead': cfg.nhead,
                            'num_layers': cfg.num_layers,
                            'dim_feedforward': cfg.dim_feedforward,
                            'dropout': cfg.dropout,
                            'mask_id': cfg.MASK_ID,
                        },
                    }

                    if val_loader is not None:
                        policy.eval()
                        val_metrics = evaluate_order_policy(model_without_ddp, f_theta, val_loader, cfg, max_eval_batches=max_eval_batches)
                        policy.train()
                        
                        checkpoint['val_metrics'] = val_metrics
                        
                        print(f"\n[Epoch {epoch}] Validation metrics: {val_metrics}")
                        if wandb.run is not None:
                            wandb.log({
                                f"val/avg_logp_per_step": val_metrics['avg_logp_per_step'],
                                f"val/total_reward": val_metrics['total_reward'],
                                "epoch": epoch,
                            }, step=global_step)
                        
                        # Track best model
                        current_metric = val_metrics['avg_logp_per_step']
                        if current_metric > best_val_metric:
                            best_val_metric = current_metric
                            best_ckpt_path = str(Path(save_path).parent / f"{Path(save_path).stem}_best.pt")
                            torch.save(checkpoint, best_ckpt_path)
                            print(f"✓ New best model! Saved to {best_ckpt_path} (val_logp={current_metric:.4f})")
                            wandb.save(best_ckpt_path)
                    
                    # Save epoch checkpoint
                    torch.save(checkpoint, epoch_ckpt_path)
                    print(f"✓ Saved epoch {epoch} checkpoint to {epoch_ckpt_path}")
                    wandb.save(epoch_ckpt_path)
        
        # Step learning rate scheduler at end of each epoch
        scheduler.step()
        current_lr = opt.param_groups[0]['lr']
        if is_main_process(rank):
            print(f"[Epoch {epoch} end] Learning rate: {current_lr:.2e}")
            if wandb.run is not None:
                wandb.log({"train/lr": current_lr}, step=global_step)



# ============================================================
# 8) Evaluation (teacher forcing; greedy position)
# ============================================================
@torch.no_grad()
def evaluate_order_policy(
    policy: nn.Module,
    f_theta: Callable[[torch.Tensor], torch.Tensor],
    loader: DataLoader,
    cfg: TrainConfig,
    max_eval_batches: Optional[int] = None,
) -> Dict[str, float]:
    policy.eval()
    device = torch.device(cfg.device)
    total_logp, total_steps = 0.0, 0.0

    for batch_idx, (init_tokens, solution_tokens) in enumerate(loader):
        # Limit number of batches for faster evaluation during training
        if max_eval_batches is not None and batch_idx >= max_eval_batches:
            break
            
        tokens = init_tokens.to(device).long()
        solution_tokens = solution_tokens.to(device).long()
        done = all_filled(tokens)

        for _ in range(cfg.max_steps):
            pos_logits = policy(tokens)                    # [B,81]
            valid = mask_empty_positions(tokens)
            masked_logits = pos_logits.masked_fill(valid == 0, float('-inf'))
            action = masked_logits.argmax(dim=-1)          # greedy

            gt_digit = gather_gt_digits(solution_tokens, action)
            r_t = reward_logprob(f_theta, tokens, action, gt_digit)

            m = (~done).float()
            total_logp += (r_t * m).sum().item()
            total_steps += m.sum().item()

            tokens = tokens_fill_truth(tokens, action, gt_digit)
            done = all_filled(tokens)
            if done.all():
                break

    return {"avg_logp_per_step": total_logp / max(total_steps, 1.0), "total_reward": total_logp}


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--jsonl", type=str, required=True, help="Path to sudoku.jsonl (lines with puzzles/solutions)")
    parser.add_argument("--model", type=int, default=1028, help="LLaMA M-size for Config, e.g., 8 for Diff_LLaMA_8M")
    parser.add_argument("--ckpt", type=str, help="Checkpoint path")
    parser.add_argument("--order_ckpt", type=str, help="Optional pretrained order policy checkpoint")
    parser.add_argument("--epochs", type=int, default=1000)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-3, help="Base learning rate (overridden by lr_initial if provided)")
    parser.add_argument("--lr_initial", type=float, default=1e-3, help="Initial learning rate for first N epochs")
    parser.add_argument("--lr_final", type=float, default=1e-4, help="Final learning rate after warmup epochs")
    parser.add_argument("--lr_warmup_epochs", type=int, default=14, help="Number of epochs to use initial LR before switching to final LR")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--limit", type=int, default=0, help="Optional: cap number of (puzzle,solution) pairs")
    parser.add_argument("--save_path", type=str, default="GRPO_ckpt_adt_lr/policy_ckpt_n8.pt", help="Path to save trained policy checkpoint")
    parser.add_argument("--val_split", type=float, default=0.7, help="Validation split ratio (0.0 to disable)")
    parser.add_argument("--eval_every", type=int, default=1, help="Evaluate every N batches")
    parser.add_argument("--max_eval_batches", type=int, default=20, 
                        help="Max number of batches to evaluate during training (None=full, default=10 for speed)")
    parser.add_argument("--wandb_project", type=str, default="sudoku-rl-order", help="Wandb project name")
    parser.add_argument("--wandb_name", type=str, default=None, help="Wandb run name (optional)")
    parser.add_argument("--wandb_mode", type=str, default="online", choices=["online", "offline", "disabled"], 
                        help="Wandb mode: online, offline, or disabled")
    args = parser.parse_args()
    
    # Setup distributed training
    rank, world_size, local_rank = setup_distributed()
    
    # Set device based on local rank
    if torch.cuda.is_available():
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(local_rank)
    else:
        device = torch.device("cpu")
    
    # Debug: Print which GPU this process is using
    if torch.cuda.is_available():
        print(f"[Rank {rank}/{world_size}] Using device: {device}, GPU: {torch.cuda.get_device_name(local_rank)}")
    
    # Initialize wandb (only on main process)
    if is_main_process(rank):
        # Set wandb timeout environment variables to avoid hanging
        os.environ.setdefault("WANDB__SERVICE_WAIT", "300")  # 5 minute timeout
        os.environ.setdefault("WANDB_INIT_TIMEOUT", "60")    # 1 minute init timeout
        
        try:
            wandb.init(
                project=args.wandb_project,
                name=args.wandb_name,
                mode=args.wandb_mode,
                settings=wandb.Settings(
                    _service_wait=300,  # 5 minute timeout for service
                    start_method="thread",  # Use thread instead of fork
                ),
                config={
                    "model_size": args.model,
                    "epochs": args.epochs,
                    "batch_size": args.batch_size,
                    "lr": args.lr,
                    "seed": args.seed,
                    "val_split": args.val_split,
                    "eval_every": args.eval_every,
                    "checkpoint": args.ckpt,
                    "world_size": world_size,
                }
            )
        except Exception as e:
            print(f"WARNING: Failed to initialize wandb: {e}")
            print("Continuing without wandb logging...")
            args.wandb_mode = "disabled"
    
    if is_main_process(rank):
        print(f"Running on {world_size} GPU(s)")
        if world_size == 1:
            print("⚠️  WARNING: Running in SINGLE GPU mode!")
            print("   If you intended multi-GPU, make sure to use 'torchrun --nproc_per_node=N'")
        else:
            print(f"✓ Multi-GPU mode active with {world_size} processes")

    # --- Tokenizer (NO CUDA yet) ---
    tok, digit2id, id2digit, digit_vocab_ids = build_tokenizer_and_digit_maps(MASK_ID=32000)
    if is_main_process(rank):
        print("digit2id (0..9):", digit2id, "| MASK_ID:", tok.pad_token_id)

    # --- Data (Create DataLoaders BEFORE loading model to GPU!) ---
    ds = SudokuTokenDataset(args.jsonl, limit=args.limit)
    
    # Split train/val if val_split > 0
    train_sampler = None
    val_sampler = None
    
    if args.val_split > 0 and len(ds) > 0:
        val_size = int(len(ds) * args.val_split)
        train_size = len(ds) - val_size
        train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])
        
        # Use DistributedSampler if using multiple GPUs
        # Set num_workers based on available CPUs (leave some for main process)
        # Keep it conservative to avoid shared memory issues
        #num_workers = min(4, max(1, args.batch_size // 8)) if world_size == 1 else min(4, max(2, args.batch_size // 8))
        num_workers = 0
        if world_size > 1:
            train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True)
            val_sampler = DistributedSampler(val_ds, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
            train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler, 
                                     drop_last=True, num_workers=num_workers, pin_memory=True, persistent_workers=False)
            val_loader = DataLoader(val_ds, batch_size=args.batch_size, sampler=val_sampler, 
                                   drop_last=True, num_workers=num_workers, pin_memory=True, persistent_workers=False)
        else:
            train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, 
                                     drop_last=False, num_workers=num_workers, pin_memory=True)
            val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, 
                                   drop_last=False, num_workers=num_workers, pin_memory=True)
        
        if is_main_process(rank):
            print(f"Train: {train_size} samples, Val: {val_size} samples")
            print(f"DataLoader: num_workers={num_workers}, pin_memory=True")
    else:
        num_workers = min(4, max(1, args.batch_size // 8)) if world_size == 1 else min(4, max(2, args.batch_size // 8))
        
        if world_size > 1:
            train_sampler = DistributedSampler(ds, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True)
            train_loader = DataLoader(ds, batch_size=args.batch_size, sampler=train_sampler, 
                                     drop_last=True, num_workers=num_workers, pin_memory=True, persistent_workers=False)
        else:
            train_loader = DataLoader(ds, batch_size=args.batch_size, shuffle=True, 
                                     drop_last=False, num_workers=num_workers, pin_memory=True)
        val_loader = None
        if is_main_process(rank):
            print(f"Train: {len(ds)} samples (no validation split)")
    
    # CRITICAL: Load model AFTER DataLoader creation to avoid fork-after-CUDA issues
    if is_main_process(rank):
        print("\n⚠️  Loading model AFTER DataLoader creation (to avoid CUDA fork issues)...")
    
    config = Config.from_name(f"Diff_LLaMA_{args.model}M")
    model  = TransEncoder(config).to(device)
    state  = load_mdm_state_dict(args.ckpt)
    model.load_state_dict(state, strict=False)
    model.eval()
    
    if is_main_process(rank):
        print(f"✓ Model loaded on {device}")

    # --- f_theta and policy ---
    cfg = TrainConfig(MASK_ID=tok.pad_token_id, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, seed=args.seed)
    f_theta = make_f_theta_from_model(model, digit2id, mask_id=cfg.MASK_ID, use_amp_bf16=True)
    policy = FusionOrderPolicy(d_model=256, nhead=8, mlp_hidden=256, topk=2, f_theta=f_theta).to(device)
    if args.order_ckpt:
        if is_main_process(rank):
            print(f"Loading pretrained order policy from {args.order_ckpt}")
        order_ckpt = torch.load(args.order_ckpt, map_location=device)
        state_dict = order_ckpt.get("model_state_dict", order_ckpt)
        policy.load_state_dict(state_dict, strict=False)

    # --- Train with periodic evaluation ---
    train_order_policy_grpo_supervised(
                    policy=policy,
                    f_theta=f_theta,
                    train_loader=train_loader,
                    cfg=cfg,
                    n_traj_per_puzzle=7,   # try 2–8
                    alpha=0.5,             # try 0.5–2.0; smaller = sharper weighting
                    beta_kl=0.05,          # try 0.01–0.2
                    epochs=args.epochs,
                    rank=rank,
                    world_size=world_size,
                    log_every=args.eval_every,
                    eval_every=5*args.eval_every,
                    val_loader=val_loader,
                    max_eval_batches=args.max_eval_batches,
                    save_path=args.save_path,
                    lr_warmup_epochs=args.lr_warmup_epochs,
                    lr_initial=args.lr_initial,
                    lr_final=args.lr_final,
                )

    
    # Get the unwrapped model for evaluation and saving
    model_to_save = policy.module if isinstance(policy, DDP) else policy
    
    # Final evaluation (all ranks to avoid deadlock with DistributedSampler)
    if val_loader is not None:
        metrics = evaluate_order_policy(model_to_save, f_theta, val_loader, cfg)
        if is_main_process(rank):
            print("\n=== Final Evaluation ===")
            print("Final validation metrics:", metrics)
            wandb.log({f"final_val/{k}": v for k, v in metrics.items()})
    else:
        # Evaluate on training set if no validation set
        metrics = evaluate_order_policy(model_to_save, f_theta, train_loader, cfg)
        if is_main_process(rank):
            print("Final training metrics:", metrics)
            wandb.log({f"final_train/{k}": v for k, v in metrics.items()})
    
    # Save final checkpoint (only on main process)
    if is_main_process(rank):
        
        # --- Save trained policy ---
        torch.save({
            'epoch': args.epochs,
            'model_state_dict': model_to_save.state_dict(),
            'config': {
                'd_model': cfg.d_model,
                'nhead': cfg.nhead,
                'num_layers': cfg.num_layers,
                'dim_feedforward': cfg.dim_feedforward,
                'dropout': cfg.dropout,
                'mask_id': cfg.MASK_ID,
            },
            'metrics': metrics,
        }, args.save_path)
        print(f"Saved policy checkpoint to {args.save_path}")
        
        # Save model to wandb
        wandb.save(args.save_path)
        print(f"Uploaded checkpoint to wandb")
        
        # Finish wandb run
        wandb.finish()
    
    # Clean up distributed training
    cleanup_distributed()


if __name__ == "__main__":
    main()