# train.py
from __future__ import annotations
import os
import time
import math
import argparse

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from dataset import MarkovRetryBlockDataset
from model import GPT, GPTConfig


def mask_retry_targets(x: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor, retry_id: int) -> None:
    # Ignore padding and tokens immediately following retry.
    ignore = (y_mask == 0) | (x == retry_id)
    y[ignore] = -100


def _build_run_tag(args: argparse.Namespace) -> str:
    return (
        f"h{args.n_hubs}_m{args.m}_emb{args.n_embd}_l{args.n_layer}_head{args.n_head}"
        f"_bs{args.block_size}"
    )


def _build_allowed_next(next_states, next_probs) -> list[set[int]]:
    allowed: list[set[int]] = []
    for ns, ps in zip(next_states, next_probs):
        allowed.append({int(s) for s, p in zip(ns.tolist(), ps.tolist()) if p > 0})
    return allowed


def _extract_path_starts(
    ids: torch.Tensor,
    attn: torch.Tensor,
    retry_id: int,
    pad_id: int,
    n_states: int,
) -> list[int]:
    ids_cpu = ids.detach().cpu()
    attn_cpu = attn.detach().cpu()

    starts: list[int] = []
    for i in range(ids_cpu.size(0)):
        tokens = ids_cpu[i][attn_cpu[i] == 1].tolist()
        prev = None
        for tok in tokens:
            if tok == pad_id:
                break
            if prev is None or prev == retry_id:
                if tok != retry_id and 0 <= tok < n_states:
                    starts.append(int(tok))
            prev = tok
    return starts


def _sample_path_is_possible(
    model: GPT,
    start_token: int,
    allowed_next: list[set[int]],
    retry_id: int,
    pad_id: int,
    max_len: int,
    device: torch.device,
) -> bool:
    n_states = len(allowed_next)
    if start_token < 0 or start_token >= n_states:
        return False

    tokens = [int(start_token)]
    cur = int(start_token)
    for _ in range(max_len - 1):
        x = torch.tensor([tokens], device=device, dtype=torch.long)
        logits = model(x)
        probs = F.softmax(logits[:, -1, :], dim=-1)
        next_token = int(torch.multinomial(probs, num_samples=1).item())
        tokens.append(next_token)
        if next_token == retry_id:
            return True
        if next_token == pad_id or next_token < 0 or next_token >= n_states:
            return False
        if next_token not in allowed_next[cur]:
            print(tokens)
            print(probs)
            return False

        
        cur = next_token

    return False


def _count_possible_paths_from_model(
    model: GPT,
    ids: torch.Tensor,
    attn: torch.Tensor,
    allowed_next: list[set[int]],
    retry_id: int,
    pad_id: int,
    max_len: int,
    device: torch.device,
) -> tuple[int, int]:
    starts = _extract_path_starts(ids, attn, retry_id, pad_id, len(allowed_next))
    valid = 0
    total = 0
    for start in starts:
        total += 1
        if _sample_path_is_possible(
            model, start, allowed_next, retry_id, pad_id, max_len, device
        ):
            valid += 1
    return valid, total


@torch.no_grad()
def evaluate(
    model: GPT,
    val_dl: DataLoader,
    device: torch.device,
    vocab_size: int,
    retry_id: int,
    pad_id: int,
    next_states,
    next_probs,
    max_path_len: int,
    max_batches: int | None = None,
):
    model.eval()
    total_nll = 0.0
    total_tokens = 0
    total_valid_paths = 0
    total_paths = 0

    allowed_next = _build_allowed_next(next_states, next_probs)
    max_gen_len = min(max_path_len, model.cfg.block_size)
    if max_gen_len < 2:
        max_gen_len = 2

    for b, (ids, attn) in enumerate(val_dl):
        if max_batches is not None and b >= max_batches:
            break

        valid, count = _count_possible_paths_from_model(
            model, ids, attn, allowed_next, retry_id, pad_id, max_gen_len, device
        )
        total_valid_paths += valid
        total_paths += count

        ids = ids.to(device, non_blocking=True)
        attn = attn.to(device, non_blocking=True)

        # ids/attn: (B, block_size+1)
        x = ids[:, :-1]                 # (B, block_size)
        y = ids[:, 1:].clone()          # (B, block_size)
        am = attn[:, :-1]               # attention mask for x
        y_mask = attn[:, 1:]            # mask aligned with y positions

        mask_retry_targets(x, y, y_mask, retry_id)

        logits = model(x, attention_mask=am)  # (B, T, vocab)
        nll_sum = F.cross_entropy(
            logits.reshape(-1, vocab_size),
            y.reshape(-1),
            ignore_index=-100,
            reduction="sum",
        )

        total_nll += float(nll_sum.item())
        total_tokens += int((y != -100).sum().item())

    avg_nll = total_nll / max(1, total_tokens)  # per-token NLL
    ppl = math.exp(avg_nll) if avg_nll < 20 else float("inf")
    correctness = total_valid_paths / max(1, total_paths)
    model.train()
    return avg_nll, ppl, correctness


def train(args):
    torch.manual_seed(args.torch_seed)
    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    if device.type == "cuda":
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    # Train dataset (synthetic, infinite-ish)
    train_ds = MarkovRetryBlockDataset(
        n_hubs=args.n_hubs,
        m=args.m,
        block_size=args.block_size,
        num_paths=args.num_paths,
        min_path_len=args.min_path_len,
        max_path_len=args.max_path_len,
        num_samples=args.num_samples,
        seed=args.data_seed,
        pad_to_block=True,
    )

    train_dl = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=(device.type == "cuda"),
        drop_last=True,
        persistent_workers=(args.num_workers > 0),
    )

    # Val dataset (fixed seed so metric is stable)
    val_ds = MarkovRetryBlockDataset(
        n_hubs=args.n_hubs,
        m=args.m,
        block_size=args.block_size,
        num_paths=args.num_paths,
        min_path_len=args.min_path_len,
        max_path_len=args.max_path_len,
        num_samples=args.val_samples,
        seed=args.data_seed + 12345,
        pad_to_block=True,
    )
    val_ds.set_epoch(0)

    val_dl = DataLoader(
        val_ds,
        batch_size=args.val_batch_size or args.batch_size,
        shuffle=False,
        num_workers=max(0, args.num_workers // 2),
        pin_memory=(device.type == "cuda"),
        drop_last=False,
        persistent_workers=(args.num_workers > 0),
    )

    cfg = GPTConfig(
        vocab_size=train_ds.vocab_size,
        block_size=args.block_size,
        n_layer=args.n_layer,
        n_head=args.n_head,
        n_embd=args.n_embd,
        dropout=args.dropout,
    )
    model = GPT(cfg).to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay,
        betas=(0.9, 0.95),
    )
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda" and args.amp))

    os.makedirs(args.save_dir, exist_ok=True)
    run_tag = _build_run_tag(args)

    best_val = float("inf")
    step = 0
    t0 = time.time()

    model.train()
    for epoch in range(args.epochs):
        train_ds.set_epoch(epoch)

        for ids, attn in train_dl:
            ids = ids.to(device, non_blocking=True)
            attn = attn.to(device, non_blocking=True)

            x = ids[:, :-1]
            y = ids[:, 1:].clone()
            am = attn[:, :-1]
            y_mask = attn[:, 1:]
            mask_retry_targets(x, y, y_mask, train_ds.retry_id)

            micro = args.grad_accum_steps
            assert x.size(0) % micro == 0, "batch_size must be divisible by grad_accum_steps"
            micro_bs = x.size(0) // micro

            optimizer.zero_grad(set_to_none=True)

            loss_total = 0.0
            for i in range(micro):
                xb = x[i * micro_bs : (i + 1) * micro_bs]
                yb = y[i * micro_bs : (i + 1) * micro_bs]
                amb = am[i * micro_bs : (i + 1) * micro_bs]

                with torch.cuda.amp.autocast(enabled=(device.type == "cuda" and args.amp)):
                    logits = model(xb, attention_mask=amb)
                    loss = F.cross_entropy(
                        logits.reshape(-1, train_ds.vocab_size),
                        yb.reshape(-1),
                        ignore_index=-100,
                    )
                    loss = loss / micro

                scaler.scale(loss).backward()
                loss_total += float(loss.item())

            if args.grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

            scaler.step(optimizer)
            scaler.update()

            # Logging
            if step % args.log_every == 0 and step > 0:
                dt = time.time() - t0
                toks = args.batch_size * args.block_size * args.log_every
                tok_s = toks / dt if dt > 0 else 0.0
                print(
                    f"epoch={epoch} step={step} train_loss={loss_total:.4f} "
                    f"tok/s≈{tok_s:,.0f} (B={args.batch_size}, T={args.block_size}, accum={args.grad_accum_steps})"
                )
                t0 = time.time()

            # Periodic validation
            if args.eval_every > 0 and step >= 0 and (step % args.eval_every == 0):
                val_nll, val_ppl, val_corr = evaluate(
                    model,
                    val_dl,
                    device,
                    train_ds.vocab_size,
                    train_ds.retry_id,
                    train_ds.pad_id,
                    train_ds.next_states,
                    train_ds.next_probs,
                    train_ds.max_path_len,
                    max_batches=args.val_batches,
                )
                print(f"[val] step={step} nll={val_nll:.4f} ppl={val_ppl:.3f} correctness={val_corr:.3f}")

                if args.save_best and val_nll < best_val:
                    best_val = val_nll
                    best_path = os.path.join(args.save_dir, f"best_{run_tag}.pt")
                    torch.save(
                        {
                            "model": model.state_dict(),
                            "config": cfg.__dict__,
                            "ds": {
                                "n_hubs": args.n_hubs,
                                "m": args.m,
                                "retry_id": train_ds.retry_id,
                                "pad_id": train_ds.pad_id,
                                "vocab_size": train_ds.vocab_size,
                            },
                            "step": step,
                            "epoch": epoch,
                            "best_val_nll": best_val,
                        },
                        best_path,
                    )
                    print(f"saved best: {best_path} (best_val_nll={best_val:.4f})")

            # Optional checkpoint saving
            if args.save_every and step > 0 and (step % args.save_every == 0):
                ckpt = os.path.join(args.save_dir, f"ckpt_{run_tag}_step{step}.pt")
                torch.save(
                    {
                        "model": model.state_dict(),
                        "config": cfg.__dict__,
                        "ds": {
                            "n_hubs": args.n_hubs,
                            "m": args.m,
                            "retry_id": train_ds.retry_id,
                            "pad_id": train_ds.pad_id,
                            "vocab_size": train_ds.vocab_size,
                        },
                        "step": step,
                        "epoch": epoch,
                    },
                    ckpt,
                )
                print(f"saved: {ckpt}")

            step += 1
            if args.max_steps is not None and step >= args.max_steps:
                return


def main():
    p = argparse.ArgumentParser()

    # Data / Markov
    p.add_argument("--n_hubs", type=int, default=6)
    p.add_argument("--m", type=int, default=10)
    p.add_argument("--num_paths", type=int, default=6)
    p.add_argument("--min_path_len", type=int, default=2)
    p.add_argument("--max_path_len", type=int, default=None)
    p.add_argument("--block_size", type=int, default=256)
    p.add_argument("--num_samples", type=int, default=100_000)
    p.add_argument("--data_seed", type=int, default=0)

    # Model
    p.add_argument("--n_layer", type=int, default=2)  # 1..6
    p.add_argument("--n_head", type=int, default=4)
    p.add_argument("--n_embd", type=int, default=128)
    p.add_argument("--dropout", type=float, default=0.1)

    # Training
    p.add_argument("--batch_size", type=int, default=64)
    p.add_argument("--grad_accum_steps", type=int, default=1)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--weight_decay", type=float, default=0.1)
    p.add_argument("--epochs", type=int, default=5)
    p.add_argument("--max_steps", type=int, default=None)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--amp", action="store_true")
    p.add_argument("--cpu", action="store_true")
    p.add_argument("--num_workers", type=int, default=0)

    # Validation
    p.add_argument("--eval_every", type=int, default=200, help="0 disables validation.")
    p.add_argument("--val_samples", type=int, default=20_000)
    p.add_argument("--val_batches", type=int, default=50, help="Limit evaluation cost (None/0 = full).")
    p.add_argument("--val_batch_size", type=int, default=0, help="0 => use train batch_size.")
    p.add_argument("--save_best", action="store_true", help="Save checkpoints/best.pt on improved val NLL.")

    # Logging/saving
    p.add_argument("--log_every", type=int, default=50)
    p.add_argument("--save_dir", type=str, default="checkpoints")
    p.add_argument("--save_every", type=int, default=0, help="0 disables; else save every N steps.")
    p.add_argument("--torch_seed", type=int, default=1337)

    args = p.parse_args()
    if args.val_batches == 0:
        args.val_batches = None
    if args.val_batch_size == 0:
        args.val_batch_size = None
    train(args)


if __name__ == "__main__":
    main()
