# verification_reinforce_grpo.py
from __future__ import annotations
import os
import time
import argparse

import torch
import torch.nn.functional as F

from markov_chain import build_markov_transitions
from model import GPT, GPTConfig


def _build_allowed_matrix(next_states, n_states: int, device: torch.device) -> torch.Tensor:
    allowed = torch.zeros((n_states, n_states), dtype=torch.bool, device=device)
    for s, ns in enumerate(next_states):
        ns_t = torch.as_tensor(ns, device=device, dtype=torch.long)
        allowed[s, ns_t] = True
    return allowed


def _build_run_tag(args: argparse.Namespace, cfg: GPTConfig) -> str:
    return (
        f"h{args.n_hubs}_m{args.m}_k{args.verify_k}_emb{cfg.n_embd}_l{cfg.n_layer}_head{cfg.n_head}"
        f"_bs{cfg.block_size}"
    )


def _compute_shortest_steps(next_states, n_states: int) -> list[int]:
    rev: list[list[int]] = [[] for _ in range(n_states)]
    for s, ns in enumerate(next_states):
        for t in ns.tolist():
            if t >= s:
                rev[int(t)].append(s)

    dist = [-1] * n_states
    final_state = n_states - 1
    dist[final_state] = 0

    from collections import deque

    q = deque([final_state])
    while q:
        cur = q.popleft()
        for prev in rev[cur]:
            if dist[prev] == -1:
                dist[prev] = dist[cur] + 1
                q.append(prev)
    return dist


def _compute_trial_success(
    tokens: torch.Tensor,
    n_states: int,
    final_state: int,
    retry_id: int,
    pad_id: int,
    allowed: torch.Tensor,
    shortest_steps: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    device = tokens.device
    total, T = tokens.shape

    terminal = (tokens == final_state) | (tokens == retry_id) | (tokens == pad_id) | (tokens >= n_states)
    term_any = terminal.any(dim=1)
    term_idx = terminal.float().argmax(dim=1)
    term_idx = torch.where(term_any, term_idx, torch.full_like(term_idx, T - 1))
    term_token = tokens[torch.arange(total, device=device), term_idx]

    lengths = term_idx + 1
    hit_final = term_token == final_state

    prev = tokens[:, :-1]
    nxt = tokens[:, 1:]
    pos = torch.arange(T - 1, device=device)
    valid_pos = pos.unsqueeze(0) < (lengths - 1).unsqueeze(1)
    state_mask = (prev < n_states) & (nxt < n_states) & valid_pos

    decrease = (nxt < prev) & state_mask
    has_decrease = decrease.any(dim=1)

    allowed_lookup = allowed[
        prev.clamp(0, n_states - 1),
        nxt.clamp(0, n_states - 1),
    ]
    invalid = (~allowed_lookup) & state_mask
    has_invalid = invalid.any(dim=1)

    start = tokens[:, 0].clamp(0, n_states - 1)
    shortest = shortest_steps[start]
    is_shortest = (shortest >= 0) & ((lengths - 1) == shortest)
    success = hit_final & ~has_decrease & ~has_invalid & is_shortest
    return success, lengths


def _compute_log_probs(model: GPT, tokens: torch.Tensor, amp: bool) -> torch.Tensor:
    x = tokens[:, :-1]
    with torch.cuda.amp.autocast(enabled=amp):
        logits = model(x)
    logits = logits.float()
    logp = F.log_softmax(logits, dim=-1)
    next_tokens = tokens[:, 1:]
    return torch.gather(logp, 2, next_tokens.unsqueeze(-1)).squeeze(-1)


def _collect_example_trajectories(
    tokens: torch.Tensor,
    pad_id: int,
    retry_id: int,
    max_show: int = 5,
) -> list[str]:
    tokens_cpu = tokens.detach().cpu()
    seen: set[tuple[int, ...]] = set()
    examples: list[str] = []
    for row in tokens_cpu:
        row_list = row.tolist()
        last = len(row_list) - 1
        while last > 0 and row_list[last] == pad_id:
            last -= 1
        trimmed = row_list[: last + 1]
        key = tuple(trimmed)
        if key in seen:
            continue
        seen.add(key)
        parts: list[str] = []
        for tok in trimmed:
            if tok == retry_id:
                parts.append("|R|")
            elif tok == pad_id:
                parts.append("PAD")
            else:
                parts.append(str(tok))
        examples.append(" ".join(parts))
        if len(examples) >= max_show:
            break
    return examples


def _sample_verification_trajectories(
    model: GPT,
    ref_model: GPT | None,
    starts: torch.Tensor,
    max_len: int,
    verify_k: int,
    retry_id: int,
    pad_id: int,
    final_state: int,
    n_states: int,
    allowed: torch.Tensor,
    shortest_steps: torch.Tensor,
    temperature: float,
    top_k: int,
    amp: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor]:
    device = starts.device
    total = starts.size(0)
    max_total_len = verify_k * max_len + (verify_k - 1)

    tokens = torch.full((total, max_total_len), pad_id, device=device, dtype=torch.long)
    tokens[:, 0] = starts

    trial_tokens = torch.full((total, max_len), pad_id, device=device, dtype=torch.long)
    trial_tokens[:, 0] = starts
    trial_lengths = torch.ones(total, device=device, dtype=torch.long)

    active = torch.ones(total, device=device, dtype=torch.bool)
    success = torch.zeros(total, device=device, dtype=torch.bool)
    trial_idx = torch.zeros(total, device=device, dtype=torch.long)
    need_retry = torch.zeros(total, device=device, dtype=torch.bool)
    need_start = torch.zeros(total, device=device, dtype=torch.bool)

    logp_steps: list[torch.Tensor] = []
    action_steps: list[torch.Tensor] = []
    logp_ref_steps: list[torch.Tensor] | None = [] if ref_model is not None else None

    for t in range(1, max_total_len):
        step_logp = torch.zeros(total, device=device, dtype=torch.float32)
        step_mask = torch.zeros(total, device=device, dtype=torch.float32)
        step_logp_ref = (
            torch.zeros(total, device=device, dtype=torch.float32) if ref_model is not None else None
        )

        if active.any():
            # Compute before mutating need_* so injected retry/start aren't overwritten this step.
            model_mask = active & ~(need_retry | need_start)
            retry_mask = active & need_retry
            if retry_mask.any():
                tokens[retry_mask, t] = retry_id
                need_retry[retry_mask] = False
                need_start[retry_mask] = True

            start_mask = active & need_start & ~retry_mask
            if start_mask.any():
                tokens[start_mask, t] = starts[start_mask]
                need_start[start_mask] = False
                trial_tokens[start_mask] = pad_id
                trial_tokens[start_mask, 0] = starts[start_mask]
                trial_lengths[start_mask] = 1

            if model_mask.any():
                x = tokens[:, :t].clone()
                with torch.cuda.amp.autocast(enabled=amp):
                    logits = model(x)[:, -1, :]
                logits = logits.float()

                if temperature != 1.0:
                    logits = logits / temperature
                if top_k > 0:
                    v, _ = torch.topk(logits, top_k)
                    kth = v[:, -1].unsqueeze(-1)
                    logits = torch.where(logits < kth, torch.full_like(logits, -float("inf")), logits)

                logp = F.log_softmax(logits, dim=-1)
                probs = torch.exp(logp)
                next_token = torch.multinomial(probs, num_samples=1).squeeze(1)

                tokens[model_mask, t] = next_token[model_mask]
                gathered = torch.gather(logp, 1, next_token.unsqueeze(1)).squeeze(1)
                step_logp[model_mask] = gathered[model_mask]
                step_mask[model_mask] = 1.0

                if ref_model is not None:
                    with torch.no_grad():
                        with torch.cuda.amp.autocast(enabled=amp):
                            ref_logits = ref_model(x)[:, -1, :]
                        ref_logits = ref_logits.float()
                        if temperature != 1.0:
                            ref_logits = ref_logits / temperature
                        if top_k > 0:
                            v, _ = torch.topk(ref_logits, top_k)
                            kth = v[:, -1].unsqueeze(-1)
                            ref_logits = torch.where(
                                ref_logits < kth, torch.full_like(ref_logits, -float("inf")), ref_logits
                            )
                        ref_logp = F.log_softmax(ref_logits, dim=-1)
                        ref_gathered = torch.gather(ref_logp, 1, next_token.unsqueeze(1)).squeeze(1)
                        step_logp_ref[model_mask] = ref_gathered[model_mask]

                idx = torch.nonzero(model_mask, as_tuple=False).squeeze(1)
                pos = trial_lengths[idx]
                trial_tokens[idx, pos] = next_token[idx]
                trial_lengths[idx] = trial_lengths[idx] + 1

                is_final = next_token == final_state
                is_retry = next_token == retry_id
                is_pad = next_token == pad_id
                is_invalid = next_token >= n_states
                hit_len = trial_lengths >= max_len
                end_mask = model_mask & (is_final | is_retry | is_pad | is_invalid | hit_len)

                if end_mask.any():
                    ended_idx = torch.nonzero(end_mask, as_tuple=False).squeeze(1)
                    with torch.no_grad():
                        trial_success, _ = _compute_trial_success(
                            trial_tokens[ended_idx],
                            n_states,
                            final_state,
                            retry_id,
                            pad_id,
                            allowed,
                            shortest_steps,
                        )
                    success[ended_idx] |= trial_success

                    done_success = ended_idx[trial_success]
                    if done_success.numel() > 0:
                        active[done_success] = False
                        need_retry[done_success] = False
                        need_start[done_success] = False

                    fail_mask = ~trial_success
                    if fail_mask.any():
                        fail_idx = ended_idx[fail_mask]
                        last_trial = trial_idx[fail_idx] >= (verify_k - 1)
                        if last_trial.any():
                            done_fail = fail_idx[last_trial]
                            active[done_fail] = False
                            need_retry[done_fail] = False
                            need_start[done_fail] = False
                        cont_mask = ~last_trial
                        if cont_mask.any():
                            cont_idx = fail_idx[cont_mask]
                            trial_idx[cont_idx] = trial_idx[cont_idx] + 1
                            ended_by_retry = is_retry[cont_idx]
                            if ended_by_retry.any():
                                need_start[cont_idx[ended_by_retry]] = True
                            if (~ended_by_retry).any():
                                need_retry[cont_idx[~ended_by_retry]] = True

        logp_steps.append(step_logp)
        action_steps.append(step_mask)
        if logp_ref_steps is not None:
            logp_ref_steps.append(step_logp_ref)

        if not active.any():
            break

    if logp_steps:
        log_probs = torch.stack(logp_steps, dim=1)
        action_mask = torch.stack(action_steps, dim=1)
    else:
        log_probs = torch.zeros((total, 0), device=device, dtype=torch.float32)
        action_mask = torch.zeros((total, 0), device=device, dtype=torch.float32)

    steps = log_probs.size(1)
    if steps < max_total_len - 1:
        pad = torch.zeros((total, max_total_len - 1 - steps), device=device, dtype=log_probs.dtype)
        log_probs = torch.cat([log_probs, pad], dim=1)
        action_mask = torch.cat([action_mask, pad], dim=1)

    log_probs_ref = None
    if logp_ref_steps is not None:
        if logp_ref_steps:
            log_probs_ref = torch.stack(logp_ref_steps, dim=1)
        else:
            log_probs_ref = torch.zeros((total, 0), device=device, dtype=torch.float32)
        steps_ref = log_probs_ref.size(1)
        if steps_ref < max_total_len - 1:
            pad_ref = torch.zeros((total, max_total_len - 1 - steps_ref), device=device, dtype=log_probs_ref.dtype)
            log_probs_ref = torch.cat([log_probs_ref, pad_ref], dim=1)

    trials_used = trial_idx + 1
    return tokens, log_probs, action_mask, log_probs_ref, success, trials_used


def _evaluate_vrgrpo(
    model: GPT,
    eval_start_count: int,
    max_len: int,
    verify_k: int,
    retry_id: int,
    pad_id: int,
    final_state: int,
    n_states: int,
    allowed: torch.Tensor,
    shortest_steps: torch.Tensor,
    temperature: float,
    top_k: int,
    amp: bool,
    device: torch.device,
) -> tuple[float, float, float, list[int], list[int]]:
    model.eval()
    with torch.no_grad():
        starts = torch.randint(0, n_states, (eval_start_count,), device=device)
        _, _, _, _, success, trials_used = _sample_verification_trajectories(
            model,
            None,
            starts,
            max_len,
            verify_k,
            retry_id,
            pad_id,
            final_state,
            n_states,
            allowed,
            shortest_steps,
            temperature,
            top_k,
            amp,
        )
    mean_reward = success.float().mean().item()
    success_rate = success.float().mean().item()
    avg_trials = trials_used.float().mean().item()
    total_attempts = [
        int((trials_used >= (i + 1)).sum().item()) for i in range(verify_k)
    ]
    correct_attempts = [
        int((success & (trials_used == (i + 1))).sum().item()) for i in range(verify_k)
    ]
    model.train()
    return mean_reward, success_rate, avg_trials, total_attempts, correct_attempts


def train(args: argparse.Namespace) -> None:
    torch.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    if device.type == "cuda":
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    ckpt = torch.load(args.init_ckpt, map_location="cpu")
    cfg = GPTConfig(**ckpt["config"])
    model = GPT(cfg)
    model.load_state_dict(ckpt["model"])
    model.to(device)
    model.train()

    ref_model = None
    if args.kl_coef > 0:
        ref_model = GPT(cfg)
        ref_model.load_state_dict(ckpt["model"])
        ref_model.to(device)
        ref_model.eval()
        for p in ref_model.parameters():
            p.requires_grad_(False)

    ds = ckpt.get("ds", {})
    args.n_hubs = args.n_hubs or ds.get("n_hubs", args.n_hubs)
    args.m = args.m or ds.get("m", args.m)
    if not args.n_hubs or not args.m:
        raise ValueError("n_hubs and m must be set either in checkpoint or via args.")

    n_states, next_states, _ = build_markov_transitions(args.n_hubs, args.m)
    retry_id = int(ds.get("retry_id", n_states))
    pad_id = int(ds.get("pad_id", n_states + 1))
    vocab_size = int(ds.get("vocab_size", n_states + 2))
    if cfg.vocab_size != vocab_size:
        raise ValueError(f"checkpoint vocab_size={cfg.vocab_size} != expected {vocab_size}")

    final_state = n_states - 1
    max_len = args.max_path_len or (args.n_hubs * args.m)
    max_len = max(2, min(max_len, cfg.block_size))
    max_total_len = args.verify_k * max_len + (args.verify_k - 1)
    if max_total_len > cfg.block_size:
        raise ValueError(
            f"block_size={cfg.block_size} too small for verify_k={args.verify_k} and max_path_len={max_len} "
            f"(need >= {max_total_len})."
        )

    allowed = _build_allowed_matrix(next_states, n_states, device)
    shortest_steps = _compute_shortest_steps(next_states, n_states)
    shortest_steps_t = torch.tensor(shortest_steps, device=device, dtype=torch.long)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay,
        betas=(0.9, 0.95),
    )
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda" and args.amp))

    os.makedirs(args.save_dir, exist_ok=True)
    run_tag = _build_run_tag(args, cfg)

    best_reward = -float("inf")
    t0 = time.time()

    eval_start_count = n_states * args.eval_start_mult
    eval_log_path = args.eval_log or os.path.join(args.save_dir, f"eval_vrgrpo_{run_tag}.csv")
    eval_log_f = open(eval_log_path, "w", encoding="ascii")
    attempt_headers = [f"total_attempt{i}" for i in range(1, args.verify_k + 1)]
    correct_headers = [f"correct_attempt{i}" for i in range(1, args.verify_k + 1)]
    eval_log_f.write(
        "step,mean_reward,success_rate,avg_trials,"
        + ",".join(attempt_headers + correct_headers)
        + "\n"
    )
    eval_log_f.flush()

    try:
        for step in range(args.steps):
            starts = torch.randint(0, n_states, (args.batch_size,), device=device)
            starts = starts.repeat_interleave(args.group_size)

            tokens, log_probs, action_mask, log_probs_ref, success, trials_used = (
                _sample_verification_trajectories(
                    model,
                    ref_model,
                    starts,
                    max_len,
                    args.verify_k,
                    retry_id,
                    pad_id,
                    final_state,
                    n_states,
                    allowed,
                    shortest_steps_t,
                    args.temperature,
                    args.top_k,
                    args.amp and device.type == "cuda",
                )
            )

            rewards = success.float()
            rewards = rewards.view(args.batch_size, args.group_size)
            mean = rewards.mean(dim=1, keepdim=True)
            std = rewards.std(dim=1, keepdim=True)
            adv = (rewards - mean) / (std + args.adv_eps)
            adv = adv.view(-1).detach()

            log_probs_old = (log_probs * action_mask).detach()
            log_probs_new = _compute_log_probs(model, tokens, args.amp and device.type == "cuda")
            log_probs_new = log_probs_new * action_mask

            ratio = torch.ones_like(log_probs_new)
            ratio = torch.where(action_mask > 0, torch.exp(log_probs_new - log_probs_old), ratio)
            ratio_clipped = torch.clamp(ratio, 1.0 - args.clip_eps, 1.0 + args.clip_eps)

            adv_t = adv.view(-1, 1)
            surr1 = ratio * adv_t
            surr2 = ratio_clipped * adv_t
            obj = torch.minimum(surr1, surr2)

            token_counts = action_mask.sum(dim=1).clamp(min=1.0)
            loss = -(obj * action_mask).sum(dim=1) / token_counts
            loss = loss.mean()

            kl_mean = None
            if log_probs_ref is not None:
                log_probs_ref = log_probs_ref * action_mask
                kl = (log_probs_new - log_probs_ref) * action_mask
                kl = kl.sum(dim=1) / token_counts
                kl_mean = kl.mean()
                loss = loss + args.kl_coef * kl_mean

            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            if args.grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            scaler.step(optimizer)
            scaler.update()

            if step % args.log_every == 0:
                dt = time.time() - t0
                mean_reward = rewards.mean().item()
                success_rate = success.float().mean().item()
                avg_trials = trials_used.float().mean().item()
                kl_val = 0.0
                if kl_mean is not None:
                    kl_val = float(kl_mean.item())
                print(
                    f"step={step} loss={loss.item():.4f} reward={mean_reward:.3f} "
                    f"success={success_rate:.3f} trials={avg_trials:.2f} kl={kl_val:.4f} dt={dt:.2f}s"
                )
                t0 = time.time()

            if args.eval_every > 0 and step % args.eval_every == 0:
                (
                    eval_reward,
                    eval_success,
                    eval_trials,
                    eval_totals,
                    eval_corrects,
                ) = _evaluate_vrgrpo(
                    model,
                    eval_start_count,
                    max_len,
                    args.verify_k,
                    retry_id,
                    pad_id,
                    final_state,
                    n_states,
                    allowed,
                    shortest_steps_t,
                    args.temperature,
                    args.top_k,
                    args.amp and device.type == "cuda",
                    device,
                )
                eval_values = [
                    str(step),
                    f"{eval_reward:.6f}",
                    f"{eval_success:.6f}",
                    f"{eval_trials:.6f}",
                ] + [str(v) for v in (eval_totals + eval_corrects)]
                eval_log_f.write(",".join(eval_values) + "\n")
                eval_log_f.flush()
                print(
                    f"eval step={step} reward={eval_reward:.3f} "
                    f"success={eval_success:.3f} trials={eval_trials:.2f}"
                )

            if args.save_best:
                mean_reward = rewards.mean().item()
                if mean_reward > best_reward:
                    best_reward = mean_reward
                    best_path = os.path.join(args.save_dir, f"best_vrgrpo_{run_tag}.pt")
                    torch.save(
                        {
                            "model": model.state_dict(),
                            "config": cfg.__dict__,
                            "ds": {
                                "n_hubs": args.n_hubs,
                                "m": args.m,
                                "retry_id": retry_id,
                                "pad_id": pad_id,
                                "vocab_size": vocab_size,
                            },
                            "vrgrpo": {
                                "step": step,
                                "mean_reward": best_reward,
                                "max_len": max_len,
                                "group_size": args.group_size,
                                "verify_k": args.verify_k,
                                "temperature": args.temperature,
                                "top_k": args.top_k,
                                "kl_coef": args.kl_coef,
                            },
                        },
                        best_path,
                    )
                    print(f"saved best: {best_path} (reward={best_reward:.3f})")

            if args.save_every and step > 0 and (step % args.save_every == 0):
                ckpt_path = os.path.join(args.save_dir, f"ckpt_vrgrpo_{run_tag}_step{step}.pt")
                torch.save(
                    {
                        "model": model.state_dict(),
                        "config": cfg.__dict__,
                        "ds": {
                            "n_hubs": args.n_hubs,
                            "m": args.m,
                            "retry_id": retry_id,
                            "pad_id": pad_id,
                            "vocab_size": vocab_size,
                        },
                        "vrgrpo": {
                            "step": step,
                            "mean_reward": rewards.mean().item(),
                            "max_len": max_len,
                            "group_size": args.group_size,
                            "verify_k": args.verify_k,
                            "temperature": args.temperature,
                            "top_k": args.top_k,
                            "kl_coef": args.kl_coef,
                        },
                    },
                    ckpt_path,
                )
                print(f"saved: {ckpt_path}")
    finally:
        eval_log_f.close()


def main() -> None:
    p = argparse.ArgumentParser()
    p.add_argument("--init_ckpt", type=str, required=True)
    p.add_argument("--save_dir", type=str, default="checkpoints")
    p.add_argument("--save_every", type=int, default=0)
    p.add_argument("--save_best", action="store_true")

    p.add_argument("--n_hubs", type=int, default=0)
    p.add_argument("--m", type=int, default=0)
    p.add_argument("--max_path_len", type=int, default=0)
    p.add_argument("--verify_k", type=int, default=1)

    p.add_argument("--eval_log", type=str, default="")
    p.add_argument("--eval_every", type=int, default=40)
    p.add_argument("--eval_start_mult", type=int, default=5)

    p.add_argument("--batch_size", type=int, default=32)
    p.add_argument("--group_size", type=int, default=4)
    p.add_argument("--steps", type=int, default=1000)
    p.add_argument("--lr", type=float, default=1e-5)
    p.add_argument("--weight_decay", type=float, default=0.0)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--kl_coef", type=float, default=0.02)
    p.add_argument("--clip_eps", type=float, default=0.2)
    p.add_argument("--adv_eps", type=float, default=1e-5)

    p.add_argument("--temperature", type=float, default=1.0)
    p.add_argument("--top_k", type=int, default=0)

    p.add_argument("--log_every", type=int, default=20)
    p.add_argument("--amp", action="store_true")
    p.add_argument("--cpu", action="store_true")
    p.add_argument("--seed", type=int, default=1337)

    args = p.parse_args()
    if args.verify_k < 1:
        raise ValueError("verify_k must be >= 1.")
    if args.eval_start_mult < 1:
        raise ValueError("eval_start_mult must be >= 1.")
    train(args)


if __name__ == "__main__":
    main()
