# train_lambda_llada_parallel.py

import argparse
from datetime import datetime
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
from dataclasses import dataclass

from transformers import AutoModel, AutoTokenizer
from llada_dataset import GSM8KLLADA, MBPPLLADA, MATH500LLADA, HumanEvalLLADA
import wandb
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

MASK_ID = 126336  # LLaDA [MASK] token id


# =========================
# Lambda(t) Network
# =========================
'''
class LambdaNet(nn.Module):
    def __init__(self, hidden: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, hidden),
            nn.Tanh(),
            nn.Linear(hidden, 1),
        )

    def forward(self, t_norm: torch.Tensor) -> torch.Tensor:
        """
        t_norm: [B] or [B,1] in [0,1]
        returns: [B] λ(t)
        """
        if t_norm.dim() == 1:
            t_norm = t_norm.unsqueeze(-1)
        raw = self.net(t_norm)
        lam = torch.sigmoid(raw)        # [B,1] in (0,1)
        return lam.squeeze(-1)  


def get_base_model(model: nn.Module) -> nn.Module:
    if isinstance(model, DDP):
        return model.module
    return model
'''

def get_base_model(model: nn.Module) -> nn.Module:
    """Extract base model from DDP wrapper if needed."""
    if isinstance(model, DDP):
        return model.module
    return model

class LambdaNet(nn.Module):
    """
    Contextual Lambda network:
      - takes teacher features per position (top_p, entropy)
      - plus a candidate mask (which positions are in the current answer block)
      - runs a Transformer over the sequence of positions
      - outputs λ_vec[b, j] for each position j

    score_{b,j} = top_p_{b,j} + λ_{b,j} * ent_{b,j}
    """

    def __init__(
        self,
        d_model: int = 128,
        nhead: int = 4,
        num_layers: int = 2,
        dim_feedforward: int = 256,
        max_len: int = 512,
    ):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len

        # Input feature dimension:
        #   top_p          : 1
        #   ent            : 1
        #   is_candidate   : 1   (mask for "valid" positions: masked answer tokens in current block)
        d_in = 3

        self.in_proj = nn.Linear(d_in, d_model)

        # Positional embedding over positions 0..L-1
        self.pos_emb = nn.Embedding(max_len, d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True,   # makes life easier: [B, L, d]
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

        # Project back to scalar λ per position
        self.out_proj = nn.Linear(d_model, 1)

    def forward(
        self,
        top_p: torch.Tensor,          # [B, L]
        ent: torch.Tensor,            # [B, L]
        candidate_mask: torch.Tensor, # [B, L] bool or 0/1; True for positions in current answer block
        pad_mask: torch.Tensor,       # [B, L] bool; True for PAD positions (to be ignored in attention)
    ) -> torch.Tensor:
        """
        Returns:
          lam_vec: [B, L] (per-position λ_j)
        """
        B, L = top_p.shape
        device = top_p.device

        # Build feature tensor: [B, L, 3]
        feat = torch.stack(
            [
                top_p,                            # [B, L]
                ent,                              # [B, L]
                candidate_mask.float(),           # [B, L]
            ],
            dim=-1,
        )  # [B, L, 3]

        # Project to model dimension
        h = self.in_proj(feat)                    # [B, L, d_model]

        # Add positional embeddings
        pos_idx = torch.arange(L, device=device).unsqueeze(0).expand(B, L)  # [B, L]
        h = h + self.pos_emb(pos_idx)             # [B, L, d_model]

        # src_key_padding_mask: True for positions to IGNORE (pad)
        # pad_mask is expected as bool [B, L] with True for pad positions
        src_key_padding_mask = pad_mask.bool()

        # Transformer encoder
        h_enc = self.encoder(h, src_key_padding_mask=src_key_padding_mask)  # [B, L, d_model]

        # Project to scalar per position
        lam_raw = self.out_proj(h_enc).squeeze(-1)   # [B, L]

        # Optional: bound λ to a reasonable range for stability
        lam_vec = torch.tanh(lam_raw)                # in (-1, 1)

        return lam_vec

@dataclass
class TrajectoryBatch:
    actions: torch.Tensor        # [B*N, T]
    logp_old_sum: torch.Tensor   # [B*N]
    returns: torch.Tensor        # [B*N]
    max_steps: int
    batch_size: int
    n_traj: int


@torch.no_grad()
def rollout_lambda_grpo(
    teacher: AutoModel,
    lambdanet: nn.Module,
    full_ids: torch.Tensor,         # [B, L]
    prompt_len: torch.Tensor,       # [B]
    ans_len: torch.Tensor,          # [B]
    n_traj_per_prompt: int,
    pad_token_id: int,
    device: torch.device,
) -> TrajectoryBatch:
    """
    Roll out trajectories with a frozen behavior policy for PPO-style GRPO.
    """
    full_ids = full_ids.to(device)
    prompt_len = prompt_len.to(device)
    ans_len = ans_len.to(device)

    B, L = full_ids.shape
    N = int(n_traj_per_prompt)
    BN = B * N

    full = full_ids.repeat_interleave(N, dim=0)    # [BN, L]
    pl = prompt_len.repeat_interleave(N)           # [BN]
    al = ans_len.repeat_interleave(N)              # [BN]

    idxs = torch.arange(L, device=device).unsqueeze(0).expand(BN, L)
    ans_start = pl
    ans_end = pl + al
    in_answer = (idxs >= ans_start.unsqueeze(1)) & (idxs < ans_end.unsqueeze(1))

    canvas = full.clone()
    canvas[in_answer] = MASK_ID

    remaining = in_answer.sum(dim=1)
    max_steps = int(al.max().item()) if al.numel() > 0 else 0

    actions = torch.full((BN, max_steps), -1, device=device, dtype=torch.long)
    logp_steps = torch.zeros((BN, max_steps), device=device)
    rewards = torch.zeros((BN, max_steps), device=device)

    for t in range(max_steps):
        alive = remaining > 0
        if not alive.any():
            break

        x = canvas
        attn = (x != pad_token_id).long()
        
        # Process in smaller chunks to reduce memory
        chunk_size_local = max(2, min(4, BN // 4))
        top_p_chunks = []
        ent_chunks = []
        
        for cs in range(0, BN, chunk_size_local):
            ce = min(cs + chunk_size_local, BN)
            with torch.no_grad():
                out_chunk = teacher(x[cs:ce], attention_mask=attn[cs:ce])
                logits_chunk = out_chunk.logits
            
            # Compute probs, top_p, ent in chunk
            probs_chunk = F.softmax(logits_chunk, dim=-1)
            top_p_chunk = probs_chunk.max(dim=-1).values
            ent_chunk = -(probs_chunk * probs_chunk.clamp_min(1e-9).log()).sum(dim=-1)
            
            top_p_chunks.append(top_p_chunk)
            ent_chunks.append(ent_chunk)
            del probs_chunk, logits_chunk
            torch.cuda.empty_cache()
        
        top_p = torch.cat(top_p_chunks, dim=0)
        ent = torch.cat(ent_chunks, dim=0)
        del top_p_chunks, ent_chunks

        valid = (canvas == MASK_ID)
        idx_act = (alive & (valid.sum(dim=1) > 0)).nonzero(as_tuple=False).squeeze(-1)
        if idx_act.numel() == 0:
            break

        top_p_act = top_p.index_select(0, idx_act)
        ent_act = ent.index_select(0, idx_act)
        valid_act = valid.index_select(0, idx_act)
        al_act = al.index_select(0, idx_act)
        canvas_act = canvas.index_select(0, idx_act)

        t_norm_act = torch.full((idx_act.numel(),), t, device=device, dtype=torch.float32) / al_act.clamp_min(1).float()
        t_norm_act = t_norm_act.clamp(0.0, 1.0)
        pad_mask_act = (canvas_act == pad_token_id)

        #lam = lambdanet(t_norm_act).view(-1, 1)
        lam_vec_act = lambdanet(top_p_act, ent_act, valid_act, pad_mask_act)
        scores_act = top_p_act + lam_vec_act * ent_act
        scores_act = scores_act.masked_fill(~valid_act, float("-inf"))

        dist = torch.distributions.Categorical(logits=scores_act)
        a_act = dist.sample()
        logp_act = dist.log_prob(a_act)

        actions[idx_act, t] = a_act
        logp_steps[idx_act, t] = logp_act

        # Recompute logits only for active indices
        logits_act_chunks = []
        for cs in range(0, idx_act.size(0), chunk_size_local):
            ce = min(cs + chunk_size_local, idx_act.size(0))
            idx_chunk = idx_act[cs:ce]
            x_act_chunk = canvas.index_select(0, idx_chunk)
            attn_act_chunk = (x_act_chunk != pad_token_id).long()
            with torch.no_grad():
                out_act_chunk = teacher(x_act_chunk, attention_mask=attn_act_chunk)
                logits_act_chunks.append(out_act_chunk.logits)
        logits_act_full = torch.cat(logits_act_chunks, dim=0)
        del logits_act_chunks
        
        logits_pos = logits_act_full[
            torch.arange(idx_act.size(0), device=device),
            a_act
        ]
        logp_pos = F.log_softmax(logits_pos, dim=-1)

        gold = full.index_select(0, idx_act)[
            torch.arange(idx_act.size(0), device=device),
            a_act
        ]
        r_act = logp_pos[
            torch.arange(idx_act.size(0), device=device),
            gold
        ]

        rewards[idx_act, t] = r_act.to(rewards.dtype)
        canvas[idx_act, a_act] = gold
        remaining[idx_act] -= 1
        
        # Cleanup
        del top_p, ent, logits_act_full, logits_pos, logp_pos, r_act, gold
        del top_p_act, ent_act, valid_act, canvas_act, pad_mask_act, lam_vec_act, scores_act, a_act, logp_act, dist
        torch.cuda.empty_cache()

    returns = rewards.sum(dim=1)
    logp_old_sum = logp_steps.sum(dim=1)

    return TrajectoryBatch(
        actions=actions,
        logp_old_sum=logp_old_sum,
        returns=returns,
        max_steps=max_steps,
        batch_size=B,
        n_traj=N,
    )


def compute_logp_new_for_actions(
    teacher: AutoModel,
    lambdanet: nn.Module,
    full_ids: torch.Tensor,         # [B, L]
    prompt_len: torch.Tensor,       # [B]
    ans_len: torch.Tensor,          # [B]
    actions: torch.Tensor,          # [BN, T]
    n_traj_per_prompt: int,
    pad_token_id: int,
    device: torch.device,
) -> torch.Tensor:
    """
    Recompute log-probs of stored actions under the current policy.
    Returns log-prob sums shaped [B, N].
    """
    full_ids = full_ids.to(device)
    prompt_len = prompt_len.to(device)
    ans_len = ans_len.to(device)

    B, L = full_ids.shape
    N = int(n_traj_per_prompt)
    BN, max_steps = actions.shape

    full = full_ids.repeat_interleave(N, dim=0)
    pl = prompt_len.repeat_interleave(N)
    al = ans_len.repeat_interleave(N)

    idxs = torch.arange(L, device=device).unsqueeze(0).expand(BN, L)
    ans_start = pl
    ans_end = pl + al
    in_answer = (idxs >= ans_start.unsqueeze(1)) & (idxs < ans_end.unsqueeze(1))

    canvas = full.clone()
    canvas[in_answer] = MASK_ID
    remaining = in_answer.sum(dim=1)

    logp_new_steps = torch.zeros_like(actions, dtype=torch.float32, device=device)

    for t in range(max_steps):
        a_t = actions[:, t]
        valid_action = (a_t >= 0) & (remaining > 0)
        if not valid_action.any():
            break

        x = canvas
        attn = (x != pad_token_id).long()
        
        # Process in smaller chunks to reduce memory
        chunk_size_local = max(2, min(4, BN // 4))
        top_p_chunks = []
        ent_chunks = []
        
        for cs in range(0, BN, chunk_size_local):
            ce = min(cs + chunk_size_local, BN)
            with torch.no_grad():
                out_chunk = teacher(x[cs:ce], attention_mask=attn[cs:ce])
                logits_chunk = out_chunk.logits
            
            # Compute probs, top_p, ent in chunk
            probs_chunk = F.softmax(logits_chunk, dim=-1)
            top_p_chunk = probs_chunk.max(dim=-1).values
            ent_chunk = -(probs_chunk * probs_chunk.clamp_min(1e-9).log()).sum(dim=-1)
            
            top_p_chunks.append(top_p_chunk)
            ent_chunks.append(ent_chunk)
            del probs_chunk, logits_chunk
            torch.cuda.empty_cache()
        
        top_p = torch.cat(top_p_chunks, dim=0)
        ent = torch.cat(ent_chunks, dim=0)
        del top_p_chunks, ent_chunks
        
        valid = (canvas == MASK_ID)

        idx_act = valid_action.nonzero(as_tuple=False).squeeze(-1)
        top_p_act = top_p.index_select(0, idx_act)
        ent_act = ent.index_select(0, idx_act)
        valid_act = valid.index_select(0, idx_act)
        al_act = al.index_select(0, idx_act)
        canvas_act = canvas.index_select(0, idx_act)
        pad_mask_act = (canvas_act == pad_token_id)

        t_norm_act = torch.full((idx_act.numel(),), t, device=device, dtype=torch.float32) / al_act.clamp_min(1).float()
        t_norm_act = t_norm_act.clamp(0.0, 1.0)

        #lam = lambdanet(t_norm_act).view(-1, 1)
        lam_vec_act = lambdanet(top_p_act, ent_act, valid_act, pad_mask_act)
        scores_act = top_p_act + lam_vec_act * ent_act
        scores_act = scores_act.masked_fill(~valid_act, float("-inf"))

        dist = torch.distributions.Categorical(logits=scores_act)
        chosen = a_t.index_select(0, idx_act)
        logp_act = dist.log_prob(chosen)
        logp_new_steps[idx_act, t] = logp_act

        gold = full.index_select(0, idx_act)[
            torch.arange(idx_act.size(0), device=device),
            chosen
        ]
        canvas[idx_act, chosen] = gold
        remaining[idx_act] -= 1
        
        # Cleanup
        del top_p, ent, top_p_act, ent_act, valid_act, canvas_act, pad_mask_act
        del lam_vec_act, scores_act, dist, logp_act, gold
        torch.cuda.empty_cache()

    logp_new_sum = logp_new_steps.sum(dim=1)
    return logp_new_sum.view(B, N)


def compute_grpo_loss(
    returns: torch.Tensor,      # [B, N]
    logp_old: torch.Tensor,     # [B, N]
    logp_new: torch.Tensor,     # [B, N]
    clip_eps: float,
    adv_whiten: bool = True,
) -> tuple[torch.Tensor, dict]:
    with torch.no_grad():
        baseline = returns.mean(dim=1, keepdim=True)
        adv = returns - baseline
        if adv_whiten:
            std = adv.std(dim=1, keepdim=True).clamp_min(1e-6)
            adv = adv / std
    ratio = torch.exp(logp_new - logp_old)
    unclipped = ratio * adv
    clipped = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv
    loss = -torch.minimum(unclipped, clipped).mean()

    stats = {
        "loss": float(loss.detach().cpu()),
        "adv_mean": float(adv.mean().detach().cpu()),
        "adv_std": float(adv.std().detach().cpu()),
        "ratio_mean": float(ratio.mean().detach().cpu()),
        "ratio_std": float(ratio.std().detach().cpu()),
        "returns_mean": float(returns.mean().detach().cpu()),
        "returns_std": float(returns.std().detach().cpu()),
    }
    return loss, stats


# =========================
# Batched GRPO sampler

# =========================
#blockwise decoding
def sample_trajectories_lambda_batched(
    teacher: AutoModel,
    lambdanet: nn.Module,           # can be DDP-wrapped
    full_ids: torch.Tensor,         # [B, L]
    prompt_len: torch.Tensor,       # [B]
    ans_len: torch.Tensor,          # [B]
    n_traj_per_prompt: int,
    pad_token_id: int,
    is_random: bool = True,
    block_size: int | None = None,  # NEW: block length in tokens; None = whole answer
    device: torch.device | None = None,
    chunk_size: int | None = None,  # Memory optimization: chunk size for processing (None = auto)
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Batched GRPO-style sampling with optional block-wise decoding.

    - full_ids:   [B, L]  = prompt + answer + padding
    - prompt_len: [B]
    - ans_len:    [B]
    - n_traj_per_prompt: N
    - block_size:
        * None  -> decode entire answer as one block (old behavior)
        * int   -> split answer span into blocks of length <= block_size,
                   decode blocks sequentially, order within block is learned.

    Returns:
      - R_mat:     [B, N]  total return per (prompt, traj) = sum_t log p_theta(x*_t | state_t)
      - logpi_mat: [B, N]  sum log π(a_t | state_t) per (prompt, traj)
    """
    # --- decide device from LambdaNet if not given ---
    if device is None:
        lambdanet_device = next(lambdanet.parameters()).device
        device = lambdanet_device
    else:
        lambdanet_device = next(lambdanet.parameters()).device
        assert lambdanet_device == device, f"LambdaNet on {lambdanet_device}, but device={device}"

    # move inputs to this rank's device
    full_ids   = full_ids.to(device)
    prompt_len = prompt_len.to(device)
    ans_len    = ans_len.to(device)

    B, L = full_ids.shape
    N = int(n_traj_per_prompt)
    BN = B * N

    # Expand prompts into BN trajectories
    full = full_ids.repeat_interleave(N, dim=0)           # [BN, L]
    pl   = prompt_len.repeat_interleave(N)                # [BN]
    al   = ans_len.repeat_interleave(N)                   # [BN]

    # Absolute indices grid
    idxs = torch.arange(L, device=device).unsqueeze(0).expand(BN, L)  # [BN, L]

    # Answer span [ans_start, ans_end)
    ans_start = pl                      # [BN]
    ans_end   = pl + al                 # [BN]

    # Mask for the entire answer region
    in_answer = (idxs >= ans_start.unsqueeze(1)) & (idxs < ans_end.unsqueeze(1))  # [BN, L]

    # Initialize canvas: prompt kept, answer masked, padding unchanged
    canvas = full.clone()               # [BN, L]
    MASK_ID = 126336                    # make sure this matches your setup
    canvas[in_answer] = MASK_ID

    # Global accumulators (over all blocks)
    R         = torch.zeros(BN, device=device)  # total return per trajectory
    logp_traj = torch.zeros(BN, device=device)  # sum log π per trajectory

    # Adaptive chunk size based on available memory
    # Smaller chunk size to reduce peak memory usage
    if chunk_size is None:
        chunk_size_val = max(2, min(4, BN // 4))  # Use smaller chunks: 2-4 samples at a time
    else:
        chunk_size_val = max(1, min(chunk_size, BN))  # Ensure valid chunk size

    # Helper: decode one block specified by a boolean mask [BN, L]
    def decode_block(block_mask: torch.Tensor) -> None:
        nonlocal canvas, R, logp_traj

        # Number of masked positions in this block
        block_remaining = ((canvas == MASK_ID) & block_mask).sum(dim=1)  # [BN]
        max_block_steps = int(block_remaining.max().item()) if block_remaining.numel() > 0 else 0
        
        # Track initial block length for normalization
        block_length_initial = block_mask.sum(dim=1).float()  # [BN] initial block length per trajectory

        for step in range(max_block_steps):
            alive = block_remaining > 0
            if not alive.any():
                break

            x    = canvas                                     # [BN, L]
            attn = (x != pad_token_id).long()                 # [BN, L]

            # Process in chunks: compute top_p and ent without storing full logits/probs
            top_p_chunks = []
            ent_chunks = []
            logits_storage = []  # Store only logits for active indices (computed later)
            
            for cs in range(0, BN, chunk_size_val):
                ce = min(cs + chunk_size_val, BN)
                x_chunk    = x[cs:ce]
                attn_chunk = attn[cs:ce]
                with torch.no_grad():
                    out_chunk = teacher(x_chunk, attention_mask=attn_chunk)
                    logits_chunk = out_chunk.logits  # [chunk_size_val, L, V]
                
                # Compute probs, top_p, ent in chunk to avoid storing full tensor
                probs_chunk = F.softmax(logits_chunk, dim=-1)  # [chunk_size_val, L, V]
                top_p_chunk, _ = probs_chunk.max(dim=-1)       # [chunk_size_val, L]
                ent_chunk = -(probs_chunk * probs_chunk.clamp_min(1e-9).log()).sum(dim=-1)  # [chunk_size_val, L]
                
                top_p_chunks.append(top_p_chunk)
                ent_chunks.append(ent_chunk)
                
                # Store logits only if we might need them later (for active indices)
                # We'll recompute if needed to save memory
                del probs_chunk, logits_chunk
                torch.cuda.empty_cache()
            
            # Concatenate results
            top_p = torch.cat(top_p_chunks, dim=0)  # [BN, L]
            ent = torch.cat(ent_chunks, dim=0)      # [BN, L]
            del top_p_chunks, ent_chunks

            # Valid positions = masked tokens IN THIS BLOCK
            valid = (canvas == MASK_ID) & block_mask          # [BN, L]

            can_act = alive & (valid.sum(dim=1) > 0)          # [BN]
            idx_act = can_act.nonzero(as_tuple=False).squeeze(-1)  # [M]
            if idx_act.numel() == 0:
                break

            top_p_act  = top_p.index_select(0, idx_act)       # [M, L]
            ent_act    = ent.index_select(0, idx_act)         # [M, L]
            valid_act  = valid.index_select(0, idx_act)       # [M, L]
            canvas_act = canvas.index_select(0, idx_act)      # [M, L]
            attn_act   = (canvas_act != pad_token_id).long()  # [M, L]
            canvas_act = canvas.index_select(0, idx_act)
            pad_mask_act = (canvas_act == pad_token_id)

            # Compute normalized time within THIS BLOCK: step / block_length
            # Use the loop step counter directly, normalized by initial block length
            block_length_act = block_length_initial.index_select(0, idx_act)  # [M] initial block length per trajectory
            t_norm_act = torch.full((idx_act.numel(),), step, device=device, dtype=torch.float32) / block_length_act.clamp_min(1.0)  # [M]
            # Clamp to [0, 1] to be safe
            t_norm_act = t_norm_act.clamp(0.0, 1.0)

            # λ depends on normalized time
            #lam = lambdanet(t_norm_act)  # [M]
            #lam = lam.view(-1, 1)                                # [M,1]
            lam_vec_act = lambdanet(top_p_act, ent_act, valid_act, pad_mask_act)
            scores_act = top_p_act + lam_vec_act * ent_act            # [M, L]
            scores_act = scores_act.masked_fill(~valid_act, float("-inf"))

            # If some rows have no valid positions (all -inf), neutralize to avoid NaNs
            all_neg_inf = torch.isneginf(scores_act).all(dim=1)
            if all_neg_inf.any():
                scores_act[all_neg_inf] = 0.0

            if is_random:
                dist   = torch.distributions.Categorical(logits=scores_act)
                a_act  = dist.sample()                        # [M]
                logp_a = dist.log_prob(a_act)                 # [M]
            else:
                a_act      = torch.argmax(scores_act, dim=-1) # [M]
                logp_scores = F.log_softmax(scores_act, dim=-1)
                logp_a     = logp_scores[
                    torch.arange(idx_act.size(0), device=device), a_act
                ]

            # Safety: chosen actions must be valid
            valid_sel = valid_act[
                torch.arange(idx_act.size(0), device=device), a_act
            ]
            if not valid_sel.all():
                bad = (~valid_sel).nonzero(as_tuple=False).squeeze(-1)
                raise RuntimeError(
                    f"Invalid action selected at indices {bad.tolist()}, "
                    f"actions={a_act[bad].tolist()}"
                )

            # Reward from teacher at those positions
            # Recompute logits only for active indices to save memory
            logits_act_chunks = []
            for cs in range(0, idx_act.size(0), chunk_size_val):
                ce = min(cs + chunk_size_val, idx_act.size(0))
                idx_chunk = idx_act[cs:ce]
                x_act_chunk = canvas.index_select(0, idx_chunk)
                attn_act_chunk = (x_act_chunk != pad_token_id).long()
                with torch.no_grad():
                    out_act_chunk = teacher(x_act_chunk, attention_mask=attn_act_chunk)
                    logits_act_chunks.append(out_act_chunk.logits)
            logits_act_full = torch.cat(logits_act_chunks, dim=0)  # [M, L, V]
            del logits_act_chunks
            
            a_act_chunk = a_act
            logits_pos = logits_act_full[
                torch.arange(idx_act.size(0), device=device), a_act_chunk
            ]                                                 # [M, V]
            logp_pos = F.log_softmax(logits_pos, dim=-1)      # [M, V]

            gold = full.index_select(0, idx_act)[
                torch.arange(idx_act.size(0), device=device), a_act_chunk
            ]                                                 # [M]
            r_act = logp_pos[
                torch.arange(idx_act.size(0), device=device), gold
            ]                                                 # [M]

            # Accumulate GRPO stats
            R[idx_act]        += r_act
            logp_traj[idx_act] += logp_a

            # Teacher forcing: fill in gold token
            canvas[idx_act, a_act] = gold

            # One fewer masked token in this block
            block_remaining[idx_act] -= 1
            
            # Memory cleanup: delete large intermediate tensors immediately
            del logits_act_full, logits_pos, logp_pos, r_act
            del top_p_act, ent_act, valid_act, canvas_act, attn_act, pad_mask_act
            del lam_vec_act, scores_act, a_act, logp_a, gold
            del top_p, ent  # Delete full tensors after extracting needed parts
            if is_random:
                del dist
            else:
                if 'logp_scores' in locals():
                    del logp_scores
            
            # Clear cache after every step to reduce fragmentation
            torch.cuda.empty_cache()

    # === Case 1: no blocking (original behavior) ===
    if block_size is None:
        block_mask_full = in_answer     # [BN, L]
        decode_block(block_mask_full)
    else:
        # === Case 2: block-wise decoding over answer span ===
        max_ans_len = int(al.max().item()) if al.numel() > 0 else 0
        if max_ans_len > 0:
            for offset in range(0, max_ans_len, block_size):
                # block_i covers [ans_start_i + offset, ans_start_i + offset + block_size)
                block_start = ans_start + offset                                 # [BN]
                block_end   = torch.minimum(ans_end, block_start + block_size)   # [BN]

                block_mask = (idxs >= block_start.unsqueeze(1)) & \
                             (idxs <  block_end.unsqueeze(1))                     # [BN, L]

                if not block_mask.any():
                    continue

                decode_block(block_mask)
                # Clean up after each block
                torch.cuda.empty_cache()

    # After all blocks, canvas should equal ground truth
    assert torch.equal(canvas, full), "Final canvas does not match ground truth"
    
    # Final cleanup after sampling (variables are already cleaned up inside decode_block)
    torch.cuda.empty_cache()

    # Reshape back to [B, N]
    R_mat     = R.view(B, N)
    logpi_mat = logp_traj.view(B, N)
    return R_mat, logpi_mat

# =========================
# Evaluation on GSM8K test
# =========================

@torch.no_grad()
def eval_lambda_on_gsm8k_test(
    teacher: AutoModel,
    lambdanet: nn.Module,
    model_path: str,
    pad_token_id: int,
    batch_size: int = 2,
    max_length: int = 256,
    device: torch.device = torch.device("cuda"),
    max_eval_batches: int = 5,
    block_size: int | None = None,
) -> float:
    """
    Evaluate current λ(t) on GSM8K test using GREEDY order selection.

    For each batch:
      - Run greedy_trajectories_lambda on test prompts
      - Get R_vec [B] (sum log teacher prob per sample)
    Return:
      - avg_R over all evaluated samples
    """

    
    ds_test = GSM8KLLADA(split="test", model_path=model_path, max_length=max_length)
    dl_test = DataLoader(ds_test, batch_size=batch_size, shuffle=False)

    lambdanet.eval()
    teacher.eval()

    total_R = 0.0
    total_count = 0

    for b_idx, batch in enumerate(dl_test):
        if b_idx >= max_eval_batches:
            break

        full_ids   = batch["full_ids"].to(device)     # [B, L]
        prompt_len = batch["prompt_len"].to(device)   # [B]
        ans_len    = batch["ans_len"].to(device)      # [B]

        R_vec, _ = sample_trajectories_lambda_batched(
            teacher=teacher,
            lambdanet=lambdanet,
            full_ids=full_ids,
            prompt_len=prompt_len,
            ans_len=ans_len,
            n_traj_per_prompt=1,
            pad_token_id=pad_token_id,
            is_random=False,
            block_size=block_size,
            device=device,
        )  # [B]

        # sanitize just in case
        R_vec = torch.nan_to_num(R_vec, nan=-1e9, neginf=-1e9, posinf=1e9)

        total_R += R_vec.sum().item()
        total_count += R_vec.numel()

    avg_R = total_R / max(total_count, 1)

    lambdanet.train()  # restore train mode
    return avg_R


# =========================
# GRPO-style training loop
# =========================
def train_lambda_llada_grpo(
    batch_size: int = 2,
    n_traj: int = 4,
    device: str = "cuda",
    max_length: int = 256,
    lr: float = 1e-3,
    epochs: int = 1,
    save_dir: str = "workdir/lambda_llada_grpo",
    rank: int = 0,
    world_size: int = 1,
    block_size: int | None = None,
    true_grpo: bool = False,
    clip_eps: float = 0.2,
    adv_whiten: bool = True,
):
    """
    GRPO-style training for λ(t):

    For each batch of B prompts:
      - Sample N trajectories per prompt in parallel (B*N total)
      - Get R_mat [B,N], logpi_mat [B,N]
      - Per-prompt baseline: b_i = mean_j R_{i,j}
      - Advantage A_{i,j} = R_{i,j} - b_i
      - Loss = - E_{i,j}[ A_{i,j} * logpi_{i,j} ]

    If true_grpo is enabled, use PPO-style importance ratios with clipping.
    """
    #------ddp flags----
    is_distributed = dist.is_available() and dist.is_initialized()
    if is_distributed:
        # overwrite rank/world_size with the real values to be safe
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1

    is_main = (rank == 0)

    if n_traj < 2:
        raise ValueError(f"n_traj must be >= 2 for GRPO (got {n_traj})")

    # Add timestamp to save_dir for uniqueness
    if is_main:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        save_dir = os.path.join(save_dir, f"run_{timestamp}")
    
    # Create save directory
    os.makedirs(save_dir, exist_ok=True)
    
    # Initialize wandb (with error handling for distributed environments)
    wandb_initialized = False
    if is_main:
        try:
            wandb.init(
                project="lambda-llada-grpo",
                config={
                    "batch_size": batch_size,
                    "n_traj": n_traj,
                    "max_length": max_length,
                    "lr": lr,
                    "epochs": epochs,
                    "true_grpo": true_grpo,
                    "clip_eps": clip_eps,
                    "adv_whiten": adv_whiten,
                },
                dir=save_dir,
                # mode="offline",  # Use offline mode to avoid service connection issues
            )
            wandb_initialized = True
        except Exception as e:
            print(f"[WARNING] Failed to initialize wandb: {e}")
            print("[INFO] Continuing without wandb logging")
            os.environ["WANDB_MODE"] = "disabled"
    
    device = torch.device(device if torch.cuda.is_available() else "cpu")
    # Normalize CUDA device to an explicit index so it matches module parameters
    if device.type == "cuda" and device.index is None:
        device = torch.device(f"cuda:{torch.cuda.current_device()}")

    model_path = "GSAI-ML/LLaDA-8B-Instruct"

    # Frozen teacher
    teacher = AutoModel.from_pretrained(
        model_path,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    ).to(device).eval()

    tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    if tok.pad_token_id is None:
        tok.pad_token_id = tok.eos_token_id
    pad_token_id = tok.pad_token_id

    
    ds0 = GSM8KLLADA(split="train", model_path=model_path, max_length=max_length)



    N = min(10000, len(ds0))

    # Check if distributed is initialized before accessing rank/world_size
    if is_distributed:
        rank = dist.get_rank()
        world = dist.get_world_size()
    else:
        rank = 0
        world = 1

    # Make N divisible by world so each rank has same #items
    N_eff = (N // world) * world
    indices = list(range(N_eff))

    # Shard
    my_indices = indices[rank::world]
    ds = Subset(ds0, my_indices)
    
    if is_distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            ds,
            num_replicas=world_size,
            rank=rank,
            shuffle=True,
        )
        dl = DataLoader(ds, batch_size=batch_size, sampler=train_sampler)
    else:
        train_sampler = None
        dl = DataLoader(ds, batch_size=batch_size, shuffle=True)

    lambdanet = LambdaNet(d_model=128, nhead=4, num_layers=2, dim_feedforward=256, max_len=max_length).to(device)
    lambdanet.train()
    
    if is_distributed:
        lambdanet = DDP(lambdanet, device_ids=[device.index] if device.type == "cuda" else None, find_unused_parameters=False)    
    opt = torch.optim.Adam(lambdanet.parameters(), lr=lr)

    lambdanet_old = None
    if true_grpo:
        lambdanet_old = LambdaNet(hidden=128).to(device)
        lambdanet_old.load_state_dict(get_base_model(lambdanet).state_dict())
        lambdanet_old.eval()

    global_step = 0

    for ep in range(epochs):
        if is_distributed and train_sampler is not None:
            train_sampler.set_epoch(ep)
        for batch in tqdm(dl, desc=f"Epoch {ep+1}/{epochs}") if is_main else dl:
            full_ids = batch["full_ids"].to(device)       # [B, L]
            prompt_len = batch["prompt_len"].to(device)   # [B]
            ans_len = batch["ans_len"].to(device)         # [B]

            if true_grpo:
                traj = rollout_lambda_grpo(
                    teacher=teacher,
                    lambdanet=lambdanet_old,
                    full_ids=full_ids,
                    prompt_len=prompt_len,
                    ans_len=ans_len,
                    n_traj_per_prompt=n_traj,
                    pad_token_id=pad_token_id,
                    device=device,
                )
                returns_mat = traj.returns.view(full_ids.size(0), n_traj)
                logp_old_mat = traj.logp_old_sum.view(full_ids.size(0), n_traj)

                logp_new_mat = compute_logp_new_for_actions(
                    teacher=teacher,
                    lambdanet=lambdanet,
                    full_ids=full_ids,
                    prompt_len=prompt_len,
                    ans_len=ans_len,
                    actions=traj.actions,
                    n_traj_per_prompt=n_traj,
                    pad_token_id=pad_token_id,
                    device=device,
                )

                loss, ppo_stats = compute_grpo_loss(
                    returns=returns_mat,
                    logp_old=logp_old_mat,
                    logp_new=logp_new_mat,
                    clip_eps=clip_eps,
                    adv_whiten=adv_whiten,
                )
                avg_train_return = returns_mat.mean().item()
                
                # Clean up trajectory data
                del traj, returns_mat, logp_old_mat, logp_new_mat
            else:
                # Sample BN trajectories in parallel
                R_mat, logpi_mat = sample_trajectories_lambda_batched(
                    teacher=teacher,
                    lambdanet=lambdanet,
                    full_ids=full_ids,
                    prompt_len=prompt_len,
                    ans_len=ans_len,
                    n_traj_per_prompt=n_traj,
                    pad_token_id=pad_token_id,
                    is_random=True,
                    block_size=block_size,
                    device=device,
                )  # both [B, N]
                
                # Clean up after sampling
                torch.cuda.empty_cache()

                with torch.no_grad():
                    baseline = R_mat.mean(dim=1, keepdim=True)   # [B,1]
                    adv = R_mat - baseline                       # [B,N]
                    
                    # Normalize by std per prompt (reduce variance)
                    mean_adv = adv.mean(dim=1, keepdim=True)     # [B,1]
                    std_adv = adv.std(dim=1, keepdim=True).clamp_min(1e-8)  # [B,1]
                    adv_norm = (adv - mean_adv) / std_adv        # [B,N]
                loss = -(logpi_mat * adv_norm).mean()
                avg_train_return = R_mat.mean().item()
            
            # Check if loss is finite
            if not torch.isfinite(loss):
                print(f"[WARNING] Loss is not finite ({loss.item()}), skipping update")
                global_step += 1
                continue

            opt.zero_grad()
            loss.backward()
            
            # Check for NaN/Inf gradients before clipping
            has_bad_grad = False
            max_grad_norm = 0.0
            for name, param in lambdanet.named_parameters():
                if param.grad is not None:
                    if not torch.isfinite(param.grad).all():
                        print(f"[WARNING] NaN/Inf gradient in {name}, skipping update")
                        has_bad_grad = True
                        break
                    param_norm = param.grad.data.norm(2)
                    max_grad_norm = max(max_grad_norm, param_norm.item())
            
            if has_bad_grad:
                opt.zero_grad()  # Clear bad gradients
                global_step += 1
                continue
            
            # Clip gradients and get the actual clipped norm
            clipped_norm = torch.nn.utils.clip_grad_norm_(lambdanet.parameters(), 1.0)
            max_grad_norm = max(max_grad_norm, clipped_norm.item())
            
            # Check if parameters have NaN before optimizer step
            has_nan_params = False
            for name, param in lambdanet.named_parameters():
                if torch.isnan(param).any() or torch.isinf(param).any():
                    print(f"[WARNING] Parameter {name} has NaN/Inf before optimizer step, reinitializing network")
                    has_nan_params = True
                    break
            
            if has_nan_params:
                # Reinitialize the network
                for module in lambdanet.modules():
                    if isinstance(module, nn.Linear):
                        nn.init.xavier_uniform_(module.weight)
                        if module.bias is not None:
                            nn.init.zeros_(module.bias)
                print("[INFO] Network reinitialized due to NaN parameters")
                opt.zero_grad()  # Clear gradients
                global_step += 1
                continue
            
            opt.step()
            
            # Check if parameters have NaN after optimizer step
            has_nan_after = False
            for name, param in lambdanet.named_parameters():
                if torch.isnan(param).any() or torch.isinf(param).any():
                    print(f"[WARNING] Parameter {name} became NaN/Inf after optimizer step, reinitializing")
                    has_nan_after = True
                    break
            
            if has_nan_after:
                base_model = get_base_model(lambdanet)
                for module in base_model.modules():
                    if isinstance(module, nn.Linear):
                        nn.init.xavier_uniform_(module.weight)
                        if module.bias is not None:
                            nn.init.zeros_(module.bias)
                if is_main:
                    print("[INFO] Network reinitialized due to NaN parameters after step")

            if true_grpo and lambdanet_old is not None:
                base_model = get_base_model(lambdanet)
                lambdanet_old.load_state_dict(base_model.state_dict())
                lambdanet_old.eval()

            global_step += 1
            
            # Clean up batch tensors after logging
            del full_ids, prompt_len, ans_len
            
            # Clear cache periodically to prevent accumulation
            if global_step % 5 == 0:
                torch.cuda.empty_cache()
            
            # Log metrics to wandb
            if wandb_initialized and wandb.run is not None:
                log_payload = {
                    "train/loss": loss.item(),
                    "train/avg_R": avg_train_return,
                    "step": global_step,
                    "epoch": ep + 1,
                }
                if true_grpo:
                    log_payload.update({
                        "train/ratio_mean": ppo_stats["ratio_mean"],
                        "train/ratio_std": ppo_stats["ratio_std"],
                        "train/adv_mean": ppo_stats["adv_mean"],
                    })
                wandb.log(log_payload)
            
            if global_step % 2 == 0 and is_main:
                print(
                    f"[ep {ep+1} step {global_step}] "
                    f"loss={loss.item():.4f}, avg_R={avg_train_return:.4f}"
                )

                base_model = get_base_model(lambdanet)
                avg_R_test = eval_lambda_on_gsm8k_test(
                    teacher=teacher,
                    lambdanet=base_model,
                    model_path=model_path,
                    pad_token_id=pad_token_id,
                    batch_size=batch_size,
                    max_length=max_length,
                    device=device,
                    max_eval_batches=5,
                    block_size=block_size
                )
                print(
                    f"[eval]  step={global_step} "
                    f"avg_R_test (5 batches, sum log p) = {avg_R_test:.4f}"
                )
                
                # Log eval metrics to wandb
                if wandb_initialized and wandb.run is not None:
                    wandb.log({
                        "eval/avg_R_test": avg_R_test,
                        "step": global_step,
                    })
                
                # Save checkpoint every 10 batches
                if is_main:
                    base_model = get_base_model(lambdanet)

                    checkpoint_path = os.path.join(save_dir, f"checkpoint_step_{global_step}.pt")
                    torch.save({
                        "step": global_step,
                        "epoch": ep + 1,
                        "lambdanet_state_dict": base_model.state_dict(),
                        "optimizer_state_dict": opt.state_dict(),
                        "loss": loss.item(),
                        "avg_R": avg_train_return,
                        "avg_R_test": avg_R_test,
                    }, checkpoint_path)
                    
                    # Also save as latest checkpoint
                    latest_path = os.path.join(save_dir, "checkpoint_latest.pt")
                    torch.save({
                        "step": global_step,
                        "epoch": ep + 1,
                        "lambdanet_state_dict": base_model.state_dict(),
                        "optimizer_state_dict": opt.state_dict(),
                        "loss": loss.item(),
                        "avg_R": avg_train_return,
                        "avg_R_test": avg_R_test,
                    }, latest_path)
                    
                    if wandb_initialized and wandb.run is not None:
                        wandb.save(checkpoint_path)
                        wandb.save(latest_path)
                    
                    print(f"[INFO] Saved checkpoint to {checkpoint_path}")


    # Finish wandb run
    if wandb_initialized and wandb.run is not None and is_main:
        wandb.finish()
    
    return lambdanet


# =========================
# CLI
# =========================
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--n_traj", type=int, default=6)
    parser.add_argument("--max_length", type=int, default=512)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--save_dir", type=str)
    parser.add_argument("--block_size", type=int, default=0, help="Block size for block-wise decoding; set to 0 or negative for no blocking")
    parser.add_argument("--true_grpo", action="store_true", help="Enable PPO-style GRPO (ratio + clipping)")
    parser.add_argument("--clip_eps", type=float, default=0.2)
    parser.add_argument("--no_adv_whiten", action="store_true", help="Disable per-prompt advantage whitening in true GRPO")
    args = parser.parse_args()

    # ----- DDP init (torchrun) -----
    if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1:
        distributed = True
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        
        # Validate device before setting
        num_gpus = torch.cuda.device_count()
        if local_rank >= num_gpus:
            raise RuntimeError(
                f"LOCAL_RANK {local_rank} is invalid. Only {num_gpus} GPU(s) available. "
                f"Make sure --nproc_per_node matches the number of GPUs requested in SLURM."
            )
        
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
        device = torch.device(f"cuda:{local_rank}")
    else:
        distributed = False
        rank = 0
        world_size = 1
        device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    # Make per-run directory using a shared timestamp across ranks
    run_ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    if distributed:
        ts_list = [run_ts]
        dist.broadcast_object_list(ts_list, src=0)
        run_ts = ts_list[0]
    run_save_dir = os.path.join(args.save_dir, f"run_lambda_{run_ts}")


    # Save a copy of the args to args.json in the run_save_dir
    if rank == 0:
        import json
        os.makedirs(run_save_dir, exist_ok=True)
        args_path = os.path.join(run_save_dir, "args.json")
        with open(args_path, "w") as f:
            json.dump(vars(args), f, indent=4)
    
    train_lambda_llada_grpo(
        batch_size=args.batch_size,
        n_traj=args.n_traj,
        device=device,
        max_length=args.max_length,
        lr=args.lr,
        epochs=args.epochs,
        save_dir=run_save_dir,
        rank=rank,
        world_size=world_size,
        block_size=args.block_size if args.block_size > 0 else None,
        true_grpo=args.true_grpo,
        clip_eps=args.clip_eps,
        adv_whiten=not args.no_adv_whiten,
    )

    if distributed:
        dist.destroy_process_group()



if __name__ == "__main__":
    main()