#!/usr/bin/env python3
# ============================================================
# Train diffusion model on Sudoku using a FIXED, GREEDY order policy
# Loss = sum_t - log p_theta(gt_digit | state, pos_t)
# - No RL
# - No sampling from diffusion
# - Order policy is frozen and executed greedily
# - Teacher forcing transitions
#
# DDP friendly (torchrun)
# ============================================================

import os, sys, math, time, json, random, argparse
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Callable
from datetime import timedelta
from contextlib import nullcontext
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer
import wandb

# ------------------------------
# Safer attention backend
# ------------------------------
if torch.cuda.is_available():
    try:
        torch.backends.cuda.enable_flash_sdp(False)
        torch.backends.cuda.enable_mem_efficient_sdp(False)
        torch.backends.cuda.enable_math_sdp(True)
    except Exception:
        pass

# ------------------------------
# Optional flash_attn stub (kept minimal)
# ------------------------------
import types, importlib.util, importlib.abc

def _fa_flash_stub(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, **kwargs):
    if hasattr(F, "scaled_dot_product_attention"):
        return F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=dropout_p if torch.is_grad_enabled() else 0.0,
            is_causal=causal,
        )
    d = q.size(-1)
    scale = (1.0 / math.sqrt(d)) if softmax_scale is None else softmax_scale
    scores = (q @ k.transpose(-2, -1)) * scale
    if causal:
        mask = torch.ones_like(scores, dtype=torch.bool).tril()
        scores = scores.masked_fill(~mask, float("-inf"))
    P = scores.softmax(dim=-1)
    if torch.is_grad_enabled() and dropout_p > 0:
        P = F.dropout(P, p=dropout_p)
    return P @ v

if "flash_attn" not in sys.modules:
    flash_attn_module = types.ModuleType("flash_attn")
    flash_attn_module.flash_attn_func = _fa_flash_stub
    class _MinimalLoader(importlib.abc.Loader):
        def create_module(self, spec): return None
        def exec_module(self, module): pass
    spec = importlib.util.spec_from_loader("flash_attn", _MinimalLoader(), origin="<stub>")
    flash_attn_module.__spec__ = spec
    sys.modules["flash_attn"] = flash_attn_module

# ------------------------------
# Multiprocessing start method
# ------------------------------
try:
    mp.set_start_method("spawn", force=True)
except RuntimeError:
    pass

# ------------------ your repo wiring ------------------
REPO_ROOT = Path('')
sys.path.insert(0, str(REPO_ROOT))

from lit_gpt.model_cache import Config
from lit_gpt.diffmodel import TransEncoder
from inference_sudoku import load_mdm_state_dict


# ============================================================
# Distributed utils
# ============================================================
def setup_distributed():
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend="nccl", timeout=timedelta(seconds=1800))
        dist.barrier()
        return rank, world_size, local_rank
    return 0, 1, 0

def cleanup_distributed():
    if dist.is_initialized():
        dist.destroy_process_group()

def is_main_process(rank=0):
    return rank == 0

def set_seed(seed: int = 1234):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# ============================================================
# Config
# ============================================================
@dataclass
class TrainConfig:
    device: str = "cuda"
    batch_size: int = 32
    lr: float = 1e-5
    epochs: int = 1
    max_steps: int = 81
    grad_clip: float = 1.0
    MASK_ID: int = 32000
    use_amp_bf16: bool = True
    seed: int = 1234


# ============================================================
# Tokenizer + digit maps
# ============================================================
def build_tokenizer_and_digit_maps(MASK_ID: int = 32000):
    tok = AutoTokenizer.from_pretrained(
        "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
        padding_side="right",
        use_fast=True,
    )
    tok.add_special_tokens({"pad_token": "[PAD]"})
    tok.pad_token_id = MASK_ID

    def digit_id(d: int) -> int:
        ids = tok.encode(f" {d}", add_special_tokens=False)
        return ids[-1]

    digit2id = {d: digit_id(d) for d in range(1, 10)}
    digit2id[0] = MASK_ID
    id2digit = {v: k for k, v in digit2id.items()}
    digit_vocab_ids = torch.tensor([digit2id[d] for d in range(10)], dtype=torch.long)
    return tok, digit2id, id2digit, digit_vocab_ids


# ============================================================
# Custom collate function for variable-length positions
# ============================================================
def collate_fn_with_positions(batch):
    """
    Custom collate function that handles variable-length position lists.
    batch: list of tuples (puzzle_tensor, solution_tensor, positions_list)
    """
    if len(batch[0]) == 3:
        # Has positions
        puzzles, solutions, positions = zip(*batch)
        puzzles = torch.stack(puzzles)
        solutions = torch.stack(solutions)
        # Keep positions as a list of lists (don't try to stack)
        return (puzzles, solutions, list(positions))
    else:
        # No positions, use default collate
        puzzles, solutions = zip(*batch)
        return (torch.stack(puzzles), torch.stack(solutions))


# ============================================================
# Dataset
# ============================================================
def reconstruct_solution_from_steps(puzzle: List[int], steps: List[dict]) -> List[int]:
    """
    Reconstruct the solution by applying all assignment steps to the puzzle.
    Only processes steps with 'pos' and 'val' fields (assignment steps).
    """
    solution = puzzle.copy()
    for step in steps:
        pos = step.get("pos", None)
        val = step.get("val", None)
        if pos and val and len(pos) == 2:
            r, c = pos[0], pos[1]  # 1-indexed [row, col]
            linear_idx = (r - 1) * 9 + (c - 1)  # Convert to 0-indexed linear position
            if 0 <= linear_idx < 81:
                solution[linear_idx] = val
    return solution

class SudokuTokenDataset(Dataset):
    def __init__(self, jsonl_path: str, limit: int = 0, use_ground_truth_order: bool = False):
        self.items: List[Tuple[List[int], List[int], Optional[List[int]]]] = []
        self.use_ground_truth_order = use_ground_truth_order
        with open(jsonl_path, "r") as f:
            for line in f:
                ex = json.loads(line)
                positions = None
                solution = None
                
                # Extract positions from steps if available and requested
                if use_ground_truth_order and "steps" in ex:
                    steps = ex.get("steps", [])
                    positions = []
                    for step in steps:
                        pos = step.get("pos", None)
                        if pos and len(pos) == 2:
                            r, c = pos[0], pos[1]  # 1-indexed [row, col]
                            linear_pos = (r - 1) * 9 + (c - 1)  # Convert to 0-indexed linear position
                            positions.append(linear_pos)
                
                # Try to get solution from various fields, or reconstruct from steps
                if "puzzles" in ex and "solutions" in ex:
                    puz = ex["puzzles"]
                    sol = ex["solutions"]
                    if isinstance(puz, list) and len(puz) == 81 and isinstance(sol, list) and len(sol) == 81:
                        self.items.append((puz, sol, positions))
                    elif isinstance(puz, list) and isinstance(sol, list) and len(puz) == len(sol):
                        for p, s in zip(puz, sol):
                            if isinstance(p, list) and isinstance(s, list) and len(p) == 81 and len(s) == 81:
                                self.items.append((p, s, positions))
                elif "puzzle" in ex:
                    p = ex["puzzle"]
                    if isinstance(p, list) and len(p) == 81:
                        # Try to get solution from "solution" field first
                        if "solution" in ex:
                            s = ex["solution"]
                            if isinstance(s, list) and len(s) == 81:
                                self.items.append((p, s, positions))
                        # Otherwise, reconstruct solution from steps
                        elif "steps" in ex:
                            steps = ex.get("steps", [])
                            s = reconstruct_solution_from_steps(p, steps)
                            self.items.append((p, s, positions))
                        else:
                            # No solution and no steps - skip this example
                            continue

        if limit and limit > 0:
            self.items = self.items[:limit]
        if not self.items:
            raise ValueError("No (puzzle, solution) pairs found in JSONL.")
        if use_ground_truth_order:
            missing_positions = sum(1 for _, _, pos in self.items if pos is None or len(pos) == 0)
            if missing_positions > 0:
                print(f"WARNING: {missing_positions} puzzles missing ground truth positions", flush=True)

    def __len__(self): return len(self.items)

    def __getitem__(self, i):
        puz, sol, positions = self.items[i]
        result = (torch.tensor(puz, dtype=torch.long),
                  torch.tensor(sol, dtype=torch.long))
        if self.use_ground_truth_order:
            # Return positions as a list (not tensor) to handle variable lengths in collate
            # Always return a list, even if empty
            result = result + (positions if positions is not None else [],)
        return result


# ============================================================
# Sudoku helpers
# ============================================================
def all_filled(tokens: torch.Tensor) -> torch.Tensor:
    return (tokens == 0).sum(dim=1) == 0

def mask_empty_positions(tokens: torch.Tensor) -> torch.Tensor:
    return (tokens == 0)

def gather_gt_digits(solution_tokens: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
    return solution_tokens[torch.arange(solution_tokens.size(0), device=solution_tokens.device), pos]


# ============================================================
# Order policy (fusion λ(t))
# ============================================================
class LambdaSchedule(nn.Module):
    def __init__(self, hidden: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
        )

    def forward(self, t_norm: torch.Tensor) -> torch.Tensor:
        if t_norm.dim() == 1:
            t_norm = t_norm.unsqueeze(-1)
        raw = self.net(t_norm)
        return torch.sigmoid(raw).squeeze(-1)

class OrderPolicyWithFusion(nn.Module):
    def __init__(self, f_theta, lambda_hidden: int = 64):
        super().__init__()
        self.f_theta = f_theta
        self.lambdanet = LambdaSchedule(hidden=lambda_hidden)

    def forward(self, tokens: torch.Tensor, t_norm: torch.Tensor) -> torch.Tensor:
        # uses frozen f_theta inside (should be no_grad)
        with torch.no_grad():
            logits_81x9 = self.f_theta(tokens)  # [B,81,9]
        logits_81x9 = logits_81x9.float()
        probs = F.softmax(logits_81x9, dim=-1)
        log_probs = F.log_softmax(logits_81x9, dim=-1)
        top1, _ = probs.max(dim=-1)                     # [B,81]
        entropy = -(probs * log_probs).sum(dim=-1)      # [B,81]

        lam = self.lambdanet(t_norm.to(tokens.device))  # [B]
        return top1 + lam.unsqueeze(1) * entropy        # [B,81]

class FusionOrderPolicy(OrderPolicyWithFusion):
    def __init__(self, *args, f_theta=None, lambda_hidden=64, **kwargs):
        if f_theta is None:
            f_theta = kwargs.pop("f_theta", None)
        if f_theta is None:
            raise ValueError("FusionOrderPolicy requires f_theta")
        super().__init__(f_theta=f_theta, lambda_hidden=lambda_hidden)

    def forward(self, tokens: torch.Tensor, t_norm: torch.Tensor = None) -> torch.Tensor:
        if t_norm is None:
            t_norm = torch.zeros(tokens.size(0), device=tokens.device)
        if t_norm.dim() == 1:
            t_norm = t_norm.unsqueeze(-1)
        return super().forward(tokens, t_norm=t_norm)


# ============================================================
# Model adapters
# ============================================================
def sudoku_tokens_to_vocab_ids(tokens: torch.Tensor, digit2id: Dict[int,int], mask_id: int) -> torch.Tensor:
    ids = torch.full_like(tokens, fill_value=mask_id, dtype=torch.long, device=tokens.device)
    for d in range(1, 10):
        ids[tokens == d] = digit2id[d]
    return ids

def make_f_theta_from_model(
    model: nn.Module,
    digit2id: Dict[int, int],
    mask_id: int = 32000,
    use_amp_bf16: bool = True,
) -> Callable[[torch.Tensor], torch.Tensor]:
    device = next(model.parameters()).device
    digit_ids = torch.tensor([digit2id[d] for d in range(1, 10)], dtype=torch.long, device=device)

    @torch.no_grad()
    def f_theta(tokens: torch.Tensor) -> torch.Tensor:
        ids = sudoku_tokens_to_vocab_ids(tokens.to(device), digit2id, mask_id)
        amp = (torch.cuda.amp.autocast(dtype=torch.bfloat16)
               if (use_amp_bf16 and device.type == "cuda") else nullcontext())
        with amp:
            out = model(ids)
        logits = out[0] if isinstance(out, (tuple, list)) else out
        return logits.index_select(-1, digit_ids)  # [B,81,9]
    return f_theta

def make_logprob_trainable_from_model(
    model: nn.Module,
    digit2id: Dict[int, int],
    mask_id: int = 32000,
    use_amp_bf16: bool = True,
):
    device = next(model.parameters()).device
    digit_ids = torch.tensor([digit2id[d] for d in range(1, 10)], dtype=torch.long, device=device)

    def logprob_trainable(tokens: torch.Tensor, pos_idx: torch.Tensor, gt_digit: torch.Tensor) -> torch.Tensor:
        tokens = tokens.to(device).long()
        pos_idx = pos_idx.to(device).long()
        gt_digit = gt_digit.to(device).long()

        ids = sudoku_tokens_to_vocab_ids(tokens, digit2id, mask_id)

        amp = (torch.cuda.amp.autocast(dtype=torch.bfloat16)
               if (use_amp_bf16 and device.type == "cuda") else nullcontext())
        with amp:
            out = model(ids)
            logits = out[0] if isinstance(out, (tuple, list)) else out

        logits_9 = logits.index_select(-1, digit_ids)  # [B,81,9]
        ls = logits_9[torch.arange(tokens.size(0), device=device), pos_idx]  # [B,9]
        logp = F.log_softmax(ls.float(), dim=-1)
        return logp[torch.arange(tokens.size(0), device=device), gt_digit - 1]  # [B]

    return logprob_trainable


# ============================================================
# Training: GREEDY order, loss = sum negative reward
# ============================================================
def train_diffusion_greedy(
    model: nn.Module,               # TRAIN THIS (DDP ok)
    policy: Optional[nn.Module],   # FROZEN (optional, only used if use_ground_truth_order=False)
    logprob_trainable: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
    train_loader: DataLoader,
    cfg: TrainConfig,
    epochs: int,
    rank: int,
    world_size: int,
    lr_initial: float,
    lr_final: float,
    lr_warmup_epochs: int,
    log_every: int,
    val_loader: Optional[DataLoader],
    max_eval_batches: int,
    save_path: str,
    use_ground_truth_order: bool = False,
):
    device = torch.device(cfg.device)
    model_without_ddp = model.module if isinstance(model, DDP) else model

    # Freeze policy params (if using policy)
    if policy is not None:
        policy.eval()
        for p in policy.parameters():
            p.requires_grad_(False)

    # Make diffusion trainable
    model.train()
    for p in model.parameters():
        p.requires_grad_(True)

    opt = torch.optim.AdamW(model.parameters(), lr=lr_initial, weight_decay=0.01)

    def lr_lambda(epoch):
        return 1.0 if epoch < lr_warmup_epochs else (lr_final / lr_initial)

    scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda)

    global_step = 0

    for epoch in range(epochs):
        if isinstance(train_loader.sampler, DistributedSampler):
            train_loader.sampler.set_epoch(epoch)

        model.train()
        for batch_idx, batch_data in enumerate(train_loader):
            global_step += 1
            batch_start = time.time()

            if use_ground_truth_order:
                if len(batch_data) == 3:
                    init_tokens, solution_tokens, gt_positions_list = batch_data
                else:
                    # Fallback: if somehow we don't have positions, use policy
                    init_tokens, solution_tokens = batch_data[:2]
                    gt_positions_list = None
                    if is_main_process(rank) and batch_idx == 0:
                        print(f"WARNING: Expected 3 items in batch but got {len(batch_data)}, falling back to policy", flush=True)
            else:
                init_tokens, solution_tokens = batch_data
                gt_positions_list = None

            tokens = init_tokens.to(device).long()         # current state, teacher-forced
            sol = solution_tokens.to(device).long()
            B = tokens.size(0)

            done = all_filled(tokens)

            opt.zero_grad(set_to_none=True)

            total_reward = 0.0   # sum logp(gt)
            total_loss = 0.0     # sum -logp(gt)
            total_steps = 0      # count of (puzzle,step) pairs where puzzle not done

            # If using ground truth order, determine max steps from positions
            if use_ground_truth_order and gt_positions_list is not None:
                # gt_positions_list is a list of lists (variable length per puzzle)
                max_len = max(len(positions) for positions in gt_positions_list if positions) if gt_positions_list else cfg.max_steps
                max_len = min(max_len, cfg.max_steps)
            else:
                max_len = cfg.max_steps

            for t in range(max_len):
                if done.all():
                    break

                # Position selection: either from ground truth or from policy
                if use_ground_truth_order and gt_positions_list is not None:
                    # Use ground truth positions from JSONL
                    pos = torch.zeros(B, dtype=torch.long, device=device)
                    for b in range(B):
                        positions_b = gt_positions_list[b] if b < len(gt_positions_list) else []
                        if positions_b and t < len(positions_b):
                            pos[b] = positions_b[t]  # positions_b is a list of ints
                        else:
                            # If we've exhausted positions for this puzzle, mark as done
                            pos[b] = -1  # Will be filtered out by active mask
                else:
                    # GREEDY position selection from frozen order policy
                    with torch.no_grad():
                        t_norm = torch.full((B, 1), t / cfg.max_steps, device=device)
                        pos_scores = policy(tokens, t_norm)  # [B,81]
                        pos_scores = torch.nan_to_num(pos_scores, nan=0.0, posinf=1e30, neginf=-1e30)

                        valid = mask_empty_positions(tokens)  # [B,81] bool
                        pos_scores = pos_scores.masked_fill(~valid, float("-inf"))

                        pos = pos_scores.argmax(dim=-1)      # [B]

                active = ~done
                # Filter out invalid positions (e.g., -1 from exhausted ground truth)
                if use_ground_truth_order:
                    valid_pos = (pos >= 0) & (pos < 81)
                    active = active & valid_pos
                
                if not active.any():
                    break

                # Compute logprob only on active puzzles (saves compute)
                tokens_a = tokens[active]                 # [Ba,81]
                pos_a = pos[active]                       # [Ba]
                gt_a = sol[active, pos_a]                 # [Ba]

                logp_a = logprob_trainable(tokens_a, pos_a, gt_a)  # [Ba]
                # Loss = sum negative reward (exactly what you requested)
                loss_t = -(logp_a).sum()
                loss_t.backward()

                total_reward += float(logp_a.detach().sum().item())
                total_loss += float(loss_t.detach().item())
                total_steps += int(active.sum().item())

                # Teacher forcing transition: fill GT digit at chosen pos
                tokens[active, pos_a] = gt_a
                done = all_filled(tokens)

            nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()

            # Logging
            if (batch_idx % log_every == 0) and is_main_process(rank):
                lr_now = opt.param_groups[0]["lr"]
                bt = time.time() - batch_start
                avg_r = total_reward / max(total_steps, 1)  # average logp per filled step (negative number)
                print(
                    f"[E{epoch} B{batch_idx}] "
                    f"loss_sum={total_loss:.3f} | avg_r={avg_r:.6f} | steps={total_steps} "
                    f"| lr={lr_now:.2e} | time={bt:.2f}s"
                )
                if wandb.run is not None:
                    wandb.log({
                        "train/loss_sum": total_loss,
                        "train/avg_logp_per_step": avg_r,
                        "train/steps": total_steps,
                        "train/lr": lr_now,
                        "train/epoch": epoch,
                        "train/batch_idx": batch_idx,
                    }, step=global_step)

        scheduler.step()

        # Save epoch checkpoint (rank0)
        if is_main_process(rank):
            save_dir = Path(save_path).parent
            save_dir.mkdir(parents=True, exist_ok=True)
            epoch_path = str(Path(save_path).with_name(f"{Path(save_path).stem}_epoch{epoch:02d}.pt"))
            try:
                torch.save({
                    "epoch": epoch,
                    "global_step": global_step,
                    "model_state_dict": model_without_ddp.state_dict(),
                    "optimizer_state_dict": opt.state_dict(),
                }, epoch_path)
                print(f"✓ Saved epoch ckpt to {epoch_path}")
            except (OSError, IOError, RuntimeError) as e:
                print(f"WARNING: Failed to save epoch checkpoint to {epoch_path}: {e}", flush=True)
                print("Continuing training without saving epoch checkpoint...", flush=True)

    # Final save (rank0)
    if is_main_process(rank):
        save_dir = Path(save_path).parent
        save_dir.mkdir(parents=True, exist_ok=True)
        try:
            torch.save({
                "epoch": epochs,
                "global_step": global_step,
                "model_state_dict": model_without_ddp.state_dict(),
                "optimizer_state_dict": opt.state_dict(),
            }, save_path)
            print(f"✓ Saved final diffusion ckpt to {save_path}")
        except (OSError, IOError, RuntimeError) as e:
            print(f"ERROR: Failed to save final checkpoint to {save_path}: {e}", flush=True)
            print("Training completed but checkpoint was not saved due to disk quota issues.", flush=True)


# ============================================================
# Main
# ============================================================
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--jsonl", type=str, required=True)
    parser.add_argument("--model", type=int, default=1028, help="Config name: Diff_LLaMA_{M}M")
    parser.add_argument("--ckpt", type=str, required=True, help="Diffusion checkpoint (init weights)")
    parser.add_argument("--order_ckpt", type=str, default=None, help="Pretrained order policy checkpoint (required if not using ground truth order)")
    parser.add_argument("--use_ground_truth_order", action="store_true", help="Use positions from JSONL 'steps' field instead of order policy")
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--batch_size", type=int, default=32)  # PER GPU
    parser.add_argument("--lr_initial", type=float, default=1e-5)
    parser.add_argument("--lr_final", type=float, default=1e-6)
    parser.add_argument("--lr_warmup_epochs", type=int, default=10)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--limit", type=int, default=0)
    parser.add_argument("--save_path", type=str)
    parser.add_argument("--val_split", type=float, default=0.0)
    parser.add_argument("--log_every", type=int, default=10)
    parser.add_argument("--wandb_project", type=str, default="sudoku-diffusion-greedy")
    parser.add_argument("--wandb_name", type=str, default=None)
    parser.add_argument("--wandb_mode", type=str, default="online", choices=["online", "offline", "disabled"])
    args = parser.parse_args()

    rank, world_size, local_rank = setup_distributed()
    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)

    set_seed(args.seed + rank)

    # wandb (rank0 only)
    if is_main_process(rank):
        os.environ.setdefault("WANDB__SERVICE_WAIT", "300")
        os.environ.setdefault("WANDB_INIT_TIMEOUT", "60")
        if args.wandb_mode != "disabled":
            try:
                wandb.init(
                    project=args.wandb_project,
                    name=args.wandb_name,
                    mode=args.wandb_mode,
                    settings=wandb.Settings(_service_wait=300, start_method="thread"),
                    config=vars(args) | {"world_size": world_size},
                )
            except Exception as e:
                print(f"WARNING: wandb init failed: {e}")
                args.wandb_mode = "disabled"

    if is_main_process(rank):
        print(f"Running on {world_size} GPU(s). Device={device}")

    # Tokenizer + maps
    tok, digit2id, _, _ = build_tokenizer_and_digit_maps(MASK_ID=32000)
    cfg = TrainConfig(
        device=str(device),
        batch_size=args.batch_size,
        epochs=args.epochs,
        MASK_ID=tok.pad_token_id,
        seed=args.seed,
        use_amp_bf16=True,
    )

    # Data
    ds = SudokuTokenDataset(args.jsonl, limit=args.limit, use_ground_truth_order=args.use_ground_truth_order)
    if args.val_split > 0.0 and len(ds) > 10:
        val_size = int(len(ds) * args.val_split)
        train_size = len(ds) - val_size
        train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])
    else:
        train_ds, val_ds = ds, None

    # Use custom collate function if using ground truth order
    collate_fn = collate_fn_with_positions if args.use_ground_truth_order else None
    
    if world_size > 1:
        train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True)
        train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler,
                                  num_workers=0, pin_memory=True, drop_last=True, collate_fn=collate_fn)
        val_loader = None
    else:
        train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                                  num_workers=0, pin_memory=True, drop_last=False, collate_fn=collate_fn)
        val_loader = None

    # Load diffusion TRAIN model
    if is_main_process(rank):
        print("Loading TRAINABLE diffusion model...")
    config = Config.from_name(f"Diff_LLaMA_{args.model}M")
    model_train = TransEncoder(config).to(device)
    state = load_mdm_state_dict(args.ckpt)
    model_train.load_state_dict(state, strict=False)

    if world_size > 1:
        model_train = DDP(model_train, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)

    # Load order policy (only if not using ground truth order)
    policy = None
    if not args.use_ground_truth_order:
        if args.order_ckpt is None:
            raise ValueError("--order_ckpt is required when not using --use_ground_truth_order")
        
        # Load frozen diffusion snapshot for ORDER policy features (keeps order fixed)
        if is_main_process(rank):
            print("Loading FROZEN diffusion snapshot for order policy features...")
        model_order = TransEncoder(config).to(device)
        model_order.load_state_dict(state, strict=False)
        model_order.eval()
        for p in model_order.parameters():
            p.requires_grad_(False)

        f_theta_order = make_f_theta_from_model(model_order, digit2id, mask_id=cfg.MASK_ID, use_amp_bf16=True)

        # Build policy and load checkpoint
        policy = FusionOrderPolicy(f_theta=f_theta_order).to(device)
        order_ckpt = torch.load(args.order_ckpt, map_location='cpu')  # Load to CPU to avoid OOM
        sd = order_ckpt.get("model_state_dict", order_ckpt)
        policy.load_state_dict(sd, strict=False)
        policy.eval()
        for p in policy.parameters():
            p.requires_grad_(False)
        del order_ckpt, sd
        torch.cuda.empty_cache()

        if is_main_process(rank):
            print("✓ Order policy loaded + frozen + greedy execution")
    else:
        if is_main_process(rank):
            print("✓ Using ground truth order from JSONL 'steps' field")

    # Trainable logprob for diffusion
    logprob_trainable = make_logprob_trainable_from_model(model_train, digit2id, mask_id=cfg.MASK_ID, use_amp_bf16=True)

    # Train
    train_diffusion_greedy(
        model=model_train,
        policy=policy,
        logprob_trainable=logprob_trainable,
        train_loader=train_loader,
        cfg=cfg,
        epochs=args.epochs,
        rank=rank,
        world_size=world_size,
        lr_initial=args.lr_initial,
        lr_final=args.lr_final,
        lr_warmup_epochs=args.lr_warmup_epochs,
        log_every=args.log_every,
        val_loader=val_loader,
        max_eval_batches=0,
        save_path=args.save_path,
        use_ground_truth_order=args.use_ground_truth_order,
    )

    if is_main_process(rank) and wandb.run is not None:
        wandb.save(args.save_path)
        wandb.finish()

    cleanup_distributed()


if __name__ == "__main__":
    main()
