import torch
import os
import torch.nn as nn
import torch.nn.functional as F

from .utils import RingComm, update_out_and_lse
from yunchang.kernels import AttnType, select_flash_attn_impl
from yunchang.globals import U_HANDLE, clear_u_handle, clear_o_handle, PROCESS_GROUP, GRAPH_FLAGS
from yunchang.comm.all_to_all import vanilla_all_to_all_4D as all_to_all_4D, SeqAllToAll4D
import logging
logger = logging.getLogger(__name__)

from typing import List, Tuple, Dict, Any


global process_group, attn_type, alibi_slopes, window_size, next_ulysses_qkv, a2a_available, ulysses_group, two_streams

import torch.distributed as dist

def check_nan_inf(tensor, name, rank=0):
    """Check for NaN or Inf values in tensor"""
    if torch.isnan(tensor).any():
        print(f"[RANK {rank}] NaN detected in {name}, shape: {tensor.shape}")
        print(f"[RANK {rank}] NaN locations: {torch.isnan(tensor).sum().item()}")
        return True
    if torch.isinf(tensor).any():
        print(f"[RANK {rank}] Inf detected in {name}, shape: {tensor.shape}")
        print(f"[RANK {rank}] Inf locations: {torch.isinf(tensor).sum().item()}")
        return True
    return False

# @torch.library.custom_op("yunchang::_fully_fused_ring_flash_attn_forward", mutates_args=(), device_types="cuda")
# @torch.no_grad()
def fully_fused_ring_flash_attn_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: float,
    dropout_p: float = 0,
    causal: bool = True,
    softcap: float = 0.0,
    deterministic: bool=False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    global process_group, attn_type, alibi_slopes, window_size
    
    assert causal == True, "zigzag ring is meaningless for causal=False"
    comm = RingComm(process_group)

    block_seq_len = q.shape[1] // 2
    q1 = q[:, block_seq_len:]

    out = None
    lse = None
    next_k, next_v = None, None
    
    def forward(q, k, v, causal):
        fn = select_flash_attn_impl(attn_type, stage="fwd-only")
        block_out, block_lse = fn(
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
            window_size=window_size,
            softcap=softcap,
            alibi_slopes=alibi_slopes,
            return_softmax=True and dropout_p > 0,
        )
        return block_out, block_lse

    for step in range(comm.world_size):
        if step + 1 != comm.world_size:
            next_k: torch.Tensor = comm.send_recv(k)
            next_v: torch.Tensor = comm.send_recv(v)
            comm.commit()

        if step == 0:
            block_out, block_lse = forward(q, k, v, causal=True)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)
        elif step <= comm.rank:
            k0 = k[:, :block_seq_len]
            v0 = v[:, :block_seq_len]
            block_out, block_lse = forward(q, k0, v0, causal=False)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)
        else:
            block_out, block_lse = forward(q1, k, v, causal=False)
            out, lse = update_out_and_lse(
                out,
                lse,
                block_out,
                block_lse,
                slice_=(slice(None), slice(block_seq_len, None)),
            )

        if step + 1 != comm.world_size:
            comm.wait()
            k = next_k
            v = next_v

    out = out.to(q.dtype)
    lse = lse.squeeze(dim=-1).transpose(1, 2)
    return out, lse

# @fully_fused_ring_flash_attn_forward.register_fake
# def _(
#     q: torch.Tensor,
#     k: torch.Tensor,
#     v: torch.Tensor,
#     softmax_scale: float,
#     dropout_p: float = 0,
#     causal: bool = True,
#     softcap: float = 0.0,
#     deterministic: bool=False,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
#     bs, sl, nh, d = q.shape
#     out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
#     lse = torch.empty([bs, nh, sl], dtype=q.dtype, device=q.device)
#     return out, lse

# @torch.no_grad()
def fully_fused_ring_flash_attn_backward(
    process_group,
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    softmax_scale,
    dropout_p=0,
    causal=True,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    attn_type: AttnType = AttnType.FA,
):
    assert causal == True, "upipe ring is meaningless for causal=False"
    kv_comm = RingComm(process_group)
    d_kv_comm = RingComm(process_group)
    dq, dk, dv = None, None, None
    next_dk, next_dv = None, None
    next_k, next_v = None, None
    dk_comm_buffer, dv_comm_buffer = None, None

    dout1 = dout.chunk(2, dim=1)[1]
    q1 = q.chunk(2, dim=1)[1]
    out1 = out.chunk(2, dim=1)[1]
    softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous()
    block_seq_len = q.shape[1] // 2

    # repeatly allocating buffer may be slow...
    # if k.shape[2] == q.shape[2]:
    dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
    dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
    # else:
    #     dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    #     dk_buffer = torch.empty((k.shape[0], k.shape[1], q.shape[2], k.shape[3]), dtype=k.dtype, device=k.device)
    #     dv_buffer = torch.empty((v.shape[0], v.shape[1], q.shape[2], v.shape[3]), dtype=v.dtype, device=v.device)

    def backward(dout, q, k, v, out, softmax_lse, causal):
        seqlen_q = q.shape[1]
        seqlen_kv = k.shape[1]
        fn = select_flash_attn_impl(attn_type, stage="bwd-only")

        fn(
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq_buffer[:, :seqlen_q],
            dk_buffer[:, :seqlen_kv],
            dv_buffer[:, :seqlen_kv],
            dropout_p,
            softmax_scale,
            causal,
            window_size,
            softcap,
            alibi_slopes,
            deterministic,
            rng_state=None,
        )

    for step in range(kv_comm.world_size):
        # assert not check_nan_inf(dout, f"dout -{step}", dist.get_rank()), f"Pipe: dout is nan or inf at step {step}"
        # assert not check_nan_inf(q, f"q -{step}", dist.get_rank()), f"Pipe: q is nan or inf at step {step}"
        # assert not check_nan_inf(k, f"k -{step}", dist.get_rank()), f"Pipe: k is nan or inf at step {step}"
        # assert not check_nan_inf(v, f"v -{step}", dist.get_rank()), f"Pipe: v is nan or inf at step {step}"
        # assert not check_nan_inf(out, f"out -{step}", dist.get_rank()), f"Pipe: out is nan or inf at step {step}"
        # assert not check_nan_inf(softmax_lse, f"softmax_lse -{step}", dist.get_rank()), f"Pipe: softmax_lse is nan or inf at step {step}"

        if step + 1 != kv_comm.world_size:
            next_k = kv_comm.send_recv(k)
            next_v = kv_comm.send_recv(v)
            kv_comm.commit()

        if step == 0:
            backward(dout.contiguous(), q.contiguous(), k.contiguous(), v.contiguous(), out.contiguous(), softmax_lse.contiguous(), causal=True)
            if kv_comm.world_size == 1:
                return dq_buffer, dk_buffer, dv_buffer
            dq = dq_buffer.to(torch.float32)
            dk = dk_buffer.to(torch.float32)
            dv = dv_buffer.to(torch.float32)
            
            
            # assert not check_nan_inf(dq, f"dq -{step}", dist.get_rank()), f"Pipe: dq is nan or inf at step {step}"
            # assert not check_nan_inf(dk, f"dk -{step}", dist.get_rank()), f"Pipe: dk is nan or inf at step {step}"
            # assert not check_nan_inf(dv, f"dv -{step}", dist.get_rank()), f"Pipe: dv is nan or inf at step {step}"
        else:
            if step <= kv_comm.rank:
                k0 = k[:, :block_seq_len]
                v0 = v[:, :block_seq_len]
                backward(dout, q, k0, v0, out, softmax_lse, causal=False)
                dq += dq_buffer
                # assert not check_nan_inf(dq, f"dq 251 -{step}", dist.get_rank()), f"Pipe: dq is nan or inf at step {step}"
            else:
                backward(dout1, q1, k, v, out1, softmax_lse1, causal=False)
                # always use the first half in dq_buffer.
                dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len]
                # assert not check_nan_inf(dq, f"dq 256 -{step}", dist.get_rank()), f"Pipe: dq is nan or inf at step {step}"

            d_kv_comm.wait()
            dk_comm_buffer, dv_comm_buffer = dk, dv
            dk, dv = next_dk, next_dv

            if step <= kv_comm.rank:
                dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len]
                dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len]
            else:
                dk += dk_buffer
                dv += dv_buffer
                # assert not check_nan_inf(dk, f"dk 268 -{step}", dist.get_rank()), f"Pipe: dk is nan or inf at step {step}"
                # assert not check_nan_inf(dv, f"dv 269 -{step}", dist.get_rank()), f"Pipe: dv is nan or inf at step {step}"

        if step + 1 != kv_comm.world_size:
            kv_comm.wait()
            k = next_k
            v = next_v

        
        next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer)
        next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer)
        d_kv_comm.commit()

    
    d_kv_comm.wait()

    orig_q_dtype = q.dtype

    del dout, q, k, v, out, softmax_lse, dq_buffer, dk_buffer, dv_buffer, dk_comm_buffer, dv_comm_buffer, next_k, next_v

    # if k.shape[2] != q.shape[2]:
    #     bs, slen, nqh, hdim = q.shape
    #     return dq.to(q.dtype), next_dk.view(bs, slen, k.shape[2], (q.shape[2]//k.shape[2]), hdim).sum(dim=3).to(q.dtype), next_dv.view(bs, slen, v.shape[2], (q.shape[2]//v.shape[2]), hdim).sum(dim=3).to(q.dtype)
    # else:
    return dq.to(orig_q_dtype).detach(), next_dk.to(orig_q_dtype).detach(), next_dv.to(orig_q_dtype).detach()

from torch import Tensor
# from yunchang.ring.upipe_ring_flash_attn import upipe_ring_flash_attn_forward, upipe_ring_flash_attn_backward 

class FullyPipelinedAttnFunc(torch.autograd.Function):
    
    @staticmethod
    def forward(
        ctx,
        q0: Tensor,
        q1: Tensor,
        q2: Tensor,
        q3: Tensor,
        k0: Tensor,
        k1: Tensor,
        k2: Tensor,
        k3: Tensor,
        v0: Tensor,
        v1: Tensor,
        v2: Tensor,
        v3: Tensor,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        two_streams,
        attn_type,
    ) -> Tuple[Tensor, Tensor, None] | Tensor:  
        # Do all the ulysses stuff here
        # q = torch.chunk(q, 4, dim = 2)
        # k = torch.chunk(k, 4, dim = 2)
        # v = torch.chunk(v, 4, dim = 2)
        q = [q0, q1, q2, q3]
        k = [k0, k1, k2, k3]
        v = [v0, v1, v2, v3]

        torch.cuda.current_stream().wait_stream(fetch_stream)
        torch.cuda.current_stream().wait_stream(offload_stream)
        

        bs, shard_seqlen, hc, hs = q[0].shape
        world_size = dist.get_world_size(ulysses_group)
        pipe_degree = hc*len(q) // world_size

        orig_q_devices = [t.device for t in q]
        
        # for i in range(len(q)):
        #     q[i] = q[i].to(q[0].device)
        #     k[i] = k[i].to(q[0].device)
        #     v[i] = v[i].to(q[0].device)
        
        # q = torch.cat(q, dim = 2)
        # k = torch.cat(k, dim = 2)
        # v = torch.cat(v, dim = 2)

        

        # q = list(torch.chunk(q, pipe_degree, dim = 2))
        # k = list(torch.chunk(k, pipe_degree, dim = 2))
        # v = list(torch.chunk(v, pipe_degree, dim = 2))

        # Collect block outputs to concatenate later (maintains gradient graph)
        output = torch.zeros([bs, shard_seqlen, hc*len(q), hs], device=q[0].device, dtype=q[0].dtype)
        softmax_lse = None
        
        # assert type(q) == list, f"Pipe: q must be a list"
        # assert type(k) == list, f"Pipe: k must be a list"
        # assert type(v) == list, f"Pipe: v must be a list"
        # assert len(q) == len(k) == len(v), f"Pipe: q, k, v must have the same length"

        assert k[0].shape[2] == q[0].shape[2], f"Pipe: num heads in key {k[0].shape[2]} must be equal to query {q[0].shape[2]}"

        orig_device = q[0].device

        # for i in range(len(q)):
        #     assert q[i].requires_grad, f"Pipe: q[{i}] requires_grad must be True"
        #     assert k[i].requires_grad, f"Pipe: k[{i}] requires_grad must be True"
        #     assert v[i].requires_grad, f"Pipe: v[{i}] requires_grad must be True"

        # qkv = [torch.cat([q[i], k[i], v[i]], dim = 0) for i in range(len(q))]
        # qkv = [torch.stack([q[i], k[i], v[i]], dim = 0)[:,0] for i in range(len(q))]

        # assert qkv[0].requires_grad, f"Pipe: qkv[0] requires_grad must be True"
        # assert qkv[1].requires_grad, f"Pipe: qkv[1] requires_grad must be True"
        # assert qkv[2].requires_grad, f"Pipe: qkv[2] requires_grad must be True"

        # assert offload_stream is not None, f"Pipe: offload_stream must be provided"
        # assert fetch_stream is not None, f"Pipe: fetch_stream must be provided"

        # fetch the first qkv
        # qkv[0] = qkv[0].to(orig_device)

        if fetch_stream is not None:
            fetch_stream.wait_stream(torch.cuda.current_stream())
            with torch.cuda.stream(fetch_stream):
                q[1] = q[1].to(orig_device, non_blocking=True)
                k[1] = k[1].to(orig_device, non_blocking=True)
                v[1] = v[1].to(orig_device, non_blocking=True)

        PROCESS_GROUP.ULYSSES_PG = ulysses_group
        
        # assert q[0].requires_grad, f"Pipe: q[0] requires_grad must be True"
        ulysses_q = all_to_all_4D(
            q[0], 2, 1, False, False # scatter 2, gather 1
        )
        # assert ulysses_q.requires_grad, f"Pipe: ulysses_q requires_grad must be True"

        ulysses_k = all_to_all_4D(
            k[0], 2, 1, False, False # scatter 2, gather 1
        )
        ulysses_v = all_to_all_4D(
            v[0], 2, 1, False, False # scatter 2, gather 1
        )

        if offload_stream is not None:
            offload_stream.wait_stream(torch.cuda.current_stream())
            with torch.cuda.stream(offload_stream):
                q[0] = q[0].to("cpu", non_blocking=True)
                k[0] = k[0].to("cpu", non_blocking=True)
                v[0] = v[0].to("cpu", non_blocking=True)
        
        # Initialize variables to avoid undefined variable errors
        block_output = None
        o_bs = o_shard_seqlen = o_hc = o_hs = None

        # global variables
        global process_group
        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size

        torch.cuda.current_stream().wait_stream(offload_stream)
        torch.cuda.current_stream().wait_stream(fetch_stream)

        for stage in range(pipe_degree):
            # torch.cuda.current_stream().wait_stream(fetch_stream)
            
            if stage+1 != pipe_degree:
                # next_ulysses_qkv = SeqAllToAll4D.apply(
                #     self.ulysses_pg, qkv[:,:,(stage+1)*world_size:(stage+2)*world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, True
                # )
                if stage+2 < pipe_degree:
                    fetch_stream.wait_stream(offload_stream)
                    with torch.cuda.stream(fetch_stream):
                        q[stage+2] = q[stage+2].to(orig_device, non_blocking=True)
                        k[stage+2] = k[stage+2].to(orig_device, non_blocking=True)
                        v[stage+2] = v[stage+2].to(orig_device, non_blocking=True)
                
                assert q[stage+1].device == orig_device, f"Pipe: q[stage+1] device {q[stage+1].device} must be {orig_device} at stage {stage}"
                assert k[stage+1].device == orig_device, f"Pipe: k[stage+1] device {k[stage+1].device} must be {orig_device} at stage {stage}"
                assert v[stage+1].device == orig_device, f"Pipe: v[stage+1] device {v[stage+1].device} must be {orig_device} at stage {stage}"

                # assert not check_nan_inf(q[stage+1], f"q_{stage+1}", dist.get_rank()), f"Pipe: q is nan or inf at stage {stage+1}"
                # assert not check_nan_inf(k[stage+1], f"k_{stage+1}", dist.get_rank()), f"Pipe: k is nan or inf at stage {stage+1}"
                # assert not check_nan_inf(v[stage+1], f"v_{stage+1}", dist.get_rank()), f"Pipe: v is nan or inf at stage {stage+1}"
                # torch.distributed.barrier()

                next_ulysses_q = all_to_all_4D(
                    q[stage+1], 2, 1, False, True # scatter 2, gather 1
                )
                next_ulysses_k = all_to_all_4D(
                    k[stage+1], 2, 1, False, True # scatter 2, gather 1
                )
                next_ulysses_v = all_to_all_4D(
                    v[stage+1], 2, 1, False, True # scatter 2, gather 1
                )
                bs, shard_seqlen, hc, hs = q[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            # output[stage].to(orig_device, non_blocking=True)

            inp_q = ulysses_q
            inp_k = ulysses_k
            inp_v = ulysses_v

            if softmax_scale is None:
                softmax_scale = inp_q.shape[-1] ** (-0.5)

            assert alibi_slopes is None
            inp_k = inp_k.contiguous()
            inp_v = inp_v.contiguous()

            # check_nan_inf(inp_q, "inp_q", dist.get_rank())
            # check_nan_inf(inp_k, "inp_k", dist.get_rank())
            # check_nan_inf(inp_v, "inp_v", dist.get_rank())
            assert inp_q.device == orig_device, f"Pipe: inp_q device {inp_q.device} must be {orig_device} at stage {stage}"
            assert inp_k.device == orig_device, f"Pipe: inp_k device {inp_k.device} must be {orig_device} at stage {stage}"
            assert inp_v.device == orig_device, f"Pipe: inp_v device {inp_v.device} must be {orig_device} at stage {stage}"

            # assert not check_nan_inf(inp_q, f"inp_q_{stage}", dist.get_rank()), f"Pipe: inp_q is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_k, f"inp_k_{stage}", dist.get_rank()), f"Pipe: inp_k is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_v, f"inp_v_{stage}", dist.get_rank()), f"Pipe: inp_v is nan or inf at stage {stage}"

            # assert inp_q.requires_grad, f"Pipe stage {stage}: inp_q requires_grad must be True"
            # assert inp_k.requires_grad, f"Pipe stage {stage}: inp_k requires_grad must be True"
            # assert inp_v.requires_grad, f"Pipe stage {stage}: inp_v requires_grad must be True"

            # inp_q.requires_grad = True
            # inp_k.requires_grad = True
            # inp_v.requires_grad = True

            
            out, lse = fullpipe_ring_flash_attn_forward(
                inp_q,
                inp_k,
                inp_v,
                softmax_scale=softmax_scale,
                dropout_p=dropout_p,
                causal=causal,
                softcap=softcap,
                deterministic=False,
            )

            # assert not check_nan_inf(out, f"out_{stage}", dist.get_rank()), f"Pipe: out is nan or inf at stage {stage}"

            # assert out.requires_grad, f"Pipe stage {stage}: out requires_grad must be True"
            # if torch.distributed.get_rank() == 0:
            #     breakpoint()
            # torch.distributed.barrier()

            if softmax_lse is None:
                softmax_lse = torch.zeros_like(lse).repeat(1, pipe_degree, 1)
            softmax_lse[:, stage, :] = lse

            context_layer = out
            
            o_bs, o_seqlen, o_shard_hc, o_hs = context_layer.shape
            o_hc = o_shard_hc * world_size
            o_shard_seqlen = o_seqlen // world_size
            
            if stage > 0:
                if U_HANDLE.O_HANDLE!=[]:
                    assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
                    U_HANDLE.O_HANDLE[0].wait()
                clear_o_handle()
                assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                # Collect block instead of in-place assignment to maintain gradients
                # output_blocks.append(block_output)
                output[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_output
                # output[stage-1].to("cpu", non_blocking=True)

            block_output = all_to_all_4D(
                context_layer, 1, 2, False, True # scatter 1, gather 2
            )
            
            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=[]:
                    assert len(U_HANDLE.HANDLE) == 3, f"Pipe: U_HANDLE.HANDLE length {len(U_HANDLE.HANDLE)} must be 3"
                    U_HANDLE.HANDLE[0].wait()
                    U_HANDLE.HANDLE[1].wait()
                    U_HANDLE.HANDLE[2].wait()
                clear_u_handle()

                if offload_stream is not None:
                    offload_stream.wait_stream(torch.cuda.current_stream())
                    with torch.cuda.stream(offload_stream):
                        q[stage+1] = q[stage+1].to("cpu", non_blocking=True)
                        k[stage+1] = k[stage+1].to("cpu", non_blocking=True)
                        v[stage+1] = v[stage+1].to("cpu", non_blocking=True)

                assert next_ulysses_q.numel()== (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_q shape {next_ulysses_q.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_q = next_ulysses_q.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_q = next_ulysses_q.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_q = next_ulysses_q
                # assert not check_nan_inf(ulysses_q, f"ulysses_q_{stage+1}", dist.get_rank()), f"Pipe: ulysses_q is nan or inf at stage {stage+1}"

                assert next_ulysses_k.numel()== (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_k shape {next_ulysses_k.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_k = next_ulysses_k.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_k = next_ulysses_k.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_k = next_ulysses_k
                # assert not check_nan_inf(ulysses_k, f"ulysses_k_{stage+1}", dist.get_rank()), f"Pipe: ulysses_k is nan or inf at stage {stage+1}"

                assert next_ulysses_v.numel()== (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_v shape {next_ulysses_v.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_v = next_ulysses_v.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_v = next_ulysses_v.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_v = next_ulysses_v
                # assert not check_nan_inf(ulysses_v, f"ulysses_v_{stage+1}", dist.get_rank()), f"Pipe: ulysses_v is nan or inf at stage {stage+1}"

                # fetch_stream.wait_stream(torch.cuda.current_stream())
        
        if U_HANDLE.O_HANDLE!=[]:
            assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
            U_HANDLE.O_HANDLE[0].wait()
            clear_o_handle()
            assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            # Collect final block instead of in-place assignment
            # output_blocks.append(block_output)
            output[:,:, -world_size:,:] = block_output

        # Concatenate all blocks to form final output (maintains gradients)
        # assert len(output_blocks) == pipe_degree, f"Expected {pipe_degree} output blocks, got {len(output_blocks)}"
        # output = torch.cat(output_blocks, dim=2)

        # out, softmax_lse = zigzag_ring_flash_attn_forward_op(out_temp, softmax_lse_temp)
        if offload_stream is not None:
            offload_stream.wait_stream(torch.cuda.current_stream())
            with torch.cuda.stream(offload_stream):
                out_cpu = output.to("cpu", non_blocking=True)

        # this should be out_padded
        to_save_tensors = []
        to_save_tensors.extend(q)
        to_save_tensors.extend(k)
        to_save_tensors.extend(v)
        to_save_tensors.append(out_cpu)
        to_save_tensors.append(softmax_lse)

        ctx.save_for_backward(*to_save_tensors)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.ulysses_group = ulysses_group
        ctx.attn_type = attn_type
        ctx.pipe_degree = pipe_degree
        ctx.offload_stream = offload_stream
        ctx.fetch_stream = fetch_stream
        ctx.orig_q_devices = orig_q_devices
        
        assert output is not None, "Output is None!!!"
        output.requires_grad = True
        # assert output.requires_grad, f"Pipe: output requires_grad must be True"
        # if torch.distributed.get_rank() == 0:
        #     breakpoint()
        # torch.distributed.barrier()
        # logger.info("----------------Forward pass done!--------------------")
        return output if not return_softmax else (output, softmax_lse, None)
        
    
    @staticmethod
    def backward(ctx, dout, *args) -> Tuple[List[Tensor], List[Tensor], List[Tensor], None, None, None, None, None, None, None, None, None, None, None, None]:

        saved_tensors = ctx.saved_tensors
        q = saved_tensors[:ctx.pipe_degree]
        k = saved_tensors[ctx.pipe_degree:2*ctx.pipe_degree]
        v = saved_tensors[2*ctx.pipe_degree:3*ctx.pipe_degree]
        out = saved_tensors[-2]
        softmax_lse = saved_tensors[-1]

        orig_q_devices = ctx.orig_q_devices

        pipe_degree = ctx.pipe_degree
        offload_stream = ctx.offload_stream
        fetch_stream = ctx.fetch_stream
        orig_device = dout.device
        world_size = dist.get_world_size(ctx.ulysses_group)

        torch.cuda.current_stream().wait_stream(fetch_stream)
        torch.cuda.current_stream().wait_stream(offload_stream)
        

        # dqkv = [torch.zeros([3*q[i].shape[0], q[i].shape[1], q[i].shape[2], q[i].shape[3]], dtype=q[i].dtype, device=q[i].device) for i in range(len(q))]
        dq = [None] * len(q)
        dk = [None] * len(q)
        dv = [None] * len(q)

        # dout = dout.to("cpu")
        # out = out.to("cpu") this is already in cpu
        
        dout = [t.to(q[i].device) for i, t in enumerate(torch.chunk(dout, pipe_degree, dim = 2))]
        out = [t.to(q[i].device) for i, t in enumerate(torch.chunk(out, pipe_degree, dim = 2))]

        

        orig_q_dtype = q[0].dtype

        doutqkvo_lse = [torch.cat([q[i], k[i], v[i], dout[i], out[i]], dim=0) for i in range(len(q))] # 5 tensors in total
        eff_bs = 5  # Fixed: This should be 5 since we concatenate 5 tensors (q, k, v, dout, out)
        
        del dout, out, q, k, v
        # doutqkvo_lse = torch.cat([doutqkvo, softmax_lse.unsqueeze(-1).repeat(eff_bs, 1, 1, 1).to(dout.dtype)], dim=3).contiguous() #piggybacking softmax_lse as the last hidden dimension for all 5 other tensors, so be mindful when unpacking
        # doutqkvo_lse  = [t for t in torch.chunk(doutqkvo_lse, pipe_degree, dim = 2)]

        softmax_lse = [t for t in torch.chunk(softmax_lse, pipe_degree, dim = 1)]

        doutqkvo_lse[0] = doutqkvo_lse[0].to(orig_device)
        with torch.cuda.stream(fetch_stream):
            doutqkvo_lse[1] = doutqkvo_lse[1].to(orig_device, non_blocking=True)
        
        PROCESS_GROUP.ULYSSES_PG = ctx.ulysses_group
        ulysses_doutqkvo_lse = all_to_all_4D(doutqkvo_lse[0], 2, 1, False, False) #note that this is scatter 2, gather 1 since this is inverse of forward pass
        
        torch.cuda.current_stream().wait_stream(fetch_stream)
        torch.cuda.current_stream().wait_stream(offload_stream)
        
        for stage in range(pipe_degree):
            # torch.cuda.current_stream().wait_stream(fetch_stream)
            # if dist.get_rank(ctx.ulysses_group) == 0:
            #     print(f"Tensor shape: {ulysses_doutqkvo_lse.shape}\t stage: {stage}")
            if stage+1 != pipe_degree:
                if stage+2 < pipe_degree:
                    fetch_stream.wait_stream(offload_stream)
                    with torch.cuda.stream(fetch_stream):
                        doutqkvo_lse[stage+2] = doutqkvo_lse[stage+2].to(orig_device, non_blocking=True)
                        # softmax_lse[stage+1].to(orig_device, non_blocking=True)
                next_ulysses_doutqkvo_lse = all_to_all_4D(doutqkvo_lse[stage+1], 2, 1, False, True) #note that this is scatter 2, gather 1 since this is inverse of forward pass
                bs, shard_seqlen, hc, hs = doutqkvo_lse[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            # Chunk the tensor to separate q, k, v, dout, out components (5 components total)
            ulysses_doutqkvo_lse = torch.chunk(ulysses_doutqkvo_lse, eff_bs, dim=0)
            assert len(ulysses_doutqkvo_lse) == eff_bs, f"Pipe: ulysses_tensors length {len(ulysses_doutqkvo_lse)} must be equal to eff_bs {eff_bs}, tensor shape {ulysses_doutqkvo_lse[0].shape}"
            
            # Extract tensors and ensure they are contiguous to avoid CUDA misalignment errors
            inp_q = ulysses_doutqkvo_lse[0]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_k = ulysses_doutqkvo_lse[1]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_v = ulysses_doutqkvo_lse[2]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_dout = ulysses_doutqkvo_lse[3]#[:,:,:,:].contiguous()
            inp_out = ulysses_doutqkvo_lse[4]#[:,:,:,:].contiguous().to(orig_q_dtype)
            
            # Extract softmax_lse and ensure contiguous after transpose
            # inp_softmax_lse = ulysses_doutqkvo_lse[-1][:,:,:,-1].contiguous().to(orig_q_dtype) # all 5 tensors have the same softmax_lse
            # inp_softmax_lse_transposed = inp_softmax_lse.transpose(1, 2).contiguous()
            inp_softmax_lse_transposed = softmax_lse[stage]#.transpose(1, 2).contiguous()

            softmax_scale = inp_q.shape[-1] ** (-0.5)
            
            # Verify tensor contiguity for debugging
            # assert inp_q.is_contiguous(), f"inp_q is not contiguous: {inp_q.stride()}"
            # assert inp_k.is_contiguous(), f"inp_k is not contiguous: {inp_k.stride()}"
            # assert inp_v.is_contiguous(), f"inp_v is not contiguous: {inp_v.stride()}"
            # assert inp_dout.is_contiguous(), f"inp_dout is not contiguous: {inp_dout.stride()}"
            # assert inp_out.is_contiguous(), f"inp_out is not contiguous: {inp_out.stride()}"
            # assert inp_softmax_lse_transposed.is_contiguous(), f"inp_softmax_lse_transposed is not contiguous: {inp_softmax_lse_transposed.stride()}"

            # assert not check_nan_inf(inp_q, f"inp_q", dist.get_rank()), f"Pipe: inp_q is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_k, f"inp_k", dist.get_rank()), f"Pipe: inp_k is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_v, f"inp_v", dist.get_rank()), f"Pipe: inp_v is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_dout, f"inp_dout", dist.get_rank()), f"Pipe: inp_dout is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_out, f"inp_out", dist.get_rank()), f"Pipe: inp_out is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_softmax_lse_transposed, f"inp_softmax_lse_transposed", dist.get_rank()), f"Pipe: inp_softmax_lse_transposed is nan or inf at stage {stage}"

            attn_dq, attn_dk, attn_dv = fullpipe_ring_flash_attn_backward(
                ctx.ring_group,
                inp_dout,
                inp_q,
                inp_k,
                inp_v,
                inp_out,
                inp_softmax_lse_transposed,
                softmax_scale=softmax_scale,
                dropout_p=ctx.dropout_p,
                causal=ctx.causal,
                window_size=ctx.window_size,
                softcap=ctx.softcap,
                alibi_slopes=ctx.alibi_slopes,
                deterministic=ctx.deterministic,
                attn_type=ctx.attn_type,
            )
            
            # not needed anymore
            doutqkvo_lse[stage] = None
            softmax_lse[stage] = None

            # assert not check_nan_inf(attn_dq, f"attn_dq -{stage}", dist.get_rank()), f"Pipe: attn_dq is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dk, f"attn_dk -{stage}", dist.get_rank()), f"Pipe: attn_dk is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dv, f"attn_dv -{stage}", dist.get_rank()), f"Pipe: attn_dv is nan or inf at stage {stage}"

            if stage > 0:
                if U_HANDLE.O_HANDLE!=[]:
                    assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
                    U_HANDLE.O_HANDLE[0].wait()
                clear_o_handle()
                assert block_grads.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_grads shape {block_grads.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                block_grads = block_grads.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                block_grads = block_grads.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                # dqkv[stage-1] = block_grads
                grads = torch.chunk(block_grads, 3, dim=0)
                dq[stage-1] = grads[0]
                dk[stage-1] = grads[1]
                dv[stage-1] = grads[2]
                if offload_stream is not None:
                    offload_stream.wait_stream(torch.cuda.current_stream())
                    with torch.cuda.stream(offload_stream):
                        dq[stage-1] = dq[stage-1].to(orig_q_devices[stage-1], non_blocking=True)
                        dk[stage-1] = dk[stage-1].to(orig_q_devices[stage-1], non_blocking=True)
                        dv[stage-1] = dv[stage-1].to(orig_q_devices[stage-1], non_blocking=True)
                #block_grads = torch.chunk(block_grads, 3, dim=0)
                # dq[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[0]
                # dk[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[1]
                # dv[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[2]
                
            
            attn_grads = torch.cat([attn_dq, attn_dk, attn_dv], dim=0)
            o_bs, o_seqlen, o_shard_hc, o_hs = attn_grads.shape
            o_hc = o_shard_hc * world_size
            o_shard_seqlen = o_seqlen // world_size

            block_grads = all_to_all_4D(
                attn_grads, 1, 2, False, True # scatter 1, gather 2
            )

            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=[]:
                    assert len(U_HANDLE.HANDLE) == 1, f"Pipe: U_HANDLE.HANDLE length {len(U_HANDLE.HANDLE)} must be 1"
                    U_HANDLE.HANDLE[0].wait()
                clear_u_handle()
                if offload_stream is not None:
                    offload_stream.wait_stream(torch.cuda.current_stream())
                    with torch.cuda.stream(offload_stream):
                        doutqkvo_lse[stage+1].to("cpu", non_blocking=True)
                        # softmax_lse[stage].to("cpu", non_blocking=True)
                assert next_ulysses_doutqkvo_lse.numel() ==  (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_doutqkvo_lse shape {next_ulysses_doutqkvo_lse.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse
        
        if U_HANDLE.O_HANDLE!=[]:
            assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
            U_HANDLE.O_HANDLE[0].wait()
            clear_o_handle()
            assert block_grads.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_grads shape {block_grads.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_grads = block_grads.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_grads = block_grads.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            # dqkv[-1] = block_grads
            grads = torch.chunk(block_grads, 3, dim=0)
            dq[-1] = grads[0]
            dk[-1] = grads[1]
            dv[-1] = grads[2]

            dq[-1] = dq[-1].to(orig_q_devices[-1])
            dk[-1] = dk[-1].to(orig_q_devices[-1])
            dv[-1] = dv[-1].to(orig_q_devices[-1])
            # assert dqkv.shape[0] == 3, f"Pipe: dqkv shape {dqkv.shape} must be 3 in the bs dimension"
            # dqkv = torch.chunk(dqkv, 3, dim=0)
            # block_grads = torch.chunk(block_grads, 3, dim=0)
            # dq[:,:,-(world_size):,:] = block_grads[0]
            # dk[:,:,-(world_size):,:] = block_grads[1]
            # dv[:,:,-(world_size):,:] = block_grads[2]

        # assert not check_nan_inf(dqkv[0], f"dqkv[0]", dist.get_rank()), f"Pipe: dqkv[0] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[1], f"dqkv[1]", dist.get_rank()), f"Pipe: dqkv[1] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[2], f"dqkv[2]", dist.get_rank()), f"Pipe: dqkv[2] is nan or inf at stage"
        # assert len(dqkv) == 3, f"Pipe: dqkv length {len(dqkv)} must be 3"
        return dq[0], dq[1], dq[2], dq[3], dk[0], dk[1], dk[2], dk[3], dv[0], dv[1], dv[2], dv[3], None, None, None, None, None, None, None, None, None, None, None, None, None, None


class FullyPipelinedRingFunc(torch.autograd.Function):
    
    @staticmethod
    def forward(
        ctx,
        q: Tensor,
        k: Tensor,
        v: Tensor,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        two_streams,
        attn_type,
    ) -> Tuple[Tensor, Tensor, None] | Tensor:  
        
        # global variables
        global process_group
        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size

        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        out, lse = fullpipe_ring_flash_attn_forward(
            q,
            k,
            v,
            softmax_scale=softmax_scale,
            dropout_p=dropout_p,
            causal=causal,
            softcap=softcap,
            deterministic=False,
        )

        ctx.save_for_backward(q, k, v, out, lse)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.attn_type = attn_type
        
        return out if not return_softmax else (out, lse, None)
        
    
    @staticmethod
    def backward(ctx, dout, *args) -> Tuple[List[Tensor], List[Tensor], List[Tensor], None, None, None, None, None, None, None, None, None, None, None, None]:

        q, k, v, out, lse = ctx.saved_tensors
        softmax_scale = ctx.softmax_scale

        attn_dq, attn_dk, attn_dv = fullpipe_ring_flash_attn_backward(
            ctx.ring_group,
            dout,
            q,
            k,
            v,
            out,
            lse,
            softmax_scale=softmax_scale,
            dropout_p=ctx.dropout_p,
            causal=ctx.causal,
            window_size=ctx.window_size,
            softcap=ctx.softcap,
            alibi_slopes=ctx.alibi_slopes,
            deterministic=ctx.deterministic,
            attn_type=ctx.attn_type,
        )
            
            
        return attn_dq, attn_dk, attn_dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None


# from torchtitan.models.llama3.model import TransformerModelArgs
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """
    Reshape frequency tensor for broadcasting it with another tensor.

    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
    for the purpose of broadcasting the frequency tensor during element-wise operations.

    The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
    and the first seqlen elements will be sliced, but dim must match x.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
        x (torch.Tensor): Target tensor for broadcasting compatibility.

    Returns:
        torch.Tensor: Reshaped frequency tensor.
    """
    ndim = x.ndim
    assert ndim > 1
    seqlen = x.shape[1]
    freqs_cis = freqs_cis[0:seqlen]
    assert freqs_cis.shape == (seqlen, x.shape[-1]), f"freqs_cis.shape: {freqs_cis.shape} != (seqlen, x.shape[-1]): {(seqlen, x.shape[-1])}"
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor = None,
    freqs_cis: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings.
        xk (torch.Tensor): Key tensor to apply rotary embeddings.
        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

    Returns:
        tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
    """
    assert xq.device != torch.device("cpu"), "xq must be on GPU"
    assert freqs_cis.device != torch.device("cpu"), "freqs_cis must be on GPU"

    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    if xk is not None:
        xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
        xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
        return xq_out.type_as(xq), xk_out.type_as(xk)
    else:
        return xq_out.type_as(xq)

@torch.library.custom_op("yunchang::_fully_fused_attn_forward", mutates_args=(), device_types="cuda")
# @torch.no_grad()
def fully_fused_attn_forward(
        x: Tensor,
        wq: Tensor,
        wk: Tensor,
        wv: Tensor,
        freqs_cis: Tensor,
        head_dim: int,
        dropout_p: float = 0,
        softmax_scale: float = 0,
        causal: bool = True,
) -> list[Tensor]:
    global two_streams, attn_type, alibi_slopes, window_size, ulysses_group
    
    bs, seqlen, hid_dim = x.shape
    n_heads = hid_dim // head_dim

    proj_dim = wq.shape[1]

    ulysses_degree = dist.get_world_size(ulysses_group)
    pipe_degree = n_heads // ulysses_degree

    q_in, k_in, v_in = [None] * pipe_degree, [None] * pipe_degree, [None] * pipe_degree
    out, lse = [None] * pipe_degree, [None] * pipe_degree

    a2a_inp = [None for _ in range(pipe_degree)]
    a2a_out = [None for _ in range(pipe_degree)]

    a2a_events = [torch.cuda.Event() for _ in range(pipe_degree)]

    q_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
    k_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
    v_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
    final_out = torch.empty([bs, seqlen, n_heads, head_dim], device=x.device, dtype=x.dtype, memory_format=torch.contiguous_format)

    final_lse = []

    wq_chunks = torch.chunk(wq, pipe_degree, dim=0)
    wk_chunks = torch.chunk(wk, pipe_degree, dim=0)
    wv_chunks = torch.chunk(wv, pipe_degree, dim=0)

    # two_streams[0].wait_stream(torch.cuda.current_stream())
    # two_streams[1].wait_stream(torch.cuda.current_stream())
    num_streams = len(two_streams)
    for i in range(num_streams):
        two_streams[i].wait_stream(torch.cuda.current_stream())

    for stage in range(pipe_degree):
        if stage == 0:
            with torch.cuda.stream(two_streams[0]):
                # q_in[stage] = x @ wq[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]
                # k_in[stage] = x @ wk[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]
                # v_in[stage] = x @ wv[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]
                q_in[stage] = F.linear(x, wq_chunks[stage])
                k_in[stage] = F.linear(x, wk_chunks[stage])
                v_in[stage] = F.linear(x, wv_chunks[stage])

                q_in[stage], k_in[stage] = apply_rotary_emb(q_in[stage].view(bs, seqlen, -1, head_dim), k_in[stage].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                v_in[stage] = v_in[stage].view(bs, seqlen, -1, head_dim)

                # a2a_inp[stage] = torch.cat([q_in[stage], k_in[stage], v_in[stage]], dim=0)
                # a2a_out[stage] = all_to_all_4D(a2a_inp[stage], 2, 1, False, False)
                # q_out[stage], k_out[stage], v_out[stage] = torch.chunk(a2a_out[stage], 3, dim=0)
                q_out[stage] = all_to_all_4D(q_in[stage], 2, 1, False, False)#, output=q_out[stage])
                k_out[stage] = all_to_all_4D(k_in[stage], 2, 1, False, False)#, output=k_out[stage])
                v_out[stage] = all_to_all_4D(v_in[stage], 2, 1, False, False)#, output=v_out[stage])

                # a2a_events[stage].record()

                q_in[stage] = None
                k_in[stage] = None
                v_in[stage] = None
                a2a_inp[stage] = None

        if stage != pipe_degree - 1:
            with torch.cuda.stream(two_streams[(stage+1)%num_streams]):
                # a2a_events[stage].wait()
                # q_in[stage+1] = x @ wq[:, (stage+1)*(proj_dim//pipe_degree):(stage+2)*(proj_dim//pipe_degree)]
                # k_in[stage+1] = x @ wk[:, (stage+1)*(proj_dim//pipe_degree):(stage+2)*(proj_dim//pipe_degree)]
                # v_in[stage+1] = x @ wv[:, (stage+1)*(proj_dim//pipe_degree):(stage+2)*(proj_dim//pipe_degree)]
                q_in[stage+1] = F.linear(x, wq_chunks[stage+1])
                k_in[stage+1] = F.linear(x, wk_chunks[stage+1])
                v_in[stage+1] = F.linear(x, wv_chunks[stage+1])

                q_in[stage+1], k_in[stage+1] = apply_rotary_emb(q_in[stage+1].view(bs, seqlen, -1, head_dim), k_in[stage+1].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                v_in[stage+1] = v_in[stage+1].view(bs, seqlen, -1, head_dim)

                
                # a2a_inp[stage+1] = torch.cat([q_in[stage+1], k_in[stage+1], v_in[stage+1]], dim=0)
                # a2a_out[stage+1] = all_to_all_4D(a2a_inp[stage+1], 2, 1, False, False)
                # q_out[stage+1], k_out[stage+1], v_out[stage+1] = torch.chunk(a2a_out[stage+1], 3, dim=0)
                q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])
                k_out[stage+1] = all_to_all_4D(k_in[stage+1], 2, 1, False, False)#, output=k_out[stage+1])
                v_out[stage+1] = all_to_all_4D(v_in[stage+1], 2, 1, False, False)#, output=v_out[stage+1])

                # a2a_events[stage+1].record()
                
                q_in[stage+1] = None
                k_in[stage+1] = None
                v_in[stage+1] = None
                a2a_inp[stage+1] = None

        with torch.cuda.stream(two_streams[stage%num_streams]):
            # a2a_events[stage].wait()
            out[stage], lse[stage] = fully_fused_ring_flash_attn_forward(
                    q_out[stage],
                    k_out[stage],
                    v_out[stage],
                    softmax_scale=softmax_scale,
                    dropout_p=dropout_p,
                    causal=causal,
                    deterministic=False,
            )

            # final_out.append(all_to_all_4D(out[stage], 1, 2, False, False))
            final_lse.append(lse[stage])
            final_out[:, :, (stage*ulysses_degree):((stage+1)*ulysses_degree), :] = all_to_all_4D(out[stage], 1, 2, False, False)
            q_out[stage] = None
            k_out[stage] = None
            v_out[stage] = None
            a2a_out[stage] = None
            
            out[stage] = None
            lse[stage] = None
        
            # torch.cuda.empty_cache()
    
    for i in range(num_streams):
        torch.cuda.current_stream().wait_stream(two_streams[i])

    # torch.cuda.empty_cache()

    # final_lse = torch.cat(final_lse, dim=1)
    # final_out = torch.cat(final_out, dim=2)
    output = [final_out]
    output.extend(final_lse)

    return output

@fully_fused_attn_forward.register_fake
def _(
        x: Tensor,
        wq: Tensor,
        wk: Tensor,
        wv: Tensor,
        freqs_cis: Tensor,
        head_dim: int,
        dropout_p: float = 0,
        softmax_scale: float = 0,
        causal: bool = True,
) -> Tuple[Tensor, Tensor]:
    
    bs, sl, d = x.shape
    out = torch.empty([bs, sl, d//head_dim, head_dim], dtype=x.dtype, device=x.device)
    lse = torch.empty([bs, d//head_dim, sl], dtype=x.dtype, device=x.device)
    return out, lse 

class FullyFusedAttnFunc(torch.autograd.Function):
    
    @staticmethod
    # @torch.no_grad()
    def forward(
        ctx,
        x: Tensor,
        wq: Tensor,
        wk: Tensor,
        wv: Tensor,
        freqs_cis: Tensor,
        head_dim: int,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        two_streams,
        attn_type,
        layer_id,
    ) -> Tuple[Tensor, Tensor, None] | Tensor | Dict[str, Any]: 
        
        bs, seqlen, hid_dim = x.shape
        n_heads = hid_dim // head_dim

        # gqa_ratio = 4 # TODO: make this dynamic

        # with torch.no_grad():
        #     for i in range(n_heads):
        #         wk[i*head_dim:(i+1)*head_dim, :] = wk[(i//gqa_ratio)*gqa_ratio*head_dim:((i//gqa_ratio)*gqa_ratio + 1)*head_dim, :]
        #         wv[i*head_dim:(i+1)*head_dim, :] = wv[(i//gqa_ratio)*gqa_ratio*head_dim:((i//gqa_ratio)*gqa_ratio + 1)*head_dim, :]

        proj_dim = wq.shape[1]

        ulysses_degree = dist.get_world_size(ulysses_group)
        pipe_degree = n_heads // ulysses_degree

        # q_out, k_out, v_out = [None] * pipe_degree, [None] * pipe_degree, [None] * pipe_degree
        # final_out = torch.zeros([bs, seqlen, n_heads, head_dim], device=x.device, dtype=x.dtype)

        # final_lse = []

        if softmax_scale is None:
            softmax_scale = head_dim ** (-0.5)
        
        # global variables
        global process_group
        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size
        current_module.ulysses_group = ulysses_group
        current_module.two_streams = two_streams

        # out_event = torch.cuda.Event()
        # assert not check_nan_inf(x, "x", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in x"
        # assert not check_nan_inf(wq, "wq", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wq"
        # assert not check_nan_inf(wk, "wk", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wk"
        # assert not check_nan_inf(wv, "wv", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wv"
        # assert not check_nan_inf(freqs_cis, "freqs_cis", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in freqs_cis"

        with torch.no_grad():
            output = fully_fused_attn_forward(x, 
                                                            wq, 
                                                            wk, 
                                                            wv, 
                                                            freqs_cis, 
                                                            head_dim, 
                                                            dropout_p, 
                                                            softmax_scale, 
                                                            causal)
            final_out = output[0]
            final_lse = output[1:]
        # assert not check_nan_inf(final_out, "final_out", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_out"
        # assert not check_nan_inf(final_lse, "final_lse", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_lse"
        # out_event.record()
        # with torch.cuda.stream(offload_stream):
        #     out_event.wait()
        #     # cpu_out = torch.empty(final_out.shape, dtype=final_out.dtype, device="cpu", pin_memory=True)
        #     # cpu_out.copy_(final_out, non_blocking=True)
        #     cpu_out = final_out.to("cpu", non_blocking=True)
        

        saved_tensor_list = [x, wq, wk, wv, freqs_cis, final_out]
        saved_tensor_list.extend(final_lse)
        ctx.save_for_backward(*saved_tensor_list)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.ulysses_group = ulysses_group
        ctx.attn_type = attn_type
        ctx.two_streams = two_streams
        ctx.pipe_degree = pipe_degree
        ctx.layer_id = layer_id

        # if layer_id == 31 and max(final_out.shape) * torch.distributed.get_world_size(ulysses_group) * torch.distributed.get_world_size(ring_group) == 4194304:
        #     torch.cuda.empty_cache()
        # ctx.gqa_ratio = gqa_ratio
        # logger.info(f"Layer {layer_id} forward CUDA memory usage: {torch.cuda.memory_summary(device=torch.cuda.current_device(), abbreviated=True)}")
        return final_out if not return_softmax else (final_out, final_lse, None)
        
    
    @staticmethod
    # @torch.no_grad()
    def backward(ctx, dout, *args) -> Tuple[Tensor, Tensor, Tensor, Tensor, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]:

        # if torch.distributed.get_rank() == 0:
        #     breakpoint()
        # torch.distributed.barrier()
        # torch.cuda.cudart().cudaProfilerStart()
        fn = select_flash_attn_impl(ctx.attn_type, stage="bwd-only")

        saved_tensors = ctx.saved_tensors
        x, wq, wk, wv, freqs_cis, final_out, *final_lse = saved_tensors

        layer_id = ctx.layer_id

        # assert not check_nan_inf(final_out, "final_out", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_out"
        # assert not check_nan_inf(final_lse, "final_lse", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_lse"
        # assert not check_nan_inf(x, "x", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in x"
        # assert not check_nan_inf(wq, "wq", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wq"
        # assert not check_nan_inf(wk, "wk", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wk"
        # assert not check_nan_inf(wv, "wv", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wv"
        # assert not check_nan_inf(freqs_cis, "freqs_cis", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in freqs_cis"
        # assert not check_nan_inf(dout, "dout", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in dout"
        
        bs, seqlen, n_heads, head_dim = final_out.shape
        ulysses_degree = dist.get_world_size(ctx.ulysses_group)
        pipe_degree = n_heads // ulysses_degree

        dx = None #torch.zeros_like(x)
        # dwq = []
        # dwk = []
        # dwv = []
        dwq = torch.empty_like(wq)
        dwk = torch.empty_like(wk)
        dwv = torch.empty_like(wv)

        proj_dim = wq.shape[1]

        two_streams = ctx.two_streams

        q_in, k_in, v_in = [None] * pipe_degree, [None] * pipe_degree, [None] * pipe_degree
        attn_dq, attn_dk, attn_dv = [None] * pipe_degree, [None] * pipe_degree, [None] * pipe_degree
        # attn_dq = [torch.empty([bs, seqlen*ulysses_degree, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        # attn_dk = [torch.empty([bs, seqlen*ulysses_degree, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        # attn_dv = [torch.empty([bs, seqlen*ulysses_degree, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        dq_out, dk_out, dv_out = [None] * pipe_degree, [None] * pipe_degree, [None] * pipe_degree

        a2a_events = [torch.cuda.Event() for _ in range(pipe_degree)]

        q_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        k_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        v_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        out_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        dout_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=dout.dtype) for _ in range(pipe_degree)]

        final_out = list(torch.chunk(final_out, pipe_degree, dim = 2))
        # final_lse = list(torch.chunk(final_lse, pipe_degree, dim = 1))
        dout = list(torch.chunk(dout, pipe_degree, dim = 2))

        attn_inp = [None for _ in range(pipe_degree)]
        attn_out = [None for _ in range(pipe_degree)]
        a2a_inp = [None for _ in range(pipe_degree)]
        a2a_out = [None for _ in range(pipe_degree)]

        # final_out_shapes = [t.shape for t in torch.chunk(final_out, pipe_degree, dim = 2)]
        # dout_shapes = [t.shape for t in torch.chunk(dout, pipe_degree, dim = 2)]
        # final_out = [t.flatten() for t in torch.chunk(final_out, pipe_degree, dim = 2)]
        # dout = [t.flatten() for t in torch.chunk(dout, pipe_degree, dim = 2)]

        wq_chunks = torch.chunk(wq, pipe_degree, dim=0)
        wk_chunks = torch.chunk(wk, pipe_degree, dim=0)
        wv_chunks = torch.chunk(wv, pipe_degree, dim=0)

        num_streams = len(two_streams)

        for i in range(len(two_streams)):
            two_streams[i].wait_stream(torch.cuda.current_stream())

        for stage in range(pipe_degree):
            if stage == 0 or num_streams == 1:
                with torch.cuda.stream(two_streams[0]):
                    # q_in[stage] = x @ wq[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]
                    # k_in[stage] = x @ wk[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]
                    # v_in[stage] = x @ wv[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]
                    q_in[stage] = F.linear(x, wq_chunks[stage])
                    k_in[stage] = F.linear(x, wk_chunks[stage])
                    v_in[stage] = F.linear(x, wv_chunks[stage])

                    q_in[stage], k_in[stage] = apply_rotary_emb(q_in[stage].view(bs, seqlen, -1, head_dim), k_in[stage].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                    v_in[stage] = v_in[stage].view(bs, seqlen, -1, head_dim)

                    # a2a_inp[stage] = torch.cat([q_in[stage], k_in[stage], v_in[stage], final_out[stage], dout[stage]], dim=0)
                    # a2a_out[stage] = all_to_all_4D(a2a_inp[stage], 2, 1, False, False)
                    # q_out[stage], k_out[stage], v_out[stage], out_out[stage], dout_out[stage] = torch.chunk(a2a_out[stage], 5, dim=0)
                    q_out[stage] = all_to_all_4D(q_in[stage], 2, 1, False, False)#, output=q_out[stage])
                    k_out[stage] = all_to_all_4D(k_in[stage], 2, 1, False, False)#, output=k_out[stage])
                    v_out[stage] = all_to_all_4D(v_in[stage], 2, 1, False, False)#, output=v_out[stage])
                    out_out[stage] = all_to_all_4D(final_out[stage], 2, 1, False, False)#, output=out_out[stage])
                    dout_out[stage] = all_to_all_4D(dout[stage], 2, 1, False, False)#, output=dout_out[stage])
                    
                    q_in[stage] = None
                    k_in[stage] = None
                    v_in[stage] = None
                    final_out[stage] = None
                    dout[stage] = None

                    a2a_inp[stage] = None

                    # a2a_events[stage].record()

            if stage != pipe_degree - 1 and num_streams > 1:
                with torch.cuda.stream(two_streams[(stage+1)%num_streams]):
                    # a2a_events[stage].wait()
                    # q_in[stage+1] = x @ wq[:, (stage+1)*(proj_dim//pipe_degree):(stage+2)*(proj_dim//pipe_degree)]
                    # k_in[stage+1] = x @ wk[:, (stage+1)*(proj_dim//pipe_degree):(stage+2)*(proj_dim//pipe_degree)]
                    # v_in[stage+1] = x @ wv[:, (stage+1)*(proj_dim//pipe_degree):(stage+2)*(proj_dim//pipe_degree)]
                    q_in[stage+1] = F.linear(x, wq_chunks[stage+1]) #XXX: this one triggers a 64MB alloc buffer, WHY??? 
                    k_in[stage+1] = F.linear(x, wk_chunks[stage+1])
                    v_in[stage+1] = F.linear(x, wv_chunks[stage+1])

                    q_in[stage+1], k_in[stage+1] = apply_rotary_emb(q_in[stage+1].view(bs, seqlen, -1, head_dim), k_in[stage+1].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                    v_in[stage+1] = v_in[stage+1].view(bs, seqlen, -1, head_dim)

                    # a2a_inp[stage+1] = torch.cat([q_in[stage+1], k_in[stage+1], v_in[stage+1], final_out[stage+1], dout[stage+1]], dim=0)
                    # a2a_out[stage+1] = all_to_all_4D(a2a_inp[stage+1], 2, 1, False, False)
                    # q_out[stage+1], k_out[stage+1], v_out[stage+1], out_out[stage+1], dout_out[stage+1] = torch.chunk(a2a_out[stage+1], 5, dim=0)
                    q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])
                    k_out[stage+1] = all_to_all_4D(k_in[stage+1], 2, 1, False, False)#, output=k_out[stage+1])
                    v_out[stage+1] = all_to_all_4D(v_in[stage+1], 2, 1, False, False)#, output=v_out[stage+1])
                    out_out[stage+1] = all_to_all_4D(final_out[stage+1], 2, 1, False, False)#, output=out_out[stage+1]) #TODO: can we make this inplace? because out/dout can't be freed really
                    dout_out[stage+1] = all_to_all_4D(dout[stage+1], 2, 1, False, False)#, output=dout_out[stage+1])

                    # a2a_events[stage+1].record()

                    q_in[stage+1] = None
                    k_in[stage+1] = None
                    v_in[stage+1] = None
                    final_out[stage+1] = None
                    dout[stage+1] = None

                    a2a_inp[stage+1] = None
            
            with torch.cuda.stream(two_streams[stage%num_streams]):
                # a2a_events[stage].wait()
                # if int(layer_id) == 30 and torch.distributed.get_rank() == 0:
                #     breakpoint()
                # torch.distributed.barrier()
                attn_dq[stage], attn_dk[stage], attn_dv[stage] = fully_fused_ring_flash_attn_backward(
                    ctx.ring_group,
                    dout_out[stage],
                    # dout[stage].view(dout_shapes[stage][0], dout_shapes[stage][1]*ulysses_degree, dout_shapes[stage][2]//ulysses_degree, -1),
                    q_out[stage],
                    k_out[stage],
                    v_out[stage],
                    out_out[stage],
                    # final_out[stage].view(final_out_shapes[stage][0], final_out_shapes[stage][1]*ulysses_degree, final_out_shapes[stage][2]//ulysses_degree, -1),
                    final_lse[stage],
                    softmax_scale=ctx.softmax_scale,
                    dropout_p=ctx.dropout_p,
                    causal=ctx.causal,
                    window_size=ctx.window_size,
                    softcap=ctx.softcap,
                    alibi_slopes=ctx.alibi_slopes,
                    deterministic=ctx.deterministic,
                    attn_type=ctx.attn_type,
                )

                
                # attn_inp[stage] = torch.cat([attn_dq[stage], attn_dk[stage], attn_dv[stage]], dim=0)
                # attn_out[stage] = all_to_all_4D(attn_inp[stage], 1, 2, False, False)
                # dq_out[stage], dk_out[stage], dv_out[stage] = torch.chunk(attn_out[stage].contiguous(), 3, dim=0)
                dq_out[stage] = all_to_all_4D(attn_dq[stage], 1, 2, False, False)#.view(bs, seqlen, -1)
                dk_out[stage] = all_to_all_4D(attn_dk[stage], 1, 2, False, False)#.view(bs, seqlen, -1)
                dv_out[stage] = all_to_all_4D(attn_dv[stage], 1, 2, False, False)#.view(bs, seqlen, -1)

                # reverse RoPE on the gradients
                dq_out[stage], dk_out[stage] = apply_rotary_emb(dq_out[stage], dk_out[stage], freqs_cis=torch.conj(freqs_cis))
                dq_out[stage] = dq_out[stage].view(bs * seqlen, -1)
                dk_out[stage] = dk_out[stage].view(bs * seqlen, -1)
                dv_out[stage] = dv_out[stage].view(bs * seqlen, -1)
                

                q_out[stage] = None
                k_out[stage] = None
                v_out[stage] = None
                out_out[stage] = None
                dout_out[stage] = None
                final_lse[stage] = None
                a2a_out[stage] = None

                attn_dq[stage] = None
                attn_dk[stage] = None
                attn_dv[stage] = None
                attn_inp[stage] = None
                # torch.cuda.empty_cache()

                # dx_q = dq_out[stage] @ wq[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)].T
                # dx_k = dk_out[stage] @ wk[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)].T
                # dx_v = dv_out[stage] @ wv[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)].T
                # dx += dx_q + dx_k + dx_v

                if dx is None:
                    dx = dq_out[stage].view(bs * seqlen, -1) @ wq_chunks[stage]
                else:
                    dx.addmm_(dq_out[stage].view(bs * seqlen, -1), wq_chunks[stage], alpha=1.0, beta=1.0)
                dx.addmm_(dk_out[stage].view(bs * seqlen, -1), wk_chunks[stage], alpha=1.0, beta=1.0)
                dx.addmm_(dv_out[stage].view(bs * seqlen, -1), wv_chunks[stage], alpha=1.0, beta=1.0)
                # if dx is None:
                #     dx = dx_temp.view(bs, seqlen, -1)
                # else:
                    # dx += dx_temp.view(bs, seqlen, -1)

                # if torch.distributed.get_rank() == 0:
                #     breakpoint()
                # torch.distributed.barrier()

                # dwq.append((x.transpose(-1, -2) @ dq_out[stage]).sum(dim = 0))
                # dwk.append((x.transpose(-1, -2) @ dk_out[stage]).sum(dim = 0))
                # dwv.append((x.transpose(-1, -2) @ dv_out[stage]).sum(dim = 0))

                dwq[stage*(head_dim*ulysses_degree):((stage+1))*(head_dim*ulysses_degree), :] = dq_out[stage].T @ x.view(bs * seqlen, -1)
                dwk[stage*(head_dim*ulysses_degree):((stage+1))*(head_dim*ulysses_degree), :] = dk_out[stage].T @ x.view(bs * seqlen, -1)
                dwv[stage*(head_dim*ulysses_degree):((stage+1))*(head_dim*ulysses_degree), :] = dv_out[stage].T @ x.view(bs * seqlen, -1)

                dq_out[stage] = None
                dk_out[stage] = None
                dv_out[stage] = None
                attn_out[stage] = None

                # torch.cuda.empty_cache()
        
        for i in range(len(two_streams)):
            torch.cuda.current_stream().wait_stream(two_streams[i])
        # torch.cuda.empty_cache()

        # for i in range(n_heads):
        #     for j in range(ctx.gqa_ratio):
        #         try:
        #             if dwk_new[i] is None:
        #                 dwk_new[i] = dwk[(i//ctx.gqa_ratio)*ctx.gqa_ratio + j]
        #                 dwv_new[i] = dwv[(i//ctx.gqa_ratio)*ctx.gqa_ratio + j]
        #             else:
        #                 dwk_new[i] += dwk[(i//ctx.gqa_ratio)*ctx.gqa_ratio + j]
        #                 dwv_new[i] += dwv[(i//ctx.gqa_ratio)*ctx.gqa_ratio + j]
        #         except IndexError:
        #             assert False, f"i is {i}, gqa_ratio is {ctx.gqa_ratio}, index is {i//ctx.gqa_ratio*ctx.gqa_ratio + j}, length is {len(dwk)}"

        # dwq = torch.cat(dwq, dim = 1)
        # dwk = torch.cat(dwk, dim = 1)
        # dwv = torch.cat(dwv, dim = 1)

        # hid_dim = wk.shape[1]

        # dwk = dwk.unsqueeze(0).reshape(ctx.gqa_ratio, -1, hid_dim).sum(dim = 0, keepdim = True).expand(ctx.gqa_ratio, hid_dim//ctx.gqa_ratio, hid_dim).reshape(hid_dim, hid_dim)
        # dwv = dwv.unsqueeze(0).reshape(ctx.gqa_ratio, -1, hid_dim).sum(dim = 0, keepdim = True).expand(ctx.gqa_ratio, hid_dim//ctx.gqa_ratio, hid_dim).reshape(hid_dim, hid_dim)

        # logger.info(f"Layer {layer_id} backward CUDA memory usage: {torch.cuda.memory_summary(device=torch.cuda.current_device(), abbreviated=True)}")

        # torch.cuda.cudart().cudaProfilerStop()
        # if torch.distributed.get_rank() == 0:
        #     breakpoint()
        # torch.distributed.barrier()
        final_lse = None
        final_out = None
        dout = None
        return dx.view(bs, seqlen, -1).to(x.dtype), dwq, dwk, dwv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None

# ------------------------------------------------------------
# ----------------------- GQA --------------------------------
# ------------------------------------------------------------

@torch.library.custom_op("yunchang::_fully_fused_attn_gqa_forward", mutates_args=(), device_types="cuda")
# @torch.no_grad()
def fully_fused_attn_gqa_forward(
        x: Tensor,
        wq: Tensor,
        wk: Tensor,
        wv: Tensor,
        freqs_cis: Tensor,
        head_dim: int,
        dropout_p: float = 0,
        softmax_scale: float = 0,
        causal: bool = True,
        layer_id: int = 0,
) -> list[Tensor]:
    global two_streams, attn_type, alibi_slopes, window_size, ulysses_group, final_lse

    freqs_cis = freqs_cis.to(x.device)
    
    bs, seqlen, hid_dim = x.shape
    n_heads = hid_dim // head_dim
    n_kv_heads = wk.shape[0] // head_dim
    gqa_ratio = n_heads // n_kv_heads

    ulysses_degree = dist.get_world_size(ulysses_group)
    pipe_degree = n_heads // ulysses_degree

    assert n_kv_heads % ulysses_degree == 0, "n_kv_heads must be divisible by ulysses_degree"

    proj_dim = wq.shape[1]

    q_in, k_in, v_in = [None] * pipe_degree, [None] * (pipe_degree//gqa_ratio), [None] * (pipe_degree//gqa_ratio)
    out, lse = [None] * pipe_degree, [None] * pipe_degree

    a2a_events = [torch.cuda.Event() for _ in range(pipe_degree)]

    # wt_idx = []
    # for stage in range(pipe_degree):
    #     if stage==0 or stage//gqa_ratio != (stage-1)//gqa_ratio:
    #         stage_idx = [(stage + i)*gqa_ratio for i in range(ulysses_degree)]
    #     else:
    #         stage_idx = [idx+1 for idx in stage_idx]

    #     for si in stage_idx:
    #         wt_idx.extend(range(si*head_dim, (si+1)*head_dim))
    
    # wq = wq[wt_idx, :]
    # wk = wk[wt_idx, :]
    # wv = wv[wt_idx, :]

    wq_chunks = torch.chunk(wq, pipe_degree, dim = 0) #0 is the output dimension
    wk_chunks = torch.chunk(wk, pipe_degree//gqa_ratio, dim = 0)
    wv_chunks = torch.chunk(wv, pipe_degree//gqa_ratio, dim = 0)

    # if torch.distributed.get_rank() == 0:
    #     breakpoint()
    # torch.distributed.barrier()

    q_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
    k_out = [None for _ in range(pipe_degree//gqa_ratio)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
    v_out = [None for _ in range(pipe_degree//gqa_ratio)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
    final_out = torch.empty([bs, seqlen, n_heads, head_dim], device=x.device, dtype=x.dtype, memory_format=torch.contiguous_format)

    final_lse = []
    # final_lse_out = torch.empty([bs, pipe_degree, seqlen*ulysses_degree], device=x.device, dtype=x.dtype, memory_format=torch.contiguous_format)
    # final_lse = list(torch.chunk(final_lse, pipe_degree, dim = 1))

    # two_streams[0].wait_stream(torch.cuda.current_stream())
    # two_streams[1].wait_stream(torch.cuda.current_stream())
    num_streams = len(two_streams)
    for i in range(num_streams):
        two_streams[i].wait_stream(torch.cuda.current_stream())

    # def run():
    for stage in range(pipe_degree):
        if stage == 0 or len(two_streams) == 1:
            with torch.cuda.stream(two_streams[0]):
                # q_in[stage+1] = x @ wq[:, (stage+1)*(proj_dim//pipe_degree):(stage+2)*(proj_dim//pipe_degree)]
                q_in[stage] = F.linear(x, wq_chunks[stage])

                if stage==0 or stage//gqa_ratio > (stage-1)//gqa_ratio:
                    # k_in[kv_idx+1] = x @ wk[:, (kv_idx+1)*(proj_dim//pipe_degree):(kv_idx+2)*(proj_dim//pipe_degree)]
                    # v_in[kv_idx+1] = x @ wv[:, (kv_idx+1)*(proj_dim//pipe_degree):(kv_idx+2)*(proj_dim//pipe_degree)]
                    k_in[(stage)//gqa_ratio] = F.linear(x, wk_chunks[(stage)//gqa_ratio])
                    v_in[(stage)//gqa_ratio] = F.linear(x, wv_chunks[(stage)//gqa_ratio])

                    q_in[stage], k_in[(stage)//gqa_ratio] = apply_rotary_emb(xq = q_in[stage].view(bs, seqlen, -1, head_dim), xk = k_in[(stage)//gqa_ratio].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                    v_in[(stage)//gqa_ratio] = v_in[(stage)//gqa_ratio].view(bs, seqlen, -1, head_dim)

                    q_out[stage] = all_to_all_4D(q_in[stage], 2, 1, False, False)#, output=q_out[stage+1])
                    k_out[(stage)//gqa_ratio] = all_to_all_4D(k_in[(stage)//gqa_ratio], 2, 1, False, False)#, output=k_out[stage+1])
                    v_out[(stage)//gqa_ratio] = all_to_all_4D(v_in[(stage)//gqa_ratio], 2, 1, False, False)#, output=v_out[stage+1])

                    a2a_events[stage].record()
                    
                    q_in[stage] = None
                    k_in[(stage)//gqa_ratio] = None
                    v_in[(stage)//gqa_ratio] = None
                else:
                    q_in[stage] = apply_rotary_emb(xq = q_in[stage].view(bs, seqlen, -1, head_dim), xk = None, freqs_cis=freqs_cis)

                    q_out[stage] = all_to_all_4D(q_in[stage], 2, 1, False, False)#, output=q_out[stage+1])

                    a2a_events[stage].record()

        if stage != pipe_degree - 1 and len(two_streams) > 1:
            with torch.cuda.stream(two_streams[(stage+1)%num_streams]):
                a2a_events[stage].wait()
                # q_in[stage+1] = x @ wq[:, (stage+1)*(proj_dim//pipe_degree):(stage+2)*(proj_dim//pipe_degree)]
                q_in[stage+1] = F.linear(x, wq_chunks[stage+1])

                if (stage+1)//gqa_ratio > stage//gqa_ratio:
                    # k_in[kv_idx+1] = x @ wk[:, (kv_idx+1)*(proj_dim//pipe_degree):(kv_idx+2)*(proj_dim//pipe_degree)]
                    # v_in[kv_idx+1] = x @ wv[:, (kv_idx+1)*(proj_dim//pipe_degree):(kv_idx+2)*(proj_dim//pipe_degree)]
                    k_in[(stage+1)//gqa_ratio] = F.linear(x, wk_chunks[(stage+1)//gqa_ratio])
                    v_in[(stage+1)//gqa_ratio] = F.linear(x, wv_chunks[(stage+1)//gqa_ratio])

                    q_in[stage+1], k_in[(stage+1)//gqa_ratio] = apply_rotary_emb(xq = q_in[stage+1].view(bs, seqlen, -1, head_dim), xk = k_in[(stage+1)//gqa_ratio].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                    v_in[(stage+1)//gqa_ratio] = v_in[(stage+1)//gqa_ratio].view(bs, seqlen, -1, head_dim)

                    q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])
                    k_out[(stage+1)//gqa_ratio] = all_to_all_4D(k_in[(stage+1)//gqa_ratio], 2, 1, False, False)#, output=k_out[stage+1])
                    v_out[(stage+1)//gqa_ratio] = all_to_all_4D(v_in[(stage+1)//gqa_ratio], 2, 1, False, False)#, output=v_out[stage+1])

                    a2a_events[stage+1].record()
                    
                    q_in[stage+1] = None
                    k_in[(stage+1)//gqa_ratio] = None
                    v_in[(stage+1)//gqa_ratio] = None
                else:
                    q_in[stage+1] = apply_rotary_emb(xq = q_in[stage+1].view(bs, seqlen, -1, head_dim), xk = None, freqs_cis=freqs_cis)

                    q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])

                    a2a_events[stage+1].record()
                    

        with torch.cuda.stream(two_streams[stage%num_streams]):
            a2a_events[stage].wait()
            k_out[stage//gqa_ratio] = k_out[stage//gqa_ratio].contiguous()
            v_out[stage//gqa_ratio] = v_out[stage//gqa_ratio].contiguous()
            
            out[stage], lse[stage] = fully_fused_ring_flash_attn_forward(
                    q_out[stage],
                    k_out[stage//gqa_ratio],
                    v_out[stage//gqa_ratio],
                    softmax_scale=softmax_scale,
                    dropout_p=dropout_p,
                    causal=causal,
                    deterministic=False,
            )

            # assert not check_nan_inf(out[stage], "out[stage]", torch.distributed.get_rank()), f"NaN detected in out[stage] stage {stage}"
            # assert not check_nan_inf(lse[stage], "lse[stage]", torch.distributed.get_rank()), f"NaN detected in lse[stage] stage {stage}"

            # final_out.append(all_to_all_4D(out[stage], 1, 2, False, False))
            # lse[stage].record_stream(torch.cuda.current_stream())
            final_lse.append(lse[stage])
            # assert final_lse.shape[0] == 1, f"Not expecting batch size > 1"
            # final_lse[:, stage, :] = lse[stage][:, 0, :]
            # final_lse[stage] = lse[stage]
            final_out[:, :, (stage*ulysses_degree):((stage+1)*ulysses_degree), :] = all_to_all_4D(out[stage], 1, 2, False, False)
            # if stage==0 or stage//gqa_ratio != (stage-1)//gqa_ratio:
            #     stage_idx = [(stage + i)*gqa_ratio for i in range(ulysses_degree)]
            # else:
            #     stage_idx = [idx+1 for idx in stage_idx]
            # final_out[:, :, stage_idx, :] = all_to_all_4D(out[stage], 1, 2, False, False)
            q_out[stage] = None
            if (stage+1)//gqa_ratio != stage//gqa_ratio:
                k_out[stage//gqa_ratio] = None
                v_out[stage//gqa_ratio] = None
            
            out[stage] = None
            # lse[stage] = None
        
            # torch.cuda.empty_cache()
    
    
    # global GRAPH_FLAGS
    # if len(GRAPH_FLAGS) <= layer_id:
    #     run_graph = torch.cuda.make_graphed_callables(run, ())
    #     GRAPH_FLAGS.append(run_graph)
    # else:
    #     run_graph = GRAPH_FLAGS[layer_id]
    # final_out, final_lse = run()
    for i in range(num_streams):
        torch.cuda.current_stream().wait_stream(two_streams[i])

    # torch.cuda.empty_cache()

    # final_lse = torch.cat(final_lse, dim=1)

    # final_out = torch.cat(final_out, dim=2)
    output = [final_out]
    output.extend(final_lse)

    return output

@fully_fused_attn_gqa_forward.register_fake
def _(
        x: Tensor,
        wq: Tensor,
        wk: Tensor,
        wv: Tensor,
        freqs_cis: Tensor,
        head_dim: int,
        dropout_p: float = 0,
        softmax_scale: float = 0,
        causal: bool = True,
        layer_id: int = 0,
) -> Tuple[Tensor, Tensor]:
    
    bs, sl, d = x.shape
    out = torch.empty([bs, sl, d//head_dim, head_dim], dtype=x.dtype, device=x.device)
    lse = torch.empty([bs, d//head_dim, sl], dtype=x.dtype, device=x.device)
    return out, lse
    

class FullyFusedAttnGQAFunc(torch.autograd.Function):
    
    @staticmethod
    # @torch.no_grad()
    def forward(
        ctx,
        x: Tensor,
        wq: Tensor,
        wk: Tensor,
        wv: Tensor,
        freqs_cis: Tensor,
        head_dim: int,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        two_streams,
        attn_type,
        layer_id,
    ) -> Tuple[Tensor, Tensor, None] | Tensor | Dict[str, Any]: 
        
        bs, seqlen, hid_dim = x.shape
        n_heads = hid_dim // head_dim

        proj_dim = wq.shape[1]

        ulysses_degree = dist.get_world_size(ulysses_group)
        pipe_degree = n_heads // ulysses_degree

        # q_out, k_out, v_out = [None] * pipe_degree, [None] * pipe_degree, [None] * pipe_degree
        # final_out = torch.zeros([bs, seqlen, n_heads, head_dim], device=x.device, dtype=x.dtype)

        # final_lse = []

        if softmax_scale is None:
            softmax_scale = head_dim ** (-0.5)
        
        # global variables
        global process_group
        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size
        current_module.ulysses_group = ulysses_group
        current_module.two_streams = two_streams

        # out_event = torch.cuda.Event()
        # assert not check_nan_inf(x, "x", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in x"
        # assert not check_nan_inf(wq, "wq", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wq"
        # assert not check_nan_inf(wk, "wk", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wk"
        # assert not check_nan_inf(wv, "wv", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wv"
        # assert not check_nan_inf(freqs_cis, "freqs_cis", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in freqs_cis"

        with torch.no_grad():
            
            output = fully_fused_attn_gqa_forward(x, 
                                                                wq, 
                                                                wk, 
                                                                wv, 
                                                                freqs_cis, 
                                                                head_dim, 
                                                                dropout_p, 
                                                                softmax_scale, 
                                                                causal,
                                                                layer_id)
            final_out = output[0]
            final_lse = output[1:]
            
        # assert not check_nan_inf(final_out, "final_out", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_out"
        # assert not check_nan_inf(final_lse, "final_lse", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_lse"
        # out_event.record()
        # with torch.cuda.stream(offload_stream):
        #     out_event.wait()
        #     # cpu_out = torch.empty(final_out.shape, dtype=final_out.dtype, device="cpu", pin_memory=True)
        #     # cpu_out.copy_(final_out, non_blocking=True)
        #     cpu_out = final_out.to("cpu", non_blocking=True)

        saved_tensor_list = [x, wq, wk, wv, freqs_cis, final_out]
        saved_tensor_list.extend(final_lse)
        

        ctx.save_for_backward(*saved_tensor_list)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.ulysses_group = ulysses_group
        ctx.attn_type = attn_type
        ctx.two_streams = two_streams
        ctx.pipe_degree = pipe_degree
        ctx.layer_id = layer_id

        # if layer_id == 31 and max(final_out.shape) * torch.distributed.get_world_size(ulysses_group) * torch.distributed.get_world_size(ring_group) == 4194304:
        #     torch.cuda.empty_cache()
        
        return final_out if not return_softmax else (final_out, final_lse, None)
        
    
    @staticmethod
    # @torch.no_grad()
    def backward(ctx, dout, *args) -> Tuple[Tensor, Tensor, Tensor, Tensor, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]:

        # torch.cuda.cudart().cudaProfilerStart()
        fn = select_flash_attn_impl(ctx.attn_type, stage="bwd-only")

        saved_tensors = ctx.saved_tensors
        x, wq, wk, wv, freqs_cis, final_out, *final_lse = saved_tensors

        freqs_cis = freqs_cis.to(x.device)

        layer_id = ctx.layer_id

        # assert not check_nan_inf(final_out, "final_out", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_out"
        # assert not check_nan_inf(final_lse, "final_lse", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_lse"
        # assert not check_nan_inf(x, "x", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in x"
        # assert not check_nan_inf(wq, "wq", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wq"
        # assert not check_nan_inf(wk, "wk", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wk"
        # assert not check_nan_inf(wv, "wv", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wv"
        # assert not check_nan_inf(freqs_cis, "freqs_cis", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in freqs_cis"
        # assert not check_nan_inf(dout, "dout", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in dout"
        
        bs, seqlen, n_heads, head_dim = final_out.shape
        ulysses_degree = dist.get_world_size(ctx.ulysses_group)
        pipe_degree = n_heads // ulysses_degree
        n_kv_heads = wk.shape[0] // head_dim
        gqa_ratio = n_heads // n_kv_heads

        # dx = torch.zeros_like(x)
        # dx = torch.empty_like(x)
        dx = None
        # dwq = []
        # dwk = []
        # dwv = []
        dwq = torch.empty_like(wq)
        dwk = torch.empty_like(wk)
        dwv = torch.empty_like(wv)

        proj_dim = wq.shape[1]

        two_streams = ctx.two_streams

        q_in, k_in, v_in = [None] * pipe_degree, [None] * (pipe_degree//gqa_ratio), [None] * (pipe_degree//gqa_ratio)
        attn_dq, attn_dk, attn_dv = [None] * pipe_degree, [None] * pipe_degree, [None] * pipe_degree
        # attn_dq = [torch.empty([bs, seqlen*ulysses_degree, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        # attn_dk = [torch.empty([bs, seqlen*ulysses_degree, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        # attn_dv = [torch.empty([bs, seqlen*ulysses_degree, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        dq_out, dk_out, dv_out = [None] * pipe_degree, [None] * (pipe_degree), [None] * (pipe_degree)

        a2a_events = [torch.cuda.Event() for _ in range(pipe_degree)]

        # wt_idx = []
        # for stage in range(pipe_degree):
        #     if stage==0 or stage//gqa_ratio != (stage-1)//gqa_ratio:
        #         stage_idx = [(stage + i)*gqa_ratio for i in range(ulysses_degree)]
        #     else:
        #         stage_idx = [idx+1 for idx in stage_idx]

        #     for si in stage_idx:
        #         wt_idx.extend(range(si*head_dim, (si+1)*head_dim))
        
        # wq = wq[wt_idx, :]
        # wk = wk[wt_idx, :]
        # wv = wv[wt_idx, :]
        
        wq_chunks = torch.chunk(wq, pipe_degree, dim = 0) #0 is the output dimension
        wk_chunks = torch.chunk(wk, pipe_degree//gqa_ratio, dim = 0)
        wv_chunks = torch.chunk(wv, pipe_degree//gqa_ratio, dim = 0)

        q_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        k_out = [None for _ in range(pipe_degree//gqa_ratio)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        v_out = [None for _ in range(pipe_degree//gqa_ratio)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        out_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        dout_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=dout.dtype) for _ in range(pipe_degree)]

        # final_out_idx = []
        # for stage in range(pipe_degree):
        #     if stage==0 or stage//gqa_ratio != (stage-1)//gqa_ratio:
        #             stage_idx = [(stage + i)*gqa_ratio for i in range(ulysses_degree)]
        #     else:
        #         stage_idx = [idx+1 for idx in stage_idx]
        #     final_out_idx.extend(stage_idx)

        # final_out = final_out[:, :, final_out_idx, :]
        # dout = dout[:, :, final_out_idx, :]

        final_out = list(torch.chunk(final_out, pipe_degree, dim = 2))
        # final_lse = list(torch.chunk(final_lse, pipe_degree, dim = 1))
        dout = list(torch.chunk(dout, pipe_degree, dim = 2))

        # xT = x.reshape(bs * seqlen, -1).transpose(0, 1).contiguous()

        num_streams = len(two_streams)

        for i in range(len(two_streams)):
            two_streams[i].wait_stream(torch.cuda.current_stream())

        for stage in range(pipe_degree):
            if stage == 0 or len(two_streams) == 1:
                with torch.cuda.stream(two_streams[0]):
                    # q_in[stage] = x @ wq[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]
                    # k_in[stage] = x @ wk[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]
                    # v_in[stage] = x @ wv[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]
                    q_in[stage] = F.linear(x, wq_chunks[stage])
                    if stage==0 or stage//gqa_ratio > (stage-1)//gqa_ratio:
                        # k_in[(stage+1)//gqa_ratio] = x @ wk[:, ((stage+1)//gqa_ratio)*(proj_dim//pipe_degree):(((stage+1)//gqa_ratio)+1)*(proj_dim//pipe_degree)]
                        # v_in[(stage+1)//gqa_ratio] = x @ wv[:, ((stage+1)//gqa_ratio)*(proj_dim//pipe_degree):(((stage+1)//gqa_ratio)+1)*(proj_dim//pipe_degree)]
                        k_in[(stage)//gqa_ratio] = F.linear(x, wk_chunks[(stage)//gqa_ratio])
                        v_in[(stage)//gqa_ratio] = F.linear(x, wv_chunks[(stage)//gqa_ratio])

                        q_in[stage], k_in[(stage)//gqa_ratio] = apply_rotary_emb(xq = q_in[stage].view(bs, seqlen, -1, head_dim), xk = k_in[(stage)//gqa_ratio].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                        v_in[(stage)//gqa_ratio] = v_in[(stage)//gqa_ratio].view(bs, seqlen, -1, head_dim)

                        q_out[stage] = all_to_all_4D(q_in[stage], 2, 1, False, False)#, output=q_out[stage+1])
                        k_out[(stage)//gqa_ratio] = all_to_all_4D(k_in[(stage)//gqa_ratio], 2, 1, False, False)#, output=k_out[stage+1])
                        v_out[(stage)//gqa_ratio] = all_to_all_4D(v_in[(stage)//gqa_ratio], 2, 1, False, False)#, output=v_out[stage+1])

                        out_out[stage] = all_to_all_4D(final_out[stage], 2, 1, False, False)#, output=out_out[stage+1])
                        dout_out[stage] = all_to_all_4D(dout[stage], 2, 1, False, False)#, output=dout_out[stage+1])

                        a2a_events[stage+1].record()

                        q_in[stage] = None
                        k_in[(stage)//gqa_ratio] = None
                        v_in[(stage)//gqa_ratio] = None

                        final_out[stage] = None
                        dout[stage] = None
                    else:
                        q_in[stage] = apply_rotary_emb(xq = q_in[stage].view(bs, seqlen, -1, head_dim), xk = None, freqs_cis=freqs_cis)
                        
                        q_out[stage] = all_to_all_4D(q_in[stage], 2, 1, False, False)#, output=q_out[stage+1])
                        out_out[stage] = all_to_all_4D(final_out[stage], 2, 1, False, False)#, output=out_out[stage+1])
                        dout_out[stage] = all_to_all_4D(dout[stage], 2, 1, False, False)#, output=dout_out[stage+1])

                        a2a_events[stage].record()

                        q_in[stage] = None
                        final_out[stage] = None
                        dout[stage] = None

            if stage != pipe_degree - 1 and len(two_streams) > 1:
                with torch.cuda.stream(two_streams[(stage+1)%num_streams]):
                    a2a_events[stage].wait()
                    # q_in[stage+1] = x @ wq[:, (stage+1)*(proj_dim//pipe_degree):(stage+2)*(proj_dim//pipe_degree)]
                    q_in[stage+1] = F.linear(x, wq_chunks[stage+1])
                    if (stage+1)//gqa_ratio > stage//gqa_ratio:
                        # k_in[(stage+1)//gqa_ratio] = x @ wk[:, ((stage+1)//gqa_ratio)*(proj_dim//pipe_degree):(((stage+1)//gqa_ratio)+1)*(proj_dim//pipe_degree)]
                        # v_in[(stage+1)//gqa_ratio] = x @ wv[:, ((stage+1)//gqa_ratio)*(proj_dim//pipe_degree):(((stage+1)//gqa_ratio)+1)*(proj_dim//pipe_degree)]
                        k_in[(stage+1)//gqa_ratio] = F.linear(x, wk_chunks[(stage+1)//gqa_ratio])
                        v_in[(stage+1)//gqa_ratio] = F.linear(x, wv_chunks[(stage+1)//gqa_ratio])

                        q_in[stage+1], k_in[(stage+1)//gqa_ratio] = apply_rotary_emb(xq = q_in[stage+1].view(bs, seqlen, -1, head_dim), xk = k_in[(stage+1)//gqa_ratio].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                        v_in[(stage+1)//gqa_ratio] = v_in[(stage+1)//gqa_ratio].view(bs, seqlen, -1, head_dim)

                        q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])
                        k_out[(stage+1)//gqa_ratio] = all_to_all_4D(k_in[(stage+1)//gqa_ratio], 2, 1, False, False)#, output=k_out[stage+1])
                        v_out[(stage+1)//gqa_ratio] = all_to_all_4D(v_in[(stage+1)//gqa_ratio], 2, 1, False, False)#, output=v_out[stage+1])

                        out_out[stage+1] = all_to_all_4D(final_out[stage+1], 2, 1, False, False)#, output=out_out[stage+1])
                        dout_out[stage+1] = all_to_all_4D(dout[stage+1], 2, 1, False, False)#, output=dout_out[stage+1])

                        a2a_events[stage+1].record()

                        q_in[stage+1] = None
                        k_in[(stage+1)//gqa_ratio] = None
                        v_in[(stage+1)//gqa_ratio] = None

                        final_out[stage+1] = None
                        dout[stage+1] = None
                    else:
                        q_in[stage+1] = apply_rotary_emb(xq = q_in[stage+1].view(bs, seqlen, -1, head_dim), xk = None, freqs_cis=freqs_cis)
                        
                        q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])
                        out_out[stage+1] = all_to_all_4D(final_out[stage+1], 2, 1, False, False)#, output=out_out[stage+1])
                        dout_out[stage+1] = all_to_all_4D(dout[stage+1], 2, 1, False, False)#, output=dout_out[stage+1])

                        a2a_events[stage+1].record()

                        q_in[stage+1] = None
                        final_out[stage+1] = None
                        dout[stage+1] = None

            
            with torch.cuda.stream(two_streams[stage%num_streams]):
                a2a_events[stage].wait()
                # if int(layer_id) == 30 and torch.distributed.get_rank() == 0:
                #     breakpoint()
                # torch.distributed.barrier()
                k_out[stage//gqa_ratio] = k_out[stage//gqa_ratio].contiguous()
                v_out[stage//gqa_ratio] = v_out[stage//gqa_ratio].contiguous()

                attn_dq[stage], attn_dk[stage], attn_dv[stage] = fully_fused_ring_flash_attn_backward(
                    ctx.ring_group,
                    dout_out[stage],
                    q_out[stage],
                    k_out[stage//gqa_ratio],
                    v_out[stage//gqa_ratio],
                    out_out[stage],
                    final_lse[stage],
                    softmax_scale=ctx.softmax_scale,
                    dropout_p=ctx.dropout_p,
                    causal=ctx.causal,
                    window_size=ctx.window_size,
                    softcap=ctx.softcap,
                    alibi_slopes=ctx.alibi_slopes,
                    deterministic=ctx.deterministic,
                    attn_type=ctx.attn_type,
                )

                # assert not check_nan_inf(attn_dq[stage], "attn_dq[stage]", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in attn_dq[stage]"
                # assert not check_nan_inf(attn_dk[stage], "attn_dk[stage]", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in attn_dk[stage]"
                # assert not check_nan_inf(attn_dv[stage], "attn_dv[stage]", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in attn_dv[stage]"

                dq_out[stage] = all_to_all_4D(attn_dq[stage], 1, 2, False, False)#.view(bs, seqlen, -1)
                dk_out[stage] = all_to_all_4D(attn_dk[stage], 1, 2, False, False).view(bs, seqlen, -1)
                dv_out[stage] = all_to_all_4D(attn_dv[stage], 1, 2, False, False).view(bs, seqlen, -1)

                # if torch.distributed.get_rank() == 0:
                #     breakpoint()
                # torch.distributed.barrier()

                dq_out[stage] = apply_rotary_emb(dq_out[stage], freqs_cis=torch.conj(freqs_cis))
                dq_out[stage] = dq_out[stage].view(bs, seqlen, -1)

                if dk_out[stage//gqa_ratio] is None:
                    dk_out[stage//gqa_ratio] = dk_out[stage]
                    dv_out[stage//gqa_ratio] = dv_out[stage]
                else:
                    dk_out[stage//gqa_ratio] += dk_out[stage]
                    dv_out[stage//gqa_ratio] += dv_out[stage]
                # assert not check_nan_inf(dk_out[stage//gqa_ratio], f"dk_out[stage//gqa_ratio] stage {stage} Layer {layer_id}", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in dk_out[stage//gqa_ratio]"
                # assert not check_nan_inf(dv_out[stage//gqa_ratio], f"dv_out[stage//gqa_ratio] stage {stage} Layer {layer_id}", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in dv_out[stage//gqa_ratio]"


                q_out[stage] = None
                if (stage+1)//gqa_ratio != stage//gqa_ratio:
                    k_out[stage//gqa_ratio] = None
                    v_out[stage//gqa_ratio] = None
                out_out[stage] = None
                dout_out[stage] = None
                final_lse[stage] = None

                attn_dq[stage] = None
                attn_dk[stage] = None
                attn_dv[stage] = None
                # torch.cuda.empty_cache()

                # K/V at group boundary
                if dx is None:
                    dx = dq_out[stage].view(bs * seqlen, -1) @ wq_chunks[stage]
                else:
                    dx.addmm_(dq_out[stage].view(bs * seqlen, -1), wq_chunks[stage], alpha=1.0, beta=1.0)
                if (stage+1)//gqa_ratio != stage//gqa_ratio or stage == pipe_degree - 1:
                    dk = apply_rotary_emb(dk_out[stage//gqa_ratio].view(bs, seqlen, -1, head_dim), freqs_cis=torch.conj(freqs_cis)).view(bs * seqlen, -1)
                    dv = dv_out[stage//gqa_ratio].view(bs * seqlen, -1)

                    dx.addmm_(dk, wk_chunks[stage//gqa_ratio], alpha=1.0, beta=1.0)
                    dx.addmm_(dv, wv_chunks[stage//gqa_ratio], alpha=1.0, beta=1.0)
                    
                    # dwk.append(dk.T @ x.view(bs * seqlen, -1))                                  
                    # dwv.append(dv.T @ x.view(bs * seqlen, -1))
                    dwk[stage//gqa_ratio*(head_dim*ulysses_degree):((stage+1)//gqa_ratio)*(head_dim*ulysses_degree), :] = dk.T @ x.view(bs * seqlen, -1)
                    dwv[stage//gqa_ratio*(head_dim*ulysses_degree):((stage+1)//gqa_ratio)*(head_dim*ulysses_degree), :] = dv.T @ x.view(bs * seqlen, -1)
                
                # if dx is None:
                #     dx = dx_temp.view(bs, seqlen, -1)
                # else:
                #     dx += dx_temp.view(bs, seqlen, -1)

                # Q
                dq = dq_out[stage].view(bs * seqlen, -1)
                # dwq.append(dq.T @ x.view(bs * seqlen, -1))
                dwq[stage*(head_dim*ulysses_degree):(stage+1)*(head_dim*ulysses_degree), :] = dq.T @ x.view(bs * seqlen, -1)
                     

                dq_out[stage] = None
                dk_out[stage] = None
                dv_out[stage] = None

                if (stage+1)//gqa_ratio != stage//gqa_ratio:
                    dk_out[stage//gqa_ratio] = None
                    dv_out[stage//gqa_ratio] = None

                # torch.cuda.empty_cache()
        
        for i in range(len(two_streams)):
            torch.cuda.current_stream().wait_stream(two_streams[i])
        # torch.cuda.empty_cache()

        # dwq = torch.cat(dwq, dim = 0)
        # if len(dwk) > 0:
        #     dwk = torch.cat(dwk, dim = 0)
        #     dwv = torch.cat(dwv, dim = 0)
        # else:
        #     dwk = dwk[0]
        #     dwv = dwv[0]
        # torch.cuda.cudart().cudaProfilerStop()
        return dx.view(bs, seqlen, -1).to(x.dtype), dwq, dwk, dwv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None

# -------------------------------------------------------------
# ----------------------- Ultra-Fused GQA ---------------------
# -------------------------------------------------------------

@torch.library.custom_op("yunchang::_ultra_fused_attn_gqa_forward", mutates_args=(), device_types="cuda")
# @torch.no_grad()
def ultra_fused_attn_gqa_forward(
        x_in: Tensor,
        w_rms: Tensor,
        eps_rms: float,
        wq: Tensor,
        wk: Tensor,
        wv: Tensor,
        freqs_cis: Tensor,
        head_dim: int,
        dropout_p: float = 0,
        softmax_scale: float = 0,
        causal: bool = True,
) -> Tuple[Tensor, Tensor]:
    global two_streams, attn_type, alibi_slopes, window_size, ulysses_group

    # @torch.no_grad()
    # @torch.compile()
    def rms(x):
        # Use PyTorch functional to mirror intrinsic behavior and eps semantics
        # Old manual implementation kept for reference:
        # x_rms = torch.empty_like(x)
        # num_chunks = x.shape[1] // x.shape[2]
        # x_chunks = list(torch.chunk(x, num_chunks, dim=1))
        # for i, x_c in enumerate(x_chunks):
        #     mean_square = x_c.pow(2).mean(dim=-1, keepdim=True)
        #     rms = torch.sqrt(mean_square + eps_rms)
        #     norm_x = x_c / rms
        #     x_rms[:, i*x.shape[2]:(i+1)*x.shape[2], :] = norm_x * w_rms
        # return x_rms
        # return F.rms_norm(x, [x.shape[-1]], weight=w_rms, eps=eps_rms)
        bs, seqlen, hidden = x.shape
        shards = max(1, (seqlen + hidden - 1) // hidden)

        x_shards = list(torch.chunk(x, chunks=shards, dim=1))
        with torch.no_grad():
            y_shards = []
            for xs in x_shards:
                # Ensure weight is on the same device/dtype as xs to mirror torch semantics
                w = w_rms
                if w.device != xs.device or w.dtype != xs.dtype:
                    w = w.to(dtype=xs.dtype, device=xs.device)
                y_shards.append(F.rms_norm(xs, [xs.shape[-1]], weight=w, eps=eps_rms))
        y = torch.cat(y_shards, dim=1)
        return y
    
    # # if torch.distributed.get_rank() == 0:
    # #     breakpoint()
    # # torch.distributed.barrier()
    
    x = rms(x_in)

    freqs_cis = freqs_cis.to(x.device)
    
    bs, seqlen, hid_dim = x.shape
    n_heads = hid_dim // head_dim
    n_kv_heads = wk.shape[1] // head_dim
    gqa_ratio = n_heads // n_kv_heads

    ulysses_degree = dist.get_world_size(ulysses_group)
    pipe_degree = n_heads // ulysses_degree

    assert n_kv_heads % ulysses_degree == 0, "n_kv_heads must be divisible by ulysses_degree"

    proj_dim = wq.shape[1]

    q_in, k_in, v_in = [None] * pipe_degree, [None] * (pipe_degree//gqa_ratio), [None] * (pipe_degree//gqa_ratio)
    out, lse = [None] * pipe_degree, [None] * pipe_degree

    a2a_events = [torch.cuda.Event() for _ in range(pipe_degree)]

    q_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
    k_out = [None for _ in range(pipe_degree//gqa_ratio)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
    v_out = [None for _ in range(pipe_degree//gqa_ratio)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
    # final_out = torch.empty([bs, seqlen, n_heads, head_dim], device=x.device, dtype=x.dtype, memory_format=torch.contiguous_format)
    final_out = x_in.clone().detach().reshape(bs, seqlen, n_heads, head_dim)

    final_lse = []

    kv_idx = -1

    # two_streams[0].wait_stream(torch.cuda.current_stream())
    # two_streams[1].wait_stream(torch.cuda.current_stream())
    num_streams = len(two_streams)
    for i in range(num_streams):
        two_streams[i].wait_stream(torch.cuda.current_stream())

    for stage in range(pipe_degree):
        if stage == 0:
            with torch.cuda.stream(two_streams[0]):
                q_in[stage] = x @ wq[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]
                k_in[stage] = x @ wk[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]
                v_in[stage] = x @ wv[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]

                q_in[stage], k_in[stage] = apply_rotary_emb(q_in[stage].view(bs, seqlen, -1, head_dim), k_in[stage].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                v_in[stage] = v_in[stage].view(bs, seqlen, -1, head_dim)

                q_out[stage] = all_to_all_4D(q_in[stage], 2, 1, False, False)#, output=q_out[stage])
                k_out[stage] = all_to_all_4D(k_in[stage], 2, 1, False, False)#, output=k_out[stage])
                v_out[stage] = all_to_all_4D(v_in[stage], 2, 1, False, False)#, output=v_out[stage])

                kv_idx += 1

                a2a_events[stage].record()

                q_in[stage] = None
                k_in[stage] = None
                v_in[stage] = None

        if stage != pipe_degree - 1:
            with torch.cuda.stream(two_streams[(stage+1)%num_streams]):
                a2a_events[stage].wait()
                q_in[stage+1] = x @ wq[:, (stage+1)*(proj_dim//pipe_degree):(stage+2)*(proj_dim//pipe_degree)]

                if (stage+1)//gqa_ratio > kv_idx:
                    k_in[kv_idx+1] = x @ wk[:, (kv_idx+1)*(proj_dim//pipe_degree):(kv_idx+2)*(proj_dim//pipe_degree)]
                    v_in[kv_idx+1] = x @ wv[:, (kv_idx+1)*(proj_dim//pipe_degree):(kv_idx+2)*(proj_dim//pipe_degree)]

                    q_in[stage+1], k_in[kv_idx+1] = apply_rotary_emb(xq = q_in[stage+1].view(bs, seqlen, -1, head_dim), xk = k_in[kv_idx+1].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                    v_in[kv_idx+1] = v_in[kv_idx+1].view(bs, seqlen, -1, head_dim)

                    q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])
                    k_out[kv_idx+1] = all_to_all_4D(k_in[kv_idx+1], 2, 1, False, False)#, output=k_out[stage+1])
                    v_out[kv_idx+1] = all_to_all_4D(v_in[kv_idx+1], 2, 1, False, False)#, output=v_out[stage+1])

                    a2a_events[stage+1].record()

                    kv_idx += 1
                    
                    q_in[stage+1] = None
                    k_in[kv_idx+1] = None
                    v_in[kv_idx+1] = None
                else:
                    q_in[stage+1] = apply_rotary_emb(xq = q_in[stage+1].view(bs, seqlen, -1, head_dim), xk = None, freqs_cis=freqs_cis)

                    q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])

                    a2a_events[stage+1].record()
                    

        with torch.cuda.stream(two_streams[stage%num_streams]):
            a2a_events[stage].wait()
            out[stage], lse[stage] = fully_fused_ring_flash_attn_forward(
                    q_out[stage],
                    k_out[stage//gqa_ratio],
                    v_out[stage//gqa_ratio],
                    softmax_scale=softmax_scale,
                    dropout_p=dropout_p,
                    causal=causal,
                    deterministic=False,
            )

            # final_out.append(all_to_all_4D(out[stage], 1, 2, False, False))
            final_lse.append(lse[stage])
            final_out[:, :, (stage*ulysses_degree):((stage+1)*ulysses_degree), :] += all_to_all_4D(out[stage], 1, 2, False, False)
            q_out[stage] = None
            if (stage+1)//gqa_ratio != stage//gqa_ratio:
                k_out[stage//gqa_ratio] = None
                v_out[stage//gqa_ratio] = None
            
            out[stage] = None
            lse[stage] = None
        
            # torch.cuda.empty_cache()
    
    for i in range(num_streams):
        torch.cuda.current_stream().wait_stream(two_streams[i])

    # torch.cuda.empty_cache()
    # if torch.distributed.get_rank() == 0:
    #     breakpoint()
    # torch.distributed.barrier()

    final_lse = torch.cat(final_lse, dim=1)
    # final_out = torch.cat(final_out, dim=2)

    return final_out, final_lse

@ultra_fused_attn_gqa_forward.register_fake
def _(
        x: Tensor,
        w_rms: Tensor,
        eps_rms: float,
        wq: Tensor,
        wk: Tensor,
        wv: Tensor,
        freqs_cis: Tensor,
        head_dim: int,
        dropout_p: float = 0,
        softmax_scale: float = 0,
        causal: bool = True,
) -> Tuple[Tensor, Tensor]:
    
    bs, sl, d = x.shape
    out = torch.empty([bs, sl, d//head_dim, head_dim], dtype=x.dtype, device=x.device)
    lse = torch.empty([bs, d//head_dim, sl], dtype=x.dtype, device=x.device)
    return out, lse
    

class UltraFusedAttnGQAFunc(torch.autograd.Function):
    
    @staticmethod
    # @torch.no_grad()
    def forward(
        ctx,
        x: Tensor,
        w_rms: Tensor,
        eps_rms: float,
        wq: Tensor,
        wk: Tensor,
        wv: Tensor,
        freqs_cis: Tensor,
        head_dim: int,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        two_streams,
        attn_type,
        layer_id,
    ) -> Tuple[Tensor, Tensor, None] | Tensor | Dict[str, Any]: 
        
        bs, seqlen, hid_dim = x.shape
        n_heads = hid_dim // head_dim

        proj_dim = wq.shape[1]

        ulysses_degree = dist.get_world_size(ulysses_group)
        pipe_degree = n_heads // ulysses_degree

        # q_out, k_out, v_out = [None] * pipe_degree, [None] * pipe_degree, [None] * pipe_degree
        # final_out = torch.zeros([bs, seqlen, n_heads, head_dim], device=x.device, dtype=x.dtype)

        # final_lse = []

        if softmax_scale is None:
            softmax_scale = head_dim ** (-0.5)
        
        # global variables
        global process_group
        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size
        current_module.ulysses_group = ulysses_group
        current_module.two_streams = two_streams

        # out_event = torch.cuda.Event()
        # assert not check_nan_inf(x, "x", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in x"
        # assert not check_nan_inf(wq, "wq", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wq"
        # assert not check_nan_inf(wk, "wk", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wk"
        # assert not check_nan_inf(wv, "wv", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wv"
        # assert not check_nan_inf(freqs_cis, "freqs_cis", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in freqs_cis"

        final_out, final_lse = ultra_fused_attn_gqa_forward(x, 
                                                            w_rms,
                                                            eps_rms,
                                                            wq, 
                                                            wk, 
                                                            wv, 
                                                            freqs_cis, 
                                                            head_dim, 
                                                            dropout_p, 
                                                            softmax_scale, 
                                                            causal)
        assert not check_nan_inf(final_out, f"final_out {layer_id}", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_out"
        assert not check_nan_inf(final_lse, f"final_lse {layer_id}", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_lse"
        # out_event.record()
        # with torch.cuda.stream(offload_stream):
        #     out_event.wait()
        #     # cpu_out = torch.empty(final_out.shape, dtype=final_out.dtype, device="cpu", pin_memory=True)
        #     # cpu_out.copy_(final_out, non_blocking=True)
        #     cpu_out = final_out.to("cpu", non_blocking=True)
        

        ctx.save_for_backward(x, w_rms, wq, wk, wv, freqs_cis, final_out, final_lse)
        ctx.eps_rms = eps_rms
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.ulysses_group = ulysses_group
        ctx.attn_type = attn_type
        ctx.two_streams = two_streams
        ctx.pipe_degree = pipe_degree
        ctx.layer_id = layer_id
        
        return final_out if not return_softmax else (final_out, final_lse, None)
        
    
    @staticmethod
    # @torch.no_grad()
    def backward(ctx, dout, *args) -> Tuple[Tensor, Tensor, None, Tensor, Tensor, Tensor, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]:

        fn = select_flash_attn_impl(ctx.attn_type, stage="bwd-only")

        saved_tensors = ctx.saved_tensors
        x, w_rms, wq, wk, wv, freqs_cis, final_out, final_lse = saved_tensors
        eps_rms = ctx.eps_rms

        freqs_cis = freqs_cis.to(x.device)

        layer_id = ctx.layer_id

        assert not check_nan_inf(final_out, "final_out", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_out"
        assert not check_nan_inf(final_lse, "final_lse", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_lse"
        assert not check_nan_inf(x, "x", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in x"
        assert not check_nan_inf(wq, "wq", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wq"
        assert not check_nan_inf(wk, "wk", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wk"
        assert not check_nan_inf(wv, "wv", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wv"
        assert not check_nan_inf(freqs_cis, "freqs_cis", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in freqs_cis"
        assert not check_nan_inf(dout, f"dout {layer_id}", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in dout"
        
        bs, seqlen, n_heads, head_dim = final_out.shape
        ulysses_degree = dist.get_world_size(ctx.ulysses_group)
        pipe_degree = n_heads // ulysses_degree
        n_kv_heads = wk.shape[1] // head_dim
        gqa_ratio = n_heads // n_kv_heads

        dx = torch.zeros_like(x)
        dwq = []
        dwk = []
        dwv = []

        proj_dim = wq.shape[1]

        two_streams = ctx.two_streams

        q_in, k_in, v_in = [None] * pipe_degree, [None] * (pipe_degree//gqa_ratio), [None] * (pipe_degree//gqa_ratio)
        attn_dq, attn_dk, attn_dv = [None] * pipe_degree, [None] * pipe_degree, [None] * pipe_degree
        # attn_dq = [torch.empty([bs, seqlen*ulysses_degree, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        # attn_dk = [torch.empty([bs, seqlen*ulysses_degree, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        # attn_dv = [torch.empty([bs, seqlen*ulysses_degree, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        dq_out, dk_out, dv_out = [None] * pipe_degree, [None] * (pipe_degree//gqa_ratio), [None] * (pipe_degree//gqa_ratio)

        a2a_events = [torch.cuda.Event() for _ in range(pipe_degree)]

        q_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        k_out = [None for _ in range(pipe_degree//gqa_ratio)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        v_out = [None for _ in range(pipe_degree//gqa_ratio)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        out_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        dout_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=dout.dtype) for _ in range(pipe_degree)]

        final_out = list(torch.chunk(final_out, pipe_degree, dim = 2))
        final_lse = list(torch.chunk(final_lse, pipe_degree, dim = 1))
        dout = list(torch.chunk(dout, pipe_degree, dim = 2))

        num_streams = len(two_streams)

        for i in range(len(two_streams)):
            two_streams[i].wait_stream(torch.cuda.current_stream())

        for stage in range(pipe_degree):
            if stage == 0:
                with torch.cuda.stream(two_streams[0]):
                    q_in[stage] = x @ wq[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]
                    k_in[stage] = x @ wk[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]
                    v_in[stage] = x @ wv[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)]

                    q_in[stage], k_in[stage] = apply_rotary_emb(q_in[stage].view(bs, seqlen, -1, head_dim), k_in[stage].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                    v_in[stage] = v_in[stage].view(bs, seqlen, -1, head_dim)

                    q_out[stage] = all_to_all_4D(q_in[stage], 2, 1, False, False)#, output=q_out[stage])
                    k_out[stage] = all_to_all_4D(k_in[stage], 2, 1, False, False)#, output=k_out[stage])
                    v_out[stage] = all_to_all_4D(v_in[stage], 2, 1, False, False)#, output=v_out[stage])

                    out_out[stage] = all_to_all_4D(final_out[stage], 2, 1, False, False)#, output=out_out[stage])
                    dout_out[stage] = all_to_all_4D(dout[stage], 2, 1, False, False)#, output=dout_out[stage])

                    q_in[stage] = None
                    k_in[stage] = None
                    v_in[stage] = None

                    final_out[stage] = None
                    # dout[stage] = None

                    a2a_events[stage].record()

            if stage != pipe_degree - 1:
                with torch.cuda.stream(two_streams[(stage+1)%num_streams]):
                    a2a_events[stage].wait(two_streams[(stage+1)%num_streams])
                    q_in[stage+1] = x @ wq[:, (stage+1)*(proj_dim//pipe_degree):(stage+2)*(proj_dim//pipe_degree)]
                    if (stage+1)//gqa_ratio > stage//gqa_ratio:
                        k_in[(stage+1)//gqa_ratio] = x @ wk[:, ((stage+1)//gqa_ratio)*(proj_dim//pipe_degree):(((stage+1)//gqa_ratio)+1)*(proj_dim//pipe_degree)]
                        v_in[(stage+1)//gqa_ratio] = x @ wv[:, ((stage+1)//gqa_ratio)*(proj_dim//pipe_degree):(((stage+1)//gqa_ratio)+1)*(proj_dim//pipe_degree)]

                        q_in[stage+1], k_in[(stage+1)//gqa_ratio] = apply_rotary_emb(xq = q_in[stage+1].view(bs, seqlen, -1, head_dim), xk = k_in[(stage+1)//gqa_ratio].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                        v_in[(stage+1)//gqa_ratio] = v_in[(stage+1)//gqa_ratio].view(bs, seqlen, -1, head_dim)

                        q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])
                        k_out[(stage+1)//gqa_ratio] = all_to_all_4D(k_in[(stage+1)//gqa_ratio], 2, 1, False, False)#, output=k_out[stage+1])
                        v_out[(stage+1)//gqa_ratio] = all_to_all_4D(v_in[(stage+1)//gqa_ratio], 2, 1, False, False)#, output=v_out[stage+1])

                        out_out[stage+1] = all_to_all_4D(final_out[stage+1], 2, 1, False, False)#, output=out_out[stage+1])
                        dout_out[stage+1] = all_to_all_4D(dout[stage+1], 2, 1, False, False)#, output=dout_out[stage+1])

                        a2a_events[stage+1].record()

                        q_in[stage+1] = None
                        k_in[(stage+1)//gqa_ratio] = None
                        v_in[(stage+1)//gqa_ratio] = None

                        final_out[stage+1] = None
                        # dout[stage+1] = None
                    else:
                        q_in[stage+1] = apply_rotary_emb(xq = q_in[stage+1].view(bs, seqlen, -1, head_dim), xk = None, freqs_cis=freqs_cis)
                        
                        q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])
                        out_out[stage+1] = all_to_all_4D(final_out[stage+1], 2, 1, False, False)#, output=out_out[stage+1])
                        dout_out[stage+1] = all_to_all_4D(dout[stage+1], 2, 1, False, False)#, output=dout_out[stage+1])

                        a2a_events[stage+1].record()

                        q_in[stage+1] = None
                        final_out[stage+1] = None
                        # dout[stage+1] = None

            
            with torch.cuda.stream(two_streams[stage%num_streams]):
                a2a_events[stage].wait(two_streams[stage%num_streams])
                # if int(layer_id) == 30 and torch.distributed.get_rank() == 0:
                #     breakpoint()
                # torch.distributed.barrier()
                attn_dq[stage], attn_dk[stage], attn_dv[stage] = fully_fused_ring_flash_attn_backward(
                    ctx.ring_group,
                    dout_out[stage],
                    q_out[stage],
                    k_out[stage//gqa_ratio],
                    v_out[stage//gqa_ratio],
                    out_out[stage],
                    final_lse[stage],
                    softmax_scale=ctx.softmax_scale,
                    dropout_p=ctx.dropout_p,
                    causal=ctx.causal,
                    window_size=ctx.window_size,
                    softcap=ctx.softcap,
                    alibi_slopes=ctx.alibi_slopes,
                    deterministic=ctx.deterministic,
                    attn_type=ctx.attn_type,
                )

                assert not check_nan_inf(attn_dq[stage], f"attn_dq {layer_id}", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in attn_dq in stage {stage}"
                assert not check_nan_inf(attn_dk[stage], f"attn_dk {layer_id}", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in attn_dk in stage {stage}"
                assert not check_nan_inf(attn_dv[stage], f"attn_dv {layer_id}", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in attn_dv in stage {stage}"

                # assert attn_dq[stage].shape == q_out[stage].shape, f"attn_dq[stage].shape: {attn_dq[stage].shape}, q_out[stage].shape: {q_out[stage].shape}"

                # fn(
                #     dout_out[stage],
                #     q_out[stage],
                #     k_out[stage],
                #     v_out[stage],
                #     out_out[stage],
                #     final_lse[stage],
                #     # dq_buffer[:, :seqlen_q],
                #     # dk_buffer[:, :seqlen_kv],
                #     # dv_buffer[:, :seqlen_kv],
                #     attn_dq[stage],
                #     attn_dk[stage],
                #     attn_dv[stage],
                #     ctx.dropout_p,
                #     ctx.softmax_scale,
                #     ctx.causal,
                #     ctx.window_size,
                #     ctx.softcap,    
                #     ctx.alibi_slopes,
                #     ctx.deterministic,
                #     rng_state=None,
                # )

                dq_out[stage] = all_to_all_4D(attn_dq[stage], 1, 2, False, False).view(bs, seqlen, -1)
                if dk_out[stage//gqa_ratio] is None:
                    dk_out[stage//gqa_ratio] = all_to_all_4D(attn_dk[stage], 1, 2, False, False).view(bs, seqlen, -1)
                    dv_out[stage//gqa_ratio] = all_to_all_4D(attn_dv[stage], 1, 2, False, False).view(bs, seqlen, -1)
                else:
                    dk_out[stage//gqa_ratio] += all_to_all_4D(attn_dk[stage], 1, 2, False, False).view(bs, seqlen, -1)
                    dv_out[stage//gqa_ratio] += all_to_all_4D(attn_dv[stage], 1, 2, False, False).view(bs, seqlen, -1)


                q_out[stage] = None
                if (stage+1)//gqa_ratio != stage//gqa_ratio:
                    k_out[stage//gqa_ratio] = None
                    v_out[stage//gqa_ratio] = None
                out_out[stage] = None
                dout_out[stage] = None
                final_lse[stage] = None

                attn_dq[stage] = None
                attn_dk[stage] = None
                attn_dv[stage] = None
                # torch.cuda.empty_cache()

                dx_q = dq_out[stage] @ wq[:, stage*(proj_dim//pipe_degree):(stage+1)*(proj_dim//pipe_degree)].T
                dx_k = dk_out[stage//gqa_ratio] @ wk[:, (stage//gqa_ratio)*(proj_dim//pipe_degree):((stage//gqa_ratio)+1)*(proj_dim//pipe_degree)].T
                dx_v = dv_out[stage//gqa_ratio] @ wv[:, (stage//gqa_ratio)*(proj_dim//pipe_degree):((stage//gqa_ratio)+1)*(proj_dim//pipe_degree)].T

                dx += dx_q + dx_k + dx_v

                # if torch.distributed.get_rank() == 0:
                #     breakpoint()
                # torch.distributed.barrier()

                dwq.append((x.transpose(-1, -2) @ dq_out[stage]).sum(dim = 0))
                if (stage+1)//gqa_ratio != stage//gqa_ratio or stage == pipe_degree-1:
                    dwk.append((x.transpose(-1, -2) @ dk_out[stage//gqa_ratio]).sum(dim = 0))
                    dwv.append((x.transpose(-1, -2) @ dv_out[stage//gqa_ratio]).sum(dim = 0))

                dq_out[stage] = None
                if (stage+1)//gqa_ratio != stage//gqa_ratio:
                    dk_out[stage//gqa_ratio] = None
                    dv_out[stage//gqa_ratio] = None

                # torch.cuda.empty_cache()
        
        for i in range(len(two_streams)):
            torch.cuda.current_stream().wait_stream(two_streams[i])
        # torch.cuda.empty_cache()

        dwq = torch.cat(dwq, dim = 1)
        dwk = torch.cat(dwk, dim = 1)
        dwv = torch.cat(dwv, dim = 1)

        assert not check_nan_inf(dwq, "dwq", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in dwq"
        assert not check_nan_inf(dwk, "dwk", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in dwk"
        assert not check_nan_inf(dwv, "dwv", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in dwv"
        assert not check_nan_inf(dx, f"first dx {layer_id}", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in dx"

        
        x_requires_grad = x.requires_grad
        # x = x.detach()
        # x.requires_grad_(x_requires_grad)

        grad_out = dx

        bs, seqlen, hidden = x.shape
        shards = max(1, (seqlen + hidden - 1) // hidden)

        dout = torch.cat(dout, dim = 2).reshape(bs, seqlen, -1)

        # Flatten (bs, seqlen) to avoid stride issues when narrowing
        x_flat = x.view(-1, hidden)
        g_flat = grad_out.view(-1, hidden)
        x_grad_flat = torch.zeros_like(x_flat)

        x_shards = list(torch.chunk(x_flat, chunks=shards, dim=0))
        current_offset = 0
        for i, x_shard in enumerate(x_shards):
            # If using ZeRO, coordinate gradient readiness flag as in TiledMLP
            if hasattr(w_rms, "ds_grad_is_ready"):  
                if i + 1 < shards:
                    w_rms.ds_grad_is_ready = False
                else:
                    w_rms.ds_grad_is_ready = True

            shard_step = x_shard.shape[0]
            shard_offset = current_offset

            x_shard.requires_grad_(x_requires_grad)
            # Route autograd to write directly into the appropriate x_grad slice
            x_shard.grad = x_grad_flat.narrow(0, shard_offset, shard_step).view_as(x_shard)
            incoming_grad_shard = g_flat.narrow(0, shard_offset, shard_step).view_as(x_shard)

            with torch.enable_grad():
                w = w_rms
                if w.device != x_shard.device or w.dtype != x_shard.dtype:
                    # cast a view for compute but keep parameter for grad accumulation
                    w = w.to(dtype=x_shard.dtype, device=x_shard.device)
                y = F.rms_norm(x_shard, [hidden], weight=w, eps=eps_rms)
            torch.autograd.backward(y, incoming_grad_shard)

            current_offset += shard_step

        # Unflatten
        x_grad = (x_grad_flat.view(bs, -1, hidden) + dout) if x_requires_grad else None
        # #perform backward for RMSNorm
        # num_chunks = x.shape[1] // x.shape[2]
        # grad_out_chunks = list(torch.chunk(dx, num_chunks, dim=1))
        # x_chunks = torch.chunk(x, num_chunks, dim=1)
        # grad_weight = torch.zeros_like(w_rms, device = w_rms.device, dtype = torch.float32)
        # # if torch.distributed.get_rank() == 0:
        # #     breakpoint()
        # # torch.distributed.barrier()
        # dout = torch.chunk(torch.cat(dout, dim = 2).reshape(bs, seqlen, -1), num_chunks, dim = 1)
        # """
        # grad_norm_x = grad_out * weight
        # rms = x / norm_x
        # dot = (x * grad_norm_x).sum(dim=-1, keepdim=True) / dim
        # grad_x = (grad_norm_x / rms) - (x * dot / (rms**3))
        # """
        # use_autograd_rmsnorm = os.getenv("YUNCHANG_RMSNORM_AUTOGRAD", "0") == "1"
        # for i in range(num_chunks):
        #     # grad_x
        #     g_out_i = grad_out_chunks[i]
        #     if use_autograd_rmsnorm:
        #         x_i = x_chunks[i].requires_grad_(True)
        #         with torch.enable_grad():
        #             y_i = F.rms_norm(x_i, [x_i.shape[-1]], weight=w_rms, eps=ctx.eps_rms)
        #             dx_i, dw_i = torch.autograd.grad(y_i, [x_i, w_rms], g_out_i)#, retain_graph=True, allow_unused=True)
        #         grad_out_chunks[i] = dx_i + dout[i]
        #         grad_weight += dw_i.to(grad_weight.dtype)
        #     else:
        #         grad_norm_x_c = g_out_i * w_rms
        #         rms_c = (x_chunks[i].pow(2).mean(dim=-1, keepdim=True) + ctx.eps_rms).sqrt()
        #         dot_c = (x_chunks[i] * grad_norm_x_c).sum(dim=-1, keepdim=True) / x.shape[2]
        #         grad_x_c = (grad_norm_x_c / rms_c) - (x_chunks[i] * dot_c / (rms_c**3))
        #         grad_out_chunks[i] = grad_x_c + dout[i]

        #         # grad_weight (use upstream grad)
        #         grad_weight += (g_out_i * (x_chunks[i] / rms_c)).sum(dim=tuple(range(grad_x_c.dim()-1)))
        # dx = torch.cat(grad_out_chunks, dim=1)

        # assert not check_nan_inf(dx, f"second dx {layer_id}", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in dx"
        # assert not check_nan_inf(grad_weight, "grad_weight", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in grad_weight"

        return x_grad, None, None, dwq, dwk, dwv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None


def fully_pipelined_long_context_attn_func(
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    ring_group=None,
    ulysses_group=None,
    offload_stream=None,
    fetch_stream=None,
    two_streams=None,
    attn_type: AttnType = AttnType.FA,
    attn_processor=None,
    use_pack_qkv=True,
    dualstage=False,
):
    
    # output = FullyPipelinedAttnFunc.apply(
    #         q[0],
    #         q[1],
    #         q[2],
    #         q[3],
    #         k[0],
    #         k[1],
    #         k[2],
    #         k[3],
    #         v[0],
    #         v[1],
    #         v[2],
    #         v[3],
    #         dropout_p,
    #         softmax_scale,
    #         causal,
    #         window_size,
    #         softcap,
    #         alibi_slopes,
    #         deterministic,
    #         return_attn_probs,
    #         ring_group,
    #         ulysses_group,
    #         offload_stream,
    #         fetch_stream,
    #         two_streams,
    #         attn_type,
    #     )
    output = FullyPipelinedRingFunc.apply(
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal,
            window_size,
            softcap,
            alibi_slopes,
            deterministic,
            return_attn_probs,
            ring_group,
            ulysses_group,
            offload_stream,
            fetch_stream,
            two_streams,
            attn_type,
        )
    # if torch.distributed.get_rank() == 0:
    #     breakpoint()
    # torch.distributed.barrier()
    # assert output.requires_grad, f"Inside fully_pipelined_long_context_attention_func: output requires_grad must be True"
    return output

def fully_fused_attn_func(
        x,
        wq,
        wk,
        wv,
        freqs_cis,
        head_dim,
        dropout_p=0.0,
        softmax_scale=None,
        causal=False,
        window_size=(-1, -1),
        softcap=0.0,
        alibi_slopes=None,
        deterministic=False,
        return_attn_probs=False,
        ring_group=None,
        ulysses_group=None,
        offload_stream=None,
        fetch_stream=None,
        two_streams=None,
        attn_type: AttnType = AttnType.FA3,
        attn_processor=None,
        use_pack_qkv=True,
        dualstage=False,
        layer_id=None,
        fused_attn_type="fully_fused",
        w_rms=None,
        eps_rms=None,
):
    
    if "mha" in fused_attn_type:#wq.shape[0] == wk.shape[0] and wq.shape[1] == wk.shape[1]:
        output = FullyFusedAttnFunc.apply(
                x,
                wq,
                wk,
                wv,
                freqs_cis,
                head_dim,
                dropout_p,
                softmax_scale,
                causal,
                window_size,
                softcap,
                alibi_slopes,
                deterministic,
                return_attn_probs,
                ring_group,
                ulysses_group,
                offload_stream,
                fetch_stream,
                two_streams,
                attn_type,
                layer_id,
            )
    else:
        if "ultra_fused" in fused_attn_type:
            output = UltraFusedAttnGQAFunc.apply(
                x,
                w_rms,
                eps_rms,
                wq,
                wk,
                wv,
                freqs_cis,
                head_dim,
                dropout_p,
                softmax_scale,
                causal,
                window_size,
                softcap,
                alibi_slopes,
                deterministic,
                return_attn_probs,
                ring_group,
                ulysses_group,
                offload_stream,
                fetch_stream,
                two_streams,
                attn_type,
                layer_id,
            )
        else:
            output = FullyFusedAttnGQAFunc.apply(
                x,
                wq,
                wk,
                wv,
                freqs_cis,
                head_dim,
                dropout_p,
                softmax_scale,
                causal,
                window_size,
                softcap,
                alibi_slopes,
                deterministic,
                return_attn_probs,
                ring_group,
                ulysses_group,
                offload_stream,
                fetch_stream,
                two_streams,
                attn_type,
                layer_id,
            )
    
    return output
