# train_grpo.py
from __future__ import annotations
import os
import time
import argparse

import torch
import torch.nn.functional as F

from markov_chain import build_markov_transitions
from model import GPT, GPTConfig


def _build_allowed_matrix(next_states, n_states: int, device: torch.device) -> torch.Tensor:
    allowed = torch.zeros((n_states, n_states), dtype=torch.bool, device=device)
    for s, ns in enumerate(next_states):
        ns_t = torch.as_tensor(ns, device=device, dtype=torch.long)
        allowed[s, ns_t] = True
    return allowed


def _build_run_tag(args: argparse.Namespace, cfg: GPTConfig) -> str:
    return (
        f"h{args.n_hubs}_m{args.m}_emb{cfg.n_embd}_l{cfg.n_layer}_head{cfg.n_head}"
        f"_bs{cfg.block_size}"
    )


def _compute_shortest_steps(next_states, n_states: int) -> list[int]:
    rev: list[list[int]] = [[] for _ in range(n_states)]
    for s, ns in enumerate(next_states):
        for t in ns.tolist():
            if t >= s:
                rev[int(t)].append(s)

    dist = [-1] * n_states
    final_state = n_states - 1
    dist[final_state] = 0

    from collections import deque

    q = deque([final_state])
    while q:
        cur = q.popleft()
        for prev in rev[cur]:
            if dist[prev] == -1:
                dist[prev] = dist[cur] + 1
                q.append(prev)
    return dist


def _sample_trajectories(
    model: GPT,
    ref_model: GPT | None,
    starts: torch.Tensor,
    max_len: int,
    retry_id: int,
    pad_id: int,
    final_state: int,
    temperature: float,
    top_k: int,
    amp: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
    device = starts.device
    total = starts.size(0)

    tokens = torch.full((total, max_len), pad_id, device=device, dtype=torch.long)
    tokens[:, 0] = starts

    logp_steps: list[torch.Tensor] = []
    logp_ref_steps: list[torch.Tensor] | None = [] if ref_model is not None else None

    done = torch.zeros(total, device=device, dtype=torch.bool)

    for t in range(1, max_len):
        # Clone to avoid in-place updates to indices saved for backward.
        x = tokens[:, :t].clone()
        with torch.cuda.amp.autocast(enabled=amp):
            logits = model(x)[:, -1, :]
        logits = logits.float()

        if temperature != 1.0:
            logits = logits / temperature
        if top_k > 0:
            v, _ = torch.topk(logits, top_k)
            kth = v[:, -1].unsqueeze(-1)
            logits = torch.where(logits < kth, torch.full_like(logits, -float("inf")), logits)

        logp = F.log_softmax(logits, dim=-1)
        probs = torch.exp(logp)
        next_token = torch.multinomial(probs, num_samples=1).squeeze(1)

        next_token = torch.where(done, torch.full_like(next_token, pad_id), next_token)
        tokens[:, t] = next_token

        active = ~done
        gathered = torch.gather(logp, 1, next_token.unsqueeze(1)).squeeze(1)
        logp_steps.append(gathered)

        if ref_model is not None:
            with torch.no_grad():
                with torch.cuda.amp.autocast(enabled=amp):
                    ref_logits = ref_model(x)[:, -1, :]
                ref_logits = ref_logits.float()
                if temperature != 1.0:
                    ref_logits = ref_logits / temperature
                if top_k > 0:
                    v, _ = torch.topk(ref_logits, top_k)
                    kth = v[:, -1].unsqueeze(-1)
                    ref_logits = torch.where(ref_logits < kth, torch.full_like(ref_logits, -float("inf")), ref_logits)
                ref_logp = F.log_softmax(ref_logits, dim=-1)
                ref_gathered = torch.gather(ref_logp, 1, next_token.unsqueeze(1)).squeeze(1)
                logp_ref_steps.append(ref_gathered)

        done = done | (next_token == retry_id) | (next_token == pad_id) | (next_token == final_state)
        if done.all():
            break

    if logp_steps:
        log_probs = torch.stack(logp_steps, dim=1)
    else:
        log_probs = torch.zeros((total, 0), device=device, dtype=torch.float32)
    steps = log_probs.size(1)
    if steps < max_len - 1:
        pad = torch.zeros((total, max_len - 1 - steps), device=device, dtype=log_probs.dtype)
        log_probs = torch.cat([log_probs, pad], dim=1)

    log_probs_ref = None
    if logp_ref_steps is not None:
        if logp_ref_steps:
            log_probs_ref = torch.stack(logp_ref_steps, dim=1)
        else:
            log_probs_ref = torch.zeros((total, 0), device=device, dtype=torch.float32)
        steps_ref = log_probs_ref.size(1)
        if steps_ref < max_len - 1:
            pad_ref = torch.zeros((total, max_len - 1 - steps_ref), device=device, dtype=log_probs_ref.dtype)
            log_probs_ref = torch.cat([log_probs_ref, pad_ref], dim=1)

    terminal = (tokens == final_state) | (tokens == retry_id) | (tokens == pad_id)
    term_any = terminal.any(dim=1)
    term_idx = terminal.float().argmax(dim=1)
    term_idx = torch.where(term_any, term_idx, torch.full_like(term_idx, max_len - 1))
    lengths = term_idx + 1
    pos = torch.arange(max_len - 1, device=device)
    action_mask = (pos.unsqueeze(0) < (lengths - 1).unsqueeze(1)).float()

    log_probs = log_probs * action_mask
    if log_probs_ref is not None:
        log_probs_ref = log_probs_ref * action_mask

    return tokens, log_probs, action_mask, log_probs_ref


def _compute_rewards(
    tokens: torch.Tensor,
    n_states: int,
    final_state: int,
    retry_id: int,
    pad_id: int,
    allowed: torch.Tensor,
    shortest_steps: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    device = tokens.device
    total, T = tokens.shape

    terminal = (tokens == final_state) | (tokens == retry_id) | (tokens == pad_id) | (tokens >= n_states)
    term_any = terminal.any(dim=1)
    term_idx = terminal.float().argmax(dim=1)
    term_idx = torch.where(term_any, term_idx, torch.full_like(term_idx, T - 1))
    term_token = tokens[torch.arange(total, device=device), term_idx]

    lengths = term_idx + 1
    hit_final = term_token == final_state

    prev = tokens[:, :-1]
    nxt = tokens[:, 1:]
    pos = torch.arange(T - 1, device=device)
    valid_pos = pos.unsqueeze(0) < (lengths - 1).unsqueeze(1)
    state_mask = (prev < n_states) & (nxt < n_states) & valid_pos

    decrease = (nxt < prev) & state_mask
    has_decrease = decrease.any(dim=1)

    allowed_lookup = allowed[
        prev.clamp(0, n_states - 1),
        nxt.clamp(0, n_states - 1),
    ]
    invalid = (~allowed_lookup) & state_mask
    has_invalid = invalid.any(dim=1)

    success = hit_final & ~has_decrease & ~has_invalid

    start = tokens[:, 0].clamp(0, n_states - 1)
    shortest = shortest_steps[start]
    is_shortest = (shortest >= 0) & ((lengths - 1) == shortest)
    reward = (success & is_shortest).float()
    return reward, success, lengths


def train(args: argparse.Namespace) -> None:
    torch.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    if device.type == "cuda":
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    ckpt = torch.load(args.init_ckpt, map_location="cpu")
    cfg = GPTConfig(**ckpt["config"])
    model = GPT(cfg)
    model.load_state_dict(ckpt["model"])
    model.to(device)
    model.train()

    ref_model = None
    if args.kl_coef > 0:
        ref_model = GPT(cfg)
        ref_model.load_state_dict(ckpt["model"])
        ref_model.to(device)
        ref_model.eval()
        for p in ref_model.parameters():
            p.requires_grad_(False)

    ds = ckpt.get("ds", {})
    args.n_hubs = args.n_hubs or ds.get("n_hubs", args.n_hubs)
    args.m = args.m or ds.get("m", args.m)
    if not args.n_hubs or not args.m:
        raise ValueError("n_hubs and m must be set either in checkpoint or via args.")

    n_states, next_states, _ = build_markov_transitions(args.n_hubs, args.m)
    retry_id = int(ds.get("retry_id", n_states))
    pad_id = int(ds.get("pad_id", n_states + 1))
    vocab_size = int(ds.get("vocab_size", n_states + 2))
    if cfg.vocab_size != vocab_size:
        raise ValueError(f"checkpoint vocab_size={cfg.vocab_size} != expected {vocab_size}")

    final_state = n_states - 1
    max_len = args.max_path_len or (args.n_hubs * args.m)
    max_len = max(2, min(max_len, cfg.block_size))

    allowed = _build_allowed_matrix(next_states, n_states, device)
    shortest_steps = _compute_shortest_steps(next_states, n_states)
    shortest_steps_t = torch.tensor(shortest_steps, device=device, dtype=torch.long)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay,
        betas=(0.9, 0.95),
    )
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda" and args.amp))

    os.makedirs(args.save_dir, exist_ok=True)
    run_tag = _build_run_tag(args, cfg)

    best_reward = -float("inf")
    t0 = time.time()

    for step in range(args.steps):
        starts = torch.randint(0, n_states, (args.batch_size,), device=device)
        starts = starts.repeat_interleave(args.group_size)

        tokens, log_probs, action_mask, log_probs_ref = _sample_trajectories(
            model,
            ref_model,
            starts,
            max_len,
            retry_id,
            pad_id,
            final_state,
            args.temperature,
            args.top_k,
            args.amp and device.type == "cuda",
        )

        with torch.no_grad():
            rewards, success, lengths = _compute_rewards(
                tokens, n_states, final_state, retry_id, pad_id, allowed, shortest_steps_t
            )

        rewards = rewards.view(args.batch_size, args.group_size)
        mean = rewards.mean(dim=1, keepdim=True)
        std = rewards.std(dim=1, keepdim=True)
        adv = (rewards - mean) / (std + args.adv_eps)
        adv = adv.view(-1).detach()

        logp_sum = (log_probs * action_mask).sum(dim=1)
        loss = -(adv * logp_sum).mean()

        if log_probs_ref is not None:
            kl = (log_probs - log_probs_ref) * action_mask
            kl_sum = kl.sum(dim=1)
            loss = loss + args.kl_coef * kl_sum.mean()

        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        if args.grad_clip > 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        scaler.step(optimizer)
        scaler.update()

        if step % args.log_every == 0:
            dt = time.time() - t0
            mean_reward = rewards.mean().item()
            success_rate = success.float().mean().item()
            avg_len = lengths.float().mean().item()
            kl_val = 0.0
            if log_probs_ref is not None:
                kl_val = float(((log_probs - log_probs_ref) * action_mask).sum(dim=1).mean().item())
            print(
                f"step={step} loss={loss.item():.4f} reward={mean_reward:.3f} "
                f"success={success_rate:.3f} avg_len={avg_len:.2f} kl={kl_val:.4f} dt={dt:.2f}s"
            )
            t0 = time.time()

        if args.save_best:
            mean_reward = rewards.mean().item()
            if mean_reward > best_reward:
                best_reward = mean_reward
                best_path = os.path.join(args.save_dir, f"best_grpo_{run_tag}.pt")
                torch.save(
                    {
                        "model": model.state_dict(),
                        "config": cfg.__dict__,
                        "ds": {
                            "n_hubs": args.n_hubs,
                            "m": args.m,
                            "retry_id": retry_id,
                            "pad_id": pad_id,
                            "vocab_size": vocab_size,
                        },
                        "grpo": {
                            "step": step,
                            "mean_reward": best_reward,
                            "max_len": max_len,
                            "group_size": args.group_size,
                            "temperature": args.temperature,
                            "top_k": args.top_k,
                            "kl_coef": args.kl_coef,
                        },
                    },
                    best_path,
                )
                print(f"saved best: {best_path} (reward={best_reward:.3f})")

        if args.save_every and step > 0 and (step % args.save_every == 0):
            ckpt_path = os.path.join(args.save_dir, f"ckpt_grpo_{run_tag}_step{step}.pt")
            torch.save(
                {
                    "model": model.state_dict(),
                    "config": cfg.__dict__,
                    "ds": {
                        "n_hubs": args.n_hubs,
                        "m": args.m,
                        "retry_id": retry_id,
                        "pad_id": pad_id,
                        "vocab_size": vocab_size,
                    },
                    "grpo": {
                        "step": step,
                        "mean_reward": rewards.mean().item(),
                        "max_len": max_len,
                        "group_size": args.group_size,
                        "temperature": args.temperature,
                        "top_k": args.top_k,
                        "kl_coef": args.kl_coef,
                    },
                },
                ckpt_path,
            )
            print(f"saved: {ckpt_path}")


def main() -> None:
    p = argparse.ArgumentParser()
    p.add_argument("--init_ckpt", type=str, required=True)
    p.add_argument("--save_dir", type=str, default="checkpoints")
    p.add_argument("--save_every", type=int, default=0)
    p.add_argument("--save_best", action="store_true")

    p.add_argument("--n_hubs", type=int, default=0)
    p.add_argument("--m", type=int, default=0)
    p.add_argument("--max_path_len", type=int, default=0)

    p.add_argument("--batch_size", type=int, default=32)
    p.add_argument("--group_size", type=int, default=4)
    p.add_argument("--steps", type=int, default=1000)
    p.add_argument("--lr", type=float, default=1e-5)
    p.add_argument("--weight_decay", type=float, default=0.0)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--kl_coef", type=float, default=0.02)
    p.add_argument("--adv_eps", type=float, default=1e-5)

    p.add_argument("--temperature", type=float, default=1.0)
    p.add_argument("--top_k", type=int, default=0)

    p.add_argument("--log_every", type=int, default=20)
    p.add_argument("--amp", action="store_true")
    p.add_argument("--cpu", action="store_true")
    p.add_argument("--seed", type=int, default=1337)

    args = p.parse_args()
    train(args)


if __name__ == "__main__":
    main()
