# run_verification_weight_grpo_trap.py
from __future__ import annotations
import os
import time
import argparse

import torch
import torch.nn.functional as F

from markov_chain import build_markov_transitions
from model import GPT, GPTConfig


def _build_allowed_matrix(next_states, n_states: int, device: torch.device) -> torch.Tensor:
    allowed = torch.zeros((n_states, n_states), dtype=torch.bool, device=device)
    for s, ns in enumerate(next_states):
        ns_t = torch.as_tensor(ns, device=device, dtype=torch.long)
        allowed[s, ns_t] = True
    return allowed


def _build_run_tag(args: argparse.Namespace, cfg: GPTConfig) -> str:
    return (
        f"trap_h{args.n_hubs}_m{args.m}_k{args.verify_k}_w{args.weight_mode}"
        f"_emb{cfg.n_embd}_l{cfg.n_layer}_head{cfg.n_head}_bs{cfg.block_size}"
    )

def _get_trap_candidates(n_hubs: int, m: int) -> tuple[int, int]:
    last_hub_base = (n_hubs - 1) * m
    trap_x_m1 = last_hub_base + (m - 2)
    trap_x_m2 = last_hub_base + (m - 3)
    return trap_x_m1, trap_x_m2


def _compute_shortest_steps(next_states, n_states: int, trap_state: int | None = None) -> list[int]:
    rev: list[list[int]] = [[] for _ in range(n_states)]
    for s, ns in enumerate(next_states):
        if trap_state is not None and s == trap_state:
            continue
        for t in ns.tolist():
            if t >= s:
                if trap_state is not None and t == trap_state:
                    continue
                rev[int(t)].append(s)

    dist = [-1] * n_states
    final_state = n_states - 1
    dist[final_state] = 0

    from collections import deque

    q = deque([final_state])
    while q:
        cur = q.popleft()
        if trap_state is not None and cur == trap_state:
            continue
        for prev in rev[cur]:
            if trap_state is not None and prev == trap_state:
                continue
            if dist[prev] == -1:
                dist[prev] = dist[cur] + 1
                q.append(prev)
    return dist


def _compute_log_probs(model: GPT, tokens: torch.Tensor, amp: bool) -> torch.Tensor:
    x = tokens[:, :-1]
    with torch.cuda.amp.autocast(enabled=amp):
        logits = model(x)
    logits = logits.float()
    logp = F.log_softmax(logits, dim=-1)
    next_tokens = tokens[:, 1:]
    return torch.gather(logp, 2, next_tokens.unsqueeze(-1)).squeeze(-1)


def _compute_trial_success(
    tokens: torch.Tensor,
    n_states: int,
    final_state: int,
    retry_id: int,
    pad_id: int,
    allowed: torch.Tensor,
    shortest_steps_by_trap: torch.Tensor,
    trap_idx: torch.Tensor,
    trap_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    device = tokens.device
    total, T = tokens.shape

    terminal = (tokens == final_state) | (tokens == retry_id) | (tokens == pad_id) | (tokens >= n_states)
    term_any = terminal.any(dim=1)
    term_idx = terminal.float().argmax(dim=1)
    term_idx = torch.where(term_any, term_idx, torch.full_like(term_idx, T - 1))
    term_token = tokens[torch.arange(total, device=device), term_idx]

    lengths = term_idx + 1
    hit_final = term_token == final_state
    pos_full = torch.arange(T, device=device)
    valid_tokens = pos_full.unsqueeze(0) < lengths.unsqueeze(1)
    trap_match = tokens == trap_states.unsqueeze(1)
    has_trap = (trap_match & valid_tokens).any(dim=1)

    prev = tokens[:, :-1]
    nxt = tokens[:, 1:]
    pos = torch.arange(T - 1, device=device)
    valid_pos = pos.unsqueeze(0) < (lengths - 1).unsqueeze(1)
    state_mask = (prev < n_states) & (nxt < n_states) & valid_pos

    decrease = (nxt < prev) & state_mask
    has_decrease = decrease.any(dim=1)

    allowed_lookup = allowed[
        prev.clamp(0, n_states - 1),
        nxt.clamp(0, n_states - 1),
    ]
    invalid = (~allowed_lookup) & state_mask
    has_invalid = invalid.any(dim=1)

    start = tokens[:, 0].clamp(0, n_states - 1)
    shortest = shortest_steps_by_trap[trap_idx, start]
    is_shortest = (shortest >= 0) & ((lengths - 1) == shortest)
    success = hit_final & ~has_decrease & ~has_invalid & is_shortest & ~has_trap
    return success, lengths


def _sample_verification_weighted_trajectories(
    model: GPT,
    ref_model: GPT | None,
    starts: torch.Tensor,
    max_len: int,
    verify_k: int,
    retry_id: int,
    pad_id: int,
    final_state: int,
    n_states: int,
    allowed: torch.Tensor,
    trap_candidates: torch.Tensor,
    shortest_steps_by_trap: torch.Tensor,
    temperature: float,
    top_k: int,
    amp: bool,
) -> tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor | None,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
    device = starts.device
    total = starts.size(0)
    trap_idx = torch.randint(0, trap_candidates.numel(), (total,), device=device)
    trap_states = trap_candidates[trap_idx]
    max_total_len = verify_k * max_len + (verify_k - 1)

    tokens = torch.full((total, max_total_len), pad_id, device=device, dtype=torch.long)
    tokens[:, 0] = starts

    trial_tokens = torch.full((total, max_len), pad_id, device=device, dtype=torch.long)
    trial_tokens[:, 0] = starts
    trial_lengths = torch.ones(total, device=device, dtype=torch.long)

    active = torch.ones(total, device=device, dtype=torch.bool)
    attempt_idx = torch.zeros(total, device=device, dtype=torch.long)
    need_retry = torch.zeros(total, device=device, dtype=torch.bool)
    need_start = torch.zeros(total, device=device, dtype=torch.bool)

    attempt_rewards = torch.zeros((total, verify_k), device=device, dtype=torch.float32)

    logp_steps: list[torch.Tensor] = []
    action_steps: list[torch.Tensor] = []
    attempt_steps: list[torch.Tensor] = []
    logp_ref_steps: list[torch.Tensor] | None = [] if ref_model is not None else None

    for t in range(1, max_total_len):
        if not active.any():
            break

        step_logp = torch.zeros(total, device=device, dtype=torch.float32)
        step_mask = torch.zeros(total, device=device, dtype=torch.float32)
        step_attempt = torch.zeros(total, device=device, dtype=torch.long)
        step_logp_ref = (
            torch.zeros(total, device=device, dtype=torch.float32) if ref_model is not None else None
        )

        retry_mask = active & need_retry
        if retry_mask.any():
            tokens[retry_mask, t] = retry_id
            need_retry[retry_mask] = False
            need_start[retry_mask] = True

        start_mask = active & need_start & ~retry_mask
        if start_mask.any():
            tokens[start_mask, t] = starts[start_mask]
            need_start[start_mask] = False
            trial_tokens[start_mask] = pad_id
            trial_tokens[start_mask, 0] = starts[start_mask]
            trial_lengths[start_mask] = 1

        model_mask = active & ~(need_retry | need_start | retry_mask | start_mask)
        if model_mask.any():
            x = tokens[:, :t].clone()
            with torch.cuda.amp.autocast(enabled=amp):
                logits = model(x)[:, -1, :]
            logits = logits.float()

            if temperature != 1.0:
                logits = logits / temperature
            if top_k > 0:
                v, _ = torch.topk(logits, top_k)
                kth = v[:, -1].unsqueeze(-1)
                logits = torch.where(logits < kth, torch.full_like(logits, -float("inf")), logits)

            logp = F.log_softmax(logits, dim=-1)
            probs = torch.exp(logp)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)

            tokens[model_mask, t] = next_token[model_mask]

            gathered = torch.gather(logp, 1, next_token.unsqueeze(1)).squeeze(1)
            mask_f = model_mask.float()
            step_logp = gathered * mask_f
            step_mask = mask_f
            step_attempt = (attempt_idx + 1) * model_mask.long()

            if ref_model is not None:
                with torch.no_grad():
                    with torch.cuda.amp.autocast(enabled=amp):
                        ref_logits = ref_model(x)[:, -1, :]
                    ref_logits = ref_logits.float()
                    if temperature != 1.0:
                        ref_logits = ref_logits / temperature
                    if top_k > 0:
                        v, _ = torch.topk(ref_logits, top_k)
                        kth = v[:, -1].unsqueeze(-1)
                        ref_logits = torch.where(
                            ref_logits < kth, torch.full_like(ref_logits, -float("inf")), ref_logits
                        )
                    ref_logp = F.log_softmax(ref_logits, dim=-1)
                    ref_gathered = torch.gather(ref_logp, 1, next_token.unsqueeze(1)).squeeze(1)
                    step_logp_ref = ref_gathered * mask_f

            idx = torch.nonzero(model_mask, as_tuple=False).squeeze(1)
            pos = trial_lengths[idx]
            trial_tokens[idx, pos] = next_token[idx]
            trial_lengths[idx] = trial_lengths[idx] + 1

            is_final = next_token == final_state
            is_retry = next_token == retry_id
            is_pad = next_token == pad_id
            is_invalid = next_token >= n_states
            hit_len = trial_lengths >= max_len
            end_mask = model_mask & (is_final | is_retry | is_pad | is_invalid | hit_len)

            if end_mask.any():
                ended_idx = torch.nonzero(end_mask, as_tuple=False).squeeze(1)
                with torch.no_grad():
                    trial_success, _ = _compute_trial_success(
                        trial_tokens[ended_idx],
                        n_states,
                        final_state,
                        retry_id,
                        pad_id,
                        allowed,
                        shortest_steps_by_trap,
                        trap_idx[ended_idx],
                        trap_states[ended_idx],
                    )

                attempt_rewards[ended_idx, attempt_idx[ended_idx]] = trial_success.float()

                done_success = ended_idx[trial_success]
                if done_success.numel() > 0:
                    active[done_success] = False
                    need_retry[done_success] = False
                    need_start[done_success] = False

                fail_mask = ~trial_success
                if fail_mask.any():
                    fail_idx = ended_idx[fail_mask]
                    last_trial = attempt_idx[fail_idx] >= (verify_k - 1)
                    if last_trial.any():
                        done_fail = fail_idx[last_trial]
                        active[done_fail] = False
                        need_retry[done_fail] = False
                        need_start[done_fail] = False
                    cont_mask = ~last_trial
                    if cont_mask.any():
                        cont_idx = fail_idx[cont_mask]
                        attempt_idx[cont_idx] = attempt_idx[cont_idx] + 1
                        ended_by_retry = is_retry[cont_idx]
                        if ended_by_retry.any():
                            need_start[cont_idx[ended_by_retry]] = True
                        if (~ended_by_retry).any():
                            need_retry[cont_idx[~ended_by_retry]] = True

        logp_steps.append(step_logp)
        action_steps.append(step_mask)
        attempt_steps.append(step_attempt)
        if logp_ref_steps is not None:
            logp_ref_steps.append(step_logp_ref)

    if logp_steps:
        log_probs = torch.stack(logp_steps, dim=1)
        action_mask = torch.stack(action_steps, dim=1)
        attempt_ids = torch.stack(attempt_steps, dim=1)
    else:
        log_probs = torch.zeros((total, 0), device=device, dtype=torch.float32)
        action_mask = torch.zeros((total, 0), device=device, dtype=torch.float32)
        attempt_ids = torch.zeros((total, 0), device=device, dtype=torch.long)

    steps = log_probs.size(1)
    if steps < max_total_len - 1:
        pad = torch.zeros((total, max_total_len - 1 - steps), device=device, dtype=log_probs.dtype)
        log_probs = torch.cat([log_probs, pad], dim=1)
        action_mask = torch.cat([action_mask, pad], dim=1)
        attempt_pad = torch.zeros((total, max_total_len - 1 - steps), device=device, dtype=attempt_ids.dtype)
        attempt_ids = torch.cat([attempt_ids, attempt_pad], dim=1)

    log_probs_ref = None
    if logp_ref_steps is not None:
        if logp_ref_steps:
            log_probs_ref = torch.stack(logp_ref_steps, dim=1)
        else:
            log_probs_ref = torch.zeros((total, 0), device=device, dtype=torch.float32)
        steps_ref = log_probs_ref.size(1)
        if steps_ref < max_total_len - 1:
            pad_ref = torch.zeros((total, max_total_len - 1 - steps_ref), device=device, dtype=log_probs_ref.dtype)
            log_probs_ref = torch.cat([log_probs_ref, pad_ref], dim=1)

    trials_used = attempt_idx + 1
    return tokens, log_probs, action_mask, log_probs_ref, attempt_ids, attempt_rewards, trials_used, trap_states


def _compute_raw_attempt_weights(
    rewards: torch.Tensor,
    attempt_exists: torch.Tensor,
    weight_mode: str,
    eps: float,
) -> torch.Tensor:
    batch_size, _, num_attempts = rewards.shape
    if weight_mode == "none":
        return torch.ones((batch_size, num_attempts), device=rewards.device, dtype=rewards.dtype)

    success_counts = rewards.sum(dim=1)
    attempt_counts = attempt_exists.sum(dim=1).to(rewards.dtype)
    denom = attempt_counts.clamp(min=1.0)

    if weight_mode == "equal":
        w = 1.0 / denom
    elif weight_mode == "optimal":
        rho = success_counts / denom
        one_minus = 1.0 - rho + eps
        prod_all = one_minus.prod(dim=1, keepdim=True)
        w = prod_all / (one_minus)
        w = w / denom
    else:
        raise ValueError(f"unknown weight_mode: {weight_mode}")
    return w


def _compute_attempt_weights(
    rewards: torch.Tensor,
    attempt_exists: torch.Tensor,
    weight_mode: str,
    eps: float,
) -> torch.Tensor:
    w = _compute_raw_attempt_weights(rewards, attempt_exists, weight_mode, eps)
    if weight_mode == "none":
        return w
    attempt_counts = attempt_exists.sum(dim=1).to(rewards.dtype)
    exists_mask = (attempt_counts > 0).to(rewards.dtype)
    mean_w = (w * exists_mask).sum(dim=1, keepdim=True) / exists_mask.sum(dim=1, keepdim=True).clamp(min=1.0)
    # print(rewards.shape, attempt_exists.shape, mean_w.shape)
    # warn if mean_w is too small for any batch entry
    # bad_mask = (mean_w < eps).squeeze(1)
    # if bad_mask.any():
    #     idxs = torch.nonzero(bad_mask, as_tuple=False).squeeze(1).tolist()
    #     for i in idxs:
    #         r = rewards[i].detach().cpu().tolist()
    #         a = attempt_exists[i].detach().cpu().tolist()
    #         e = exists_mask[i].detach().cpu().tolist()
    #         print(rewards.shape, attempt_exists.shape, mean_w.shape)
    #         print(f"mean_w[{i}]={float(mean_w[i].item()):.6e} < eps={eps:.6e} -> rewards={r}, attempt_exists={a}, exists={e}")
    return w / mean_w


def _format_trajectory(tokens: list[int]) -> str:
    if not tokens:
        return "[]"
    return "[" + " ".join(str(int(t)) for t in tokens) + "]"


def _evaluate_vwgrpo(
    model: GPT,
    eval_start_count: int,
    max_len: int,
    verify_k: int,
    weight_mode: str,
    weight_eps: float,
    retry_id: int,
    pad_id: int,
    final_state: int,
    n_states: int,
    last_hub_base: int,
    allowed: torch.Tensor,
    trap_candidates: torch.Tensor,
    shortest_steps_by_trap: torch.Tensor,
    temperature: float,
    top_k: int,
    amp: bool,
    device: torch.device,
) -> tuple[float, float, float, list[int], list[int], torch.Tensor, torch.Tensor, torch.Tensor]:
    model.eval()
    with torch.no_grad():
        starts = torch.randint(0, last_hub_base, (eval_start_count,), device=device)
        _, _, _, _, _, attempt_rewards, trials_used, _ = _sample_verification_weighted_trajectories(
            model,
            None,
            starts,
            max_len,
            verify_k,
            retry_id,
            pad_id,
            final_state,
            n_states,
            allowed,
            trap_candidates,
            shortest_steps_by_trap,
            temperature,
            top_k,
            amp,
        )
    total = attempt_rewards.size(0)
    attempt_order = torch.arange(verify_k, device=device).unsqueeze(0)
    attempt_exists = attempt_order < trials_used.unsqueeze(1)
    rewards = attempt_rewards.view(total, 1, verify_k)
    attempt_exists_group = attempt_exists.view(total, 1, verify_k)
    success_counts = rewards.sum(dim=1)
    attempt_counts = attempt_exists_group.sum(dim=1).to(rewards.dtype)
    rho = success_counts / attempt_counts.clamp(min=1.0)
    wraw_group = _compute_raw_attempt_weights(rewards, attempt_exists_group, weight_mode, weight_eps)
    wtilde_group = _compute_attempt_weights(rewards, attempt_exists_group, weight_mode, weight_eps)
    group_has_all = (attempt_counts > 0).all(dim=1)
    group_mask = group_has_all.unsqueeze(1).to(rewards.dtype)
    wraw_sum = (wraw_group * group_mask).sum(dim=0)
    w_sum = (wtilde_group * group_mask).sum(dim=0)
    rho_sum = (rho * group_mask).sum(dim=0)
    denom = group_has_all.sum().to(rewards.dtype).clamp(min=1.0)
    wraw_mean = wraw_sum / denom
    w_mean = w_sum / denom
    rho_mean = rho_sum / denom
    mean_reward = attempt_rewards.mean().item()
    success_rate = (attempt_rewards.sum(dim=1) > 0).float().mean().item()
    avg_trials = trials_used.float().mean().item()
    total_attempts = [int((trials_used >= (i + 1)).sum().item()) for i in range(verify_k)]
    correct_attempts = [int(attempt_rewards[:, i].sum().item()) for i in range(verify_k)]
    model.train()
    return (
        mean_reward,
        success_rate,
        avg_trials,
        total_attempts,
        correct_attempts,
        wraw_mean,
        w_mean,
        rho_mean,
    )


def train(args: argparse.Namespace) -> None:
    torch.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    if device.type == "cuda":
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    ckpt = torch.load(args.init_ckpt, map_location="cpu")
    cfg = GPTConfig(**ckpt["config"])
    model = GPT(cfg)
    model.load_state_dict(ckpt["model"])
    model.to(device)
    model.train()

    ref_model = None
    if args.kl_coef > 0:
        ref_model = GPT(cfg)
        ref_model.load_state_dict(ckpt["model"])
        ref_model.to(device)
        ref_model.eval()
        for p in ref_model.parameters():
            p.requires_grad_(False)

    ds = ckpt.get("ds", {})
    args.n_hubs = args.n_hubs or ds.get("n_hubs", args.n_hubs)
    args.m = args.m or ds.get("m", args.m)
    if not args.n_hubs or not args.m:
        raise ValueError("n_hubs and m must be set either in checkpoint or via args.")
    if args.n_hubs <= 1:
        raise ValueError("n_hubs must be > 1 for trap variant.")

    n_states, next_states, _ = build_markov_transitions(args.n_hubs, args.m)
    retry_id = int(ds.get("retry_id", n_states))
    pad_id = int(ds.get("pad_id", n_states + 1))
    vocab_size = int(ds.get("vocab_size", n_states + 2))
    if cfg.vocab_size != vocab_size:
        raise ValueError(f"checkpoint vocab_size={cfg.vocab_size} != expected {vocab_size}")

    final_state = n_states - 1
    last_hub_base = (args.n_hubs - 1) * args.m
    max_len = args.max_path_len or (args.n_hubs * args.m)
    max_len = max(2, min(max_len, cfg.block_size))
    max_total_len = args.verify_k * max_len + (args.verify_k - 1)
    if max_total_len > cfg.block_size:
        raise ValueError(
            f"block_size={cfg.block_size} too small for verify_k={args.verify_k} and max_path_len={max_len} "
            f"(need >= {max_total_len})."
        )

    allowed = _build_allowed_matrix(next_states, n_states, device)
    trap_x_m1, trap_x_m2 = _get_trap_candidates(args.n_hubs, args.m)
    shortest_steps_m1 = _compute_shortest_steps(next_states, n_states, trap_state=trap_x_m1)
    shortest_steps_m2 = _compute_shortest_steps(next_states, n_states, trap_state=trap_x_m2)
    shortest_steps_by_trap = torch.stack(
        [
            torch.tensor(shortest_steps_m1, device=device, dtype=torch.long),
            torch.tensor(shortest_steps_m2, device=device, dtype=torch.long),
        ],
        dim=0,
    )
    trap_candidates = torch.tensor([trap_x_m1, trap_x_m2], device=device, dtype=torch.long)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay,
        betas=(0.9, 0.95),
    )
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda" and args.amp))

    os.makedirs(args.save_dir, exist_ok=True)
    run_tag = _build_run_tag(args, cfg)

    best_reward = -float("inf")
    t0 = time.time()

    eval_start_count = n_states * args.eval_start_mult
    eval_log_path = args.eval_log or os.path.join(args.save_dir, f"eval_vwgrpo_{run_tag}.csv")
    eval_log_f = open(eval_log_path, "w", encoding="ascii")
    eval_header = ["step", "mean_reward", "success_rate", "avg_trials"]
    eval_header.extend([f"total_attempt{i}" for i in range(1, args.verify_k + 1)])
    eval_header.extend([f"correct_attempt{i}" for i in range(1, args.verify_k + 1)])
    eval_header.extend([f"w{i + 1}" for i in range(args.verify_k)])
    eval_header.extend([f"w_raw{i + 1}" for i in range(args.verify_k)])
    eval_header.extend([f"rho{i + 1}" for i in range(args.verify_k)])
    eval_header.append("phase")
    eval_log_f.write(",".join(eval_header) + "\n")
    eval_log_f.flush()
    eval_log_buffer: list[str] = []
    log_flush_every = 200
    last_log_flush = 0

    w_sum_accum = torch.zeros(args.verify_k, device=device, dtype=torch.float32)
    w_count_accum = torch.zeros(args.verify_k, device=device, dtype=torch.float32)
    wraw_sum_accum = torch.zeros(args.verify_k, device=device, dtype=torch.float32)
    wraw_count_accum = torch.zeros(args.verify_k, device=device, dtype=torch.float32)
    rho_sum_accum = torch.zeros(args.verify_k, device=device, dtype=torch.float32)
    n_sum_accum = torch.zeros(args.verify_k, device=device, dtype=torch.float32)
    rho_count_accum = torch.zeros(args.verify_k, device=device, dtype=torch.float32)

    try:
        for step in range(args.steps):
            starts = torch.randint(0, last_hub_base, (args.batch_size,), device=device)
            starts = starts.repeat_interleave(args.group_size)

            tokens, log_probs, action_mask, log_probs_ref, attempt_ids, attempt_rewards, trials_used, trap_states = (
                _sample_verification_weighted_trajectories(
                    model,
                    ref_model,
                    starts,
                    max_len,
                    args.verify_k,
                    retry_id,
                    pad_id,
                    final_state,
                    n_states,
                    allowed,
                    trap_candidates,
                    shortest_steps_by_trap,
                    args.temperature,
                    args.top_k,
                    args.amp and device.type == "cuda",
                )
            )

            retry_prev = tokens[:, :-1] == retry_id
            action_mask = action_mask * (~retry_prev).float()
            attempt_ids = attempt_ids * (~retry_prev).long()
            log_probs = log_probs * action_mask
            if log_probs_ref is not None:
                log_probs_ref = log_probs_ref * action_mask

            log_probs_old = log_probs.detach()
            log_probs_new = _compute_log_probs(model, tokens, args.amp and device.type == "cuda")
            log_probs_new = log_probs_new * action_mask

            ratio = torch.ones_like(log_probs_new)
            ratio = torch.where(action_mask > 0, torch.exp(log_probs_new - log_probs_old), ratio)
            ratio_clipped = torch.clamp(ratio, 1.0 - args.clip_eps, 1.0 + args.clip_eps)

            attempt_order = torch.arange(args.verify_k, device=device).unsqueeze(0)
            attempt_exists = attempt_order < trials_used.unsqueeze(1)

            rewards = attempt_rewards.view(args.batch_size, args.group_size, args.verify_k)
            attempt_exists_group = attempt_exists.view(args.batch_size, args.group_size, args.verify_k)
            exists_f = attempt_exists_group.float()
            counts = exists_f.sum(dim=1, keepdim=True).clamp(min=1.0)
            mean = (rewards * exists_f).sum(dim=1, keepdim=True) / counts
            var = ((rewards - mean) * exists_f).pow(2).sum(dim=1, keepdim=True) / counts
            std = torch.sqrt(var)
            adv = (rewards - mean) #TODO/ (std + args.adv_eps)
            adv = adv * exists_f
            adv = adv.view(-1, args.verify_k).detach()

            success_counts = rewards.sum(dim=1)
            attempt_counts = attempt_exists_group.sum(dim=1).to(rewards.dtype)
            rho = success_counts / attempt_counts.clamp(min=1.0)

            wraw_group = _compute_raw_attempt_weights(rewards, attempt_exists_group, args.weight_mode, args.weight_eps)
            wtilde_group = _compute_attempt_weights(rewards, attempt_exists_group, args.weight_mode, args.weight_eps)
            group_has_all = (attempt_counts > 0).all(dim=1)
            if group_has_all.any():
                group_mask = group_has_all.unsqueeze(1).to(rewards.dtype)
                wraw_sum_accum = wraw_sum_accum + (wraw_group * group_mask).sum(dim=0)
                w_sum_accum = w_sum_accum + (wtilde_group * group_mask).sum(dim=0)
                rho_sum_accum = rho_sum_accum + (rho * group_mask).sum(dim=0)
                n_sum_accum = n_sum_accum + (attempt_counts * group_mask).sum(dim=0)
                count_inc = group_has_all.sum().to(rewards.dtype)
                w_count_accum = w_count_accum + count_inc
                wraw_count_accum = wraw_count_accum + count_inc
                rho_count_accum = rho_count_accum + count_inc
            wtilde = wtilde_group.unsqueeze(1).expand(-1, args.group_size, -1).reshape(-1, args.verify_k)

            attempt_idx = (attempt_ids - 1).clamp(min=0)
            token_mask = (attempt_ids > 0) & (action_mask > 0)
            token_mask_f = token_mask.float()

            adv_token = adv.gather(1, attempt_idx) * token_mask_f
            wtilde_token = wtilde.gather(1, attempt_idx) * token_mask_f

            surr1 = ratio * adv_token
            surr2 = ratio_clipped * adv_token
            obj = torch.minimum(surr1, surr2) * wtilde_token

            token_counts = action_mask.sum(dim=1).clamp(min=1.0)
            loss = -(obj * action_mask).sum(dim=1) / token_counts
            loss = loss.mean()
            if log_probs_ref is not None:
                kl_token = (log_probs_new - log_probs_ref) * wtilde_token
                kl = (kl_token * action_mask).sum(dim=1) / token_counts
                kl_mean = kl.mean()
                loss = loss + args.kl_coef * kl_mean

            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            if args.grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            scaler.step(optimizer)
            scaler.update()

            mean_reward = rewards.mean().item()
            success_rate = (attempt_rewards.sum(dim=1) > 0).float().mean().item()
            avg_trials = trials_used.float().mean().item()
            if step % args.log_every == 0:
                dt = time.time() - t0
                kl_val = 0.0
                if log_probs_ref is not None:
                    kl_val = float(kl_mean.mean().item())
                w_raw_mean = wraw_sum_accum / wraw_count_accum.clamp(min=1.0)
                w_mean = w_sum_accum / w_count_accum.clamp(min=1.0)
                rho_mean = rho_sum_accum / rho_count_accum.clamp(min=1.0)
                n_mean = n_sum_accum / rho_count_accum.clamp(min=1.0)
                w_raw_vals = w_raw_mean.detach().cpu().tolist()
                w_vals = w_mean.detach().cpu().tolist()
                rho_vals = rho_mean.detach().cpu().tolist()
                n_vals = n_mean.detach().cpu().tolist()
                w_str = " ".join(f"w{i + 1}={w_vals[i]:.3f}" for i in range(args.verify_k))
                rho_str = " ".join(f"rho{i + 1}={rho_vals[i]:.3f}" for i in range(args.verify_k))
                n_str = " ".join(f"N{i + 1}={n_vals[i]:.2f}" for i in range(args.verify_k))
                print(
                    f"step={step} loss={loss.item():.4f} reward={mean_reward:.3f} "
                    f"success={success_rate:.3f} trials={avg_trials:.2f} kl={kl_val:.4f} "
                    f"{w_str} {rho_str} {n_str} dt={dt:.2f}s"
                )
                total_attempts = [int((trials_used >= (i + 1)).sum().item()) for i in range(args.verify_k)]
                correct_attempts = [int(attempt_rewards[:, i].sum().item()) for i in range(args.verify_k)]
                total_csv = ",".join(str(v) for v in total_attempts)
                correct_csv = ",".join(str(v) for v in correct_attempts)
                w_csv = ",".join(f"{v:.6f}" for v in w_vals)
                w_raw_csv = ",".join(f"{v:.6f}" for v in w_raw_vals)
                rho_csv = ",".join(f"{v:.6f}" for v in rho_vals)
                eval_log_buffer.append(
                    f"{step},{mean_reward:.6f},{success_rate:.6f},{avg_trials:.6f},"
                    f"{total_csv},{correct_csv},{w_csv},{w_raw_csv},{rho_csv},train\n"
                )
                t0 = time.time()
                w_sum_accum.zero_()
                w_count_accum.zero_()
                wraw_sum_accum.zero_()
                wraw_count_accum.zero_()
                rho_sum_accum.zero_()
                n_sum_accum.zero_()
                rho_count_accum.zero_()

            if args.eval_every > 0 and step % args.eval_every == 0:
                (
                    eval_reward,
                    eval_success,
                    eval_trials,
                    eval_totals,
                    eval_corrects,
                    eval_w_raw,
                    eval_w,
                    eval_rho,
                ) = _evaluate_vwgrpo(
                    model,
                    eval_start_count,
                    max_len,
                    args.verify_k,
                    args.weight_mode,
                    args.weight_eps,
                    retry_id,
                    pad_id,
                    final_state,
                    n_states,
                    last_hub_base,
                    allowed,
                    trap_candidates,
                    shortest_steps_by_trap,
                    args.temperature,
                    args.top_k,
                    args.amp and device.type == "cuda",
                    device,
                )
                w_raw_vals = eval_w_raw.detach().cpu().tolist()
                w_vals = eval_w.detach().cpu().tolist()
                rho_vals = eval_rho.detach().cpu().tolist()
                totals_csv = ",".join(str(v) for v in eval_totals)
                corrects_csv = ",".join(str(v) for v in eval_corrects)
                w_str = ",".join(f"{v:.6f}" for v in w_vals)
                w_raw_str = ",".join(f"{v:.6f}" for v in w_raw_vals)
                rho_str = ",".join(f"{v:.6f}" for v in rho_vals)
                eval_log_buffer.append(
                    f"{step},{eval_reward:.6f},{eval_success:.6f},{eval_trials:.6f},"
                    f"{totals_csv},{corrects_csv},{w_str},{w_raw_str},{rho_str},eval\n"
                )
                print(
                    f"eval step={step} reward={eval_reward:.3f} "
                    f"success={eval_success:.3f} trials={eval_trials:.2f}"
                )

            if eval_log_buffer and (step - last_log_flush) >= log_flush_every:
                eval_log_f.write("".join(eval_log_buffer))
                eval_log_f.flush()
                eval_log_buffer.clear()
                last_log_flush = step

            if args.save_best:
                if mean_reward > best_reward:
                    best_reward = mean_reward
                    best_path = os.path.join(args.save_dir, f"best_vwgrpo_{run_tag}.pt")
                    torch.save(
                        {
                            "model": model.state_dict(),
                            "config": cfg.__dict__,
                            "ds": {
                                "n_hubs": args.n_hubs,
                                "m": args.m,
                                "retry_id": retry_id,
                                "pad_id": pad_id,
                                "vocab_size": vocab_size,
                            },
                            "vwgrpo": {
                                "step": step,
                                "mean_reward": best_reward,
                                "max_len": max_len,
                                "group_size": args.group_size,
                                "verify_k": args.verify_k,
                                "weight_mode": args.weight_mode,
                                "temperature": args.temperature,
                                "top_k": args.top_k,
                                "kl_coef": args.kl_coef,
                            },
                        },
                        best_path,
                    )
                    print(f"saved best: {best_path} (reward={best_reward:.3f})")

            if args.save_every and step > 0 and (step % args.save_every == 0):
                ckpt_path = os.path.join(args.save_dir, f"ckpt_vwgrpo_{run_tag}_step{step}.pt")
                torch.save(
                    {
                        "model": model.state_dict(),
                        "config": cfg.__dict__,
                        "ds": {
                            "n_hubs": args.n_hubs,
                            "m": args.m,
                            "retry_id": retry_id,
                            "pad_id": pad_id,
                            "vocab_size": vocab_size,
                        },
                        "vwgrpo": {
                            "step": step,
                            "mean_reward": mean_reward,
                            "max_len": max_len,
                            "group_size": args.group_size,
                            "verify_k": args.verify_k,
                            "weight_mode": args.weight_mode,
                            "temperature": args.temperature,
                            "top_k": args.top_k,
                            "kl_coef": args.kl_coef,
                        },
                    },
                    ckpt_path,
                )
                print(f"saved: {ckpt_path}")
    finally:
        if eval_log_buffer:
            eval_log_f.write("".join(eval_log_buffer))
            eval_log_f.flush()
        eval_log_f.close()


def main() -> None:
    p = argparse.ArgumentParser()
    p.add_argument("--init_ckpt", type=str, required=True)
    p.add_argument("--save_dir", type=str, default="checkpoints")
    p.add_argument("--save_every", type=int, default=0)
    p.add_argument("--save_best", action="store_true")

    p.add_argument("--n_hubs", type=int, default=0)
    p.add_argument("--m", type=int, default=0)
    p.add_argument("--max_path_len", type=int, default=0)
    p.add_argument("--verify_k", type=int, default=1)

    p.add_argument("--weight_mode", type=str, default="equal", choices=["equal", "none", "optimal"])
    p.add_argument("--weight_eps", type=float, default=1e-8)
    p.add_argument("--eval_log", type=str, default="")
    p.add_argument("--eval_every", type=int, default=40)
    p.add_argument("--eval_start_mult", type=int, default=5)

    p.add_argument("--batch_size", type=int, default=32)
    p.add_argument("--group_size", type=int, default=4)
    p.add_argument("--steps", type=int, default=1000)
    p.add_argument("--lr", type=float, default=1e-5)
    p.add_argument("--weight_decay", type=float, default=0.0)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--kl_coef", type=float, default=0.02)
    p.add_argument("--clip_eps", type=float, default=0.2)
    p.add_argument("--adv_eps", type=float, default=1e-5)

    p.add_argument("--temperature", type=float, default=1.0)
    p.add_argument("--top_k", type=int, default=0)

    p.add_argument("--log_every", type=int, default=20)
    p.add_argument("--amp", action="store_true")
    p.add_argument("--cpu", action="store_true")
    p.add_argument("--seed", type=int, default=1337)

    args = p.parse_args()
    if args.verify_k < 1:
        raise ValueError("verify_k must be >= 1.")
    if args.eval_start_mult < 1:
        raise ValueError("eval_start_mult must be >= 1.")
    train(args)


if __name__ == "__main__":
    main()
