#!/usr/bin/env python3
import os
import math
import argparse
import torch

# Use the same wrappers your training code calls
# from yunchang.kernels.attention import (
#     flash_attn3_func_forward,
#     flash_attn3_func_backward,
# )
from flash_attn_interface import _flash_attn_backward as flash3_bwd_raw, _flash_attn_forward as flash3_fwd_raw

def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--bs", type=int, default=1)
    ap.add_argument("--seqlen", type=int, default=2097152)
    ap.add_argument("--nheads", type=int, default=1)
    ap.add_argument("--head-dim", type=int, default=128)
    ap.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"])
    ap.add_argument("--chunk", type=int, default=131072, help="Set YUNCHANG_FA3_ATTENTION_CHUNK; use -1 to disable")
    ap.add_argument("--causal", action="store_true", default=True)
    ap.add_argument("--raw-bwd", action="store_true", help="Use raw flash_attn_interface._flash_attn_backward instead of wrapper")
    return ap.parse_args()

def main():
    args = parse_args()

    # Configure chunking for FA3 via the env var your wrapper reads
    os.environ["YUNCHANG_FA3_ATTENTION_CHUNK"] = str(args.chunk)

    device = "cuda"
    torch.cuda.set_device(0)
    torch.backends.cuda.matmul.allow_tf32 = True

    dt = torch.bfloat16 if args.dtype == "bf16" else torch.float16
    bs, seqlen, nheads, head_dim = args.bs, args.seqlen, args.nheads, args.head_dim
    softmax_scale = 1.0 / math.sqrt(head_dim)

    print(f"Allocating tensors: bs={bs}, seqlen={seqlen}, nheads={nheads}, head_dim={head_dim}, dtype={dt}, chunk={args.chunk}")
    q = torch.randn(bs, seqlen, nheads, head_dim, device=device, dtype=dt)
    k = torch.randn_like(q)
    v = torch.randn_like(q)

    #Forward
    try:
        torch.cuda.synchronize()
        # out, lse = flash_attn3_func_forward(
        #     q, k, v,
        #     dropout_p=0.0,
        #     softmax_scale=softmax_scale,
        #     causal=args.causal,
        #     window_size=(-1, -1),
        #     softcap=0.0,
        #     alibi_slopes=None,
        #     return_softmax=False,
        # )
        out, lse, *rest = flash3_fwd_raw(
            q,
            k,
            v,
            None, None,  # k_new, v_new
            None,  # qv
            None,  # out
            None, None, None,   # cu_seqlens_q/k/k_new
            None, None,   # seqused_q/k
            None, None,   # max_seqlen_q/k
            None, None, None,   # page_table, kv_batch_idx, leftpad_k,
            None, None, None,  # rotary_cos/sin, seqlens_rotary
            None, None, None, # q_descale, k_descale, v_descale
            softmax_scale=softmax_scale,
            causal=args.causal,
            window_size=(-1, -1),
            # attention_chunk=0,  # previous behavior (no chunking)
            attention_chunk=args.chunk,
            softcap=0.0,
            # num_splits=1,  # previous behavior (no splitting)
            num_splits=1,
            pack_gqa=None,
            sm_margin=0,
        )
        torch.cuda.synchronize()
        print(f"Forward OK. out={tuple(out.shape)}, lse={tuple(lse.shape)}")
    except Exception as e:
        print("Forward FAILED:", repr(e))
        raise

    # Backward
    try:
        dout = torch.randn_like(out)
        dq = torch.empty_like(q)
        dk = torch.empty_like(k)
        dv = torch.empty_like(v)

        torch.cuda.synchronize()
        if args.raw_bwd:
            print("Using raw flash_attn_interface._flash_attn_backward")
            # Direct call into FA3 Python interface
            flash3_bwd_raw(
                dout,
                q,
                k,
                v,
                q,
                lse,
                None, None,  # cu_seqlens_q, cu_seqlens_k
                None, None,  # sequed_q, sequed_k
                None, None,  # max_seqlen_q, max_seqlen_k
                dq,
                dk,
                dv,
                softmax_scale,
                args.causal,
                (-1, -1),
                0.0,
                True,
                0,
            )
        else:
            # Wrapper path
            flash_attn3_func_backward(
                dout,
                q, k, v,
                out,
                lse,
                dq, dk, dv,
                dropout_p=0.0,
                softmax_scale=softmax_scale,
                bwd_causal=args.causal,
                window_size=(-1, -1),
                softcap=0.0,
                alibi_slopes=None,
                deterministic=True,
                rng_state=None,
            )
        torch.cuda.synchronize()
        print("Backward OK.",
              f"||dq||={dq.float().norm().item():.3e}",
              f"||dk||={dk.float().norm().item():.3e}",
              f"||dv||={dv.float().norm().item():.3e}")
    except Exception as e:
        print("Backward FAILED:", repr(e))
        raise

if __name__ == "__main__":
    main()