import torch
from .utils import RingComm, update_out_and_lse
from yunchang.kernels import AttnType, select_flash_attn_impl
from yunchang.globals import U_HANDLE, clear_u_handle, clear_o_handle, PROCESS_GROUP
from yunchang.comm.all_to_all import vanilla_all_to_all_4D as all_to_all_4D
import logging
logger = logging.getLogger(__name__)

from typing import List, Tuple


global process_group, attn_type, alibi_slopes, window_size, next_ulysses_qkv, a2a_available

import torch.distributed as dist

def check_nan_inf(tensor, name, rank=0):
    """Check for NaN or Inf values in tensor"""
    if torch.isnan(tensor).any():
        print(f"[RANK {rank}] NaN detected in {name}, shape: {tensor.shape}")
        print(f"[RANK {rank}] NaN locations: {torch.isnan(tensor).sum().item()}")
        return True
    if torch.isinf(tensor).any():
        print(f"[RANK {rank}] Inf detected in {name}, shape: {tensor.shape}")
        print(f"[RANK {rank}] Inf locations: {torch.isinf(tensor).sum().item()}")
        return True
    return False

@torch.library.custom_op("yunchang::_upipe_ring_flash_attn_forward", mutates_args=(), device_types="cuda")
def upipe_ring_flash_attn_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: float,
    dropout_p: float = 0,
    causal: bool = True,
    softcap: float = 0.0,
    deterministic: bool=False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    global process_group, attn_type, alibi_slopes, window_size
    
    assert causal == True, "zigzag ring is meaningless for causal=False"
    comm = RingComm(process_group)

    block_seq_len = q.shape[1] // 2
    q1 = q[:, block_seq_len:]

    out = None
    lse = None
    next_k, next_v = None, None
    
    def forward(q, k, v, causal):
        fn = select_flash_attn_impl(attn_type, stage="fwd-only")
        block_out, block_lse = fn(
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
            window_size=window_size,
            softcap=softcap,
            alibi_slopes=alibi_slopes,
            return_softmax=True and dropout_p > 0,
        )
        return block_out, block_lse

    for step in range(comm.world_size):
        if step + 1 != comm.world_size:
            next_k: torch.Tensor = comm.send_recv(k)
            next_v: torch.Tensor = comm.send_recv(v)
            comm.commit()

        if step == 0:
            block_out, block_lse = forward(q, k, v, causal=True)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)
        elif step <= comm.rank:
            k0 = k[:, :block_seq_len]
            v0 = v[:, :block_seq_len]
            block_out, block_lse = forward(q, k0, v0, causal=False)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)
        else:
            block_out, block_lse = forward(q1, k, v, causal=False)
            out, lse = update_out_and_lse(
                out,
                lse,
                block_out,
                block_lse,
                slice_=(slice(None), slice(block_seq_len, None)),
            )

        if step + 1 != comm.world_size:
            comm.wait()
            k = next_k
            v = next_v

    out = out.to(q.dtype)
    lse = lse.squeeze(dim=-1).transpose(1, 2)
    return out, lse

@upipe_ring_flash_attn_forward.register_fake
def _(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: float,
    dropout_p: float = 0,
    causal: bool = True,
    softcap: float = 0.0,
    deterministic: bool=False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    bs, sl, nh, d = q.shape
    out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    lse = torch.empty([bs, sl, nh], dtype=q.dtype, device=q.device)
    return out, lse

# @torch.library.custom_op("yunchang::_zigzag_ring_flash_attn_forward_op", mutates_args=(), device_types="cuda")
# def zigzag_ring_flash_attn_forward_op(out: torch.Tensor, softmax_lse: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
#     return out.clone().detach().requires_grad_(True), softmax_lse.clone().detach().requires_grad_(True)

# @zigzag_ring_flash_attn_forward_op.register_fake
# def _(out: torch.Tensor, softmax_lse: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
#     return out.clone().detach().requires_grad_(True), softmax_lse.clone().detach().requires_grad_(True)

def upipe_ring_flash_attn_backward(
    process_group,
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    softmax_scale,
    dropout_p=0,
    causal=True,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    attn_type: AttnType = AttnType.FA,
):
    assert causal == True, "upipe ring is meaningless for causal=False"
    kv_comm = RingComm(process_group)
    d_kv_comm = RingComm(process_group)
    dq, dk, dv = None, None, None
    next_dk, next_dv = None, None
    next_k, next_v = None, None
    dk_comm_buffer, dv_comm_buffer = None, None

    dout1 = dout.chunk(2, dim=1)[1]
    q1 = q.chunk(2, dim=1)[1]
    out1 = out.chunk(2, dim=1)[1]
    softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous()
    block_seq_len = q.shape[1] // 2

    # repeatly allocating buffer may be slow...
    # if k.shape[2] == q.shape[2]:
    dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
    dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
    # else:
    #     dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    #     dk_buffer = torch.empty((k.shape[0], k.shape[1], q.shape[2], k.shape[3]), dtype=k.dtype, device=k.device)
    #     dv_buffer = torch.empty((v.shape[0], v.shape[1], q.shape[2], v.shape[3]), dtype=v.dtype, device=v.device)

    def backward(dout, q, k, v, out, softmax_lse, causal):
        seqlen_q = q.shape[1]
        seqlen_kv = k.shape[1]
        fn = select_flash_attn_impl(attn_type, stage="bwd-only")

        fn(
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq_buffer[:, :seqlen_q],
            dk_buffer[:, :seqlen_kv],
            dv_buffer[:, :seqlen_kv],
            dropout_p,
            softmax_scale,
            causal,
            window_size,
            softcap,
            alibi_slopes,
            deterministic,
            rng_state=None,
        )

    for step in range(kv_comm.world_size):
        # assert not check_nan_inf(dout, f"dout -{step}", dist.get_rank()), f"Pipe: dout is nan or inf at step {step}"
        # assert not check_nan_inf(q, f"q -{step}", dist.get_rank()), f"Pipe: q is nan or inf at step {step}"
        # assert not check_nan_inf(k, f"k -{step}", dist.get_rank()), f"Pipe: k is nan or inf at step {step}"
        # assert not check_nan_inf(v, f"v -{step}", dist.get_rank()), f"Pipe: v is nan or inf at step {step}"
        # assert not check_nan_inf(out, f"out -{step}", dist.get_rank()), f"Pipe: out is nan or inf at step {step}"
        # assert not check_nan_inf(softmax_lse, f"softmax_lse -{step}", dist.get_rank()), f"Pipe: softmax_lse is nan or inf at step {step}"

        if step + 1 != kv_comm.world_size:
            next_k = kv_comm.send_recv(k)
            next_v = kv_comm.send_recv(v)
            kv_comm.commit()

        if step == 0:
            backward(dout.contiguous(), q.contiguous(), k.contiguous(), v.contiguous(), out.contiguous(), softmax_lse.contiguous(), causal=True)
            if torch.distributed.get_world_size(process_group) == 1:
                return dq_buffer, dk_buffer, dv_buffer
            dq = dq_buffer.to(torch.float32)
            dk = dk_buffer.to(torch.float32)
            dv = dv_buffer.to(torch.float32)
            # assert not check_nan_inf(dq, f"dq -{step}", dist.get_rank()), f"Pipe: dq is nan or inf at step {step}"
            # assert not check_nan_inf(dk, f"dk -{step}", dist.get_rank()), f"Pipe: dk is nan or inf at step {step}"
            # assert not check_nan_inf(dv, f"dv -{step}", dist.get_rank()), f"Pipe: dv is nan or inf at step {step}"
        else:
            if step <= kv_comm.rank:
                k0 = k[:, :block_seq_len]
                v0 = v[:, :block_seq_len]
                backward(dout, q, k0, v0, out, softmax_lse, causal=False)
                dq += dq_buffer
                # assert not check_nan_inf(dq, f"dq 251 -{step}", dist.get_rank()), f"Pipe: dq is nan or inf at step {step}"
            else:
                backward(dout1, q1, k, v, out1, softmax_lse1, causal=False)
                # always use the first half in dq_buffer.
                dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len]
                # assert not check_nan_inf(dq, f"dq 256 -{step}", dist.get_rank()), f"Pipe: dq is nan or inf at step {step}"

            d_kv_comm.wait()
            dk_comm_buffer, dv_comm_buffer = dk, dv
            dk, dv = next_dk, next_dv

            if step <= kv_comm.rank:
                dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len]
                dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len]
            else:
                dk += dk_buffer
                dv += dv_buffer
                # assert not check_nan_inf(dk, f"dk 268 -{step}", dist.get_rank()), f"Pipe: dk is nan or inf at step {step}"
                # assert not check_nan_inf(dv, f"dv 269 -{step}", dist.get_rank()), f"Pipe: dv is nan or inf at step {step}"

        if step + 1 != kv_comm.world_size:
            kv_comm.wait()
            k = next_k
            v = next_v

        next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer)
        next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer)
        d_kv_comm.commit()

    d_kv_comm.wait()

    # if k.shape[2] != q.shape[2]:
    #     bs, slen, nqh, hdim = q.shape
    #     return dq.to(q.dtype), next_dk.view(bs, slen, k.shape[2], (q.shape[2]//k.shape[2]), hdim).sum(dim=3).to(q.dtype), next_dv.view(bs, slen, v.shape[2], (q.shape[2]//v.shape[2]), hdim).sum(dim=3).to(q.dtype)
    # else:
    return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype)



class UpipeRingFlashAttnFuncQKVPacked(torch.autograd.Function):
    
    @staticmethod
    def forward(
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        attn_type,
    ):  
        # Do all the ulysses stuff here
        bs, shard_seqlen, hc, hs = q.shape

        world_size = dist.get_world_size(ulysses_group)
        pipe_degree = hc // world_size

        # context_layer = torch.empty([bs, shard_seqlen*world_size, hc//world_size, hs], dtype=query.dtype, device=query.device)
        output = torch.zeros_like(q)
        # softmax_lse = torch.zeros([bs, shard_seqlen, hc], dtype=q.dtype, device=q.device)
        softmax_lse = None#torch.zeros([bs, pipe_degree, shard_seqlen*world_size], dtype=q.dtype, device=q.device)
        

        assert k.shape[2] == q.shape[2], f"Pipe: num heads in key {k.shape[2]} must be equal to query {q.shape[2]}"
        qkv = torch.cat([q, k, v]).contiguous()

        orig_device = qkv.device

        qkv = [t for t in torch.chunk(qkv, pipe_degree, dim = 2)]

        # async offload the qkv tensors to CPU
        if offload_stream is not None:
            with torch.cuda.stream(offload_stream):
                for tnsr in qkv[2:]:
                    tnsr.to("cpu", non_blocking=True)

        # first all-to-all is blocking
        # ulysses_qkv = SeqAllToAll4D.apply(
        #     self.ulysses_pg, qkv[:,:,:world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, False
        # )
        PROCESS_GROUP.ULYSSES_PG = ulysses_group
        ulysses_qkv = all_to_all_4D(
            qkv[0], 2, 1, False, False # scatter 2, gather 1
        )

        # assert ulysses_qkv.requires_grad, f"Pipe: ulysses_qkv requires_grad must be True"
        
        if offload_stream is not None:
            torch.cuda.current_stream().synchronize()
            with torch.cuda.stream(offload_stream):
                qkv[0].to("cpu", non_blocking=True)

        # Initialize variables to avoid undefined variable errors
        block_output = None
        o_bs = o_shard_seqlen = o_hc = o_hs = None

        # global variables
        global process_group
        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size

        for stage in range(pipe_degree):
            if stage+1 != pipe_degree:
                # next_ulysses_qkv = SeqAllToAll4D.apply(
                #     self.ulysses_pg, qkv[:,:,(stage+1)*world_size:(stage+2)*world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, True
                # )
                if stage+2 < pipe_degree:
                    if fetch_stream is not None:
                        fetch_stream.synchronize() # make sure the last qkv is fetched to GPU
                        with torch.cuda.stream(fetch_stream):
                            qkv[stage+2].to(orig_device, non_blocking=True)
                next_ulysses_qkv = all_to_all_4D(
                    qkv[stage+1], 2, 1, False, True # scatter 2, gather 1
                )
                bs, shard_seqlen, hc, hs = qkv[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            # output[stage].to(orig_device, non_blocking=True)

            ulysses_qkv = torch.chunk(ulysses_qkv, 3, dim=0)

            inp_q = ulysses_qkv[0]
            inp_k = ulysses_qkv[1]
            inp_v = ulysses_qkv[2]

            if softmax_scale is None:
                softmax_scale = inp_q.shape[-1] ** (-0.5)

            assert alibi_slopes is None
            inp_k = inp_k.contiguous()
            inp_v = inp_v.contiguous()

            # check_nan_inf(inp_q, "inp_q", dist.get_rank())
            # check_nan_inf(inp_k, "inp_k", dist.get_rank())
            # check_nan_inf(inp_v, "inp_v", dist.get_rank())
            
            out, lse = upipe_ring_flash_attn_forward(
                inp_q,
                inp_k,
                inp_v,
                softmax_scale=softmax_scale,
                dropout_p=dropout_p,
                causal=causal,
                softcap=softcap,
                deterministic=False,
            )
            if torch.distributed.get_rank() == 0:
                breakpoint()
            torch.distributed.barrier()
            # assert out.requires_grad, f"Pipe stage {stage}: out requires_grad must be True"
            # if softmax_lse is None:
            #     softmax_lse = lse#.clone() # cloning the lse tensor to avoid overwriting the lse tensor
            # else:
            #     softmax_lse = torch.cat([softmax_lse, lse], dim = 1)
            if softmax_lse is None:
                softmax_lse = torch.zeros_like(lse).repeat(1, pipe_degree, 1)
            softmax_lse[:, stage, :] = lse

            # try:
            #     context_layer = torch.cat([out, lse.unsqueeze(-1)], dim = 3)
            # except:
            #     print(f"Pipe: out shape {out.shape}, lse shape {lse.shape}")
            #     raise ValueError("Pipe: out and lse shapes are not compatible")
            context_layer = out
            
            o_bs, o_seqlen, o_shard_hc, o_hs = context_layer.shape
            o_hc = o_shard_hc * world_size
            o_shard_seqlen = o_seqlen // world_size
            
            if stage > 0:
                if U_HANDLE.O_HANDLE!=[]:
                    assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
                    U_HANDLE.O_HANDLE[0].wait()
                clear_o_handle()
                assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                output[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_output#[:,:,:,:-1]
                # softmax_lse[:,:,(stage-1)*world_size:(stage)*world_size] = block_output[:,:,:,-1]
                # output[stage-1] = block_output
                # output[stage-1].to("cpu", non_blocking=True)

            block_output = all_to_all_4D(
                context_layer, 1, 2, False, True # scatter 1, gather 2
            )
            
            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=[]:
                    assert len(U_HANDLE.HANDLE) == 1, f"Pipe: U_HANDLE.HANDLE length {len(U_HANDLE.HANDLE)} must be 1"
                    U_HANDLE.HANDLE[0].wait()
                clear_u_handle()
                if offload_stream is not None:
                    with torch.cuda.stream(offload_stream):
                        qkv[stage+1].to("cpu", non_blocking=True)
                assert next_ulysses_qkv.numel()== (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_qkv shape {next_ulysses_qkv.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_qkv = next_ulysses_qkv.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_qkv = next_ulysses_qkv.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_qkv = next_ulysses_qkv
                # assert not check_nan_inf(ulysses_qkv, f"ulysses_qkv_{stage+1}", dist.get_rank()), f"Pipe: ulysses_qkv is nan or inf at stage {stage+1}"
        
        if U_HANDLE.O_HANDLE!=[]:
            assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
            U_HANDLE.O_HANDLE[0].wait()
            clear_o_handle()
            assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            output[:,:,-(world_size):,:] = block_output#[:,:,:,:-1]
            # softmax_lse[:,:,-(world_size):] = block_output[:,:,:,-1]

        # out, softmax_lse = zigzag_ring_flash_attn_forward_op(out_temp, softmax_lse_temp)

        # this should be out_padded
        ctx.save_for_backward(q, k, v, output, softmax_lse)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.ulysses_group = ulysses_group
        ctx.attn_type = attn_type
        ctx.pipe_degree = pipe_degree
        ctx.offload_stream = offload_stream
        ctx.fetch_stream = fetch_stream
        return output if not return_softmax else (output, softmax_lse, None)
        
    
    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse = ctx.saved_tensors

        pipe_degree = ctx.pipe_degree
        offload_stream = ctx.offload_stream
        fetch_stream = ctx.fetch_stream
        orig_device = dout.device
        world_size = dist.get_world_size(ctx.ulysses_group)

        dqkv = torch.zeros([3*q.shape[0], q.shape[1], q.shape[2], q.shape[3]], dtype=q.dtype, device=q.device)

        orig_q_dtype = q.dtype
        doutqkvo_lse = torch.cat([q, k, v, dout, out], dim=0).contiguous()
        eff_bs = doutqkvo_lse.shape[0]
        # doutqkvo_lse = torch.cat([doutqkvo, softmax_lse.unsqueeze(-1).repeat(eff_bs, 1, 1, 1).to(dout.dtype)], dim=3).contiguous() #piggybacking softmax_lse as the last hidden dimension for all 5 other tensors, so be mindful when unpacking
        doutqkvo_lse  = [t for t in torch.chunk(doutqkvo_lse, pipe_degree, dim = 2)]

        softmax_lse = [t for t in torch.chunk(softmax_lse, pipe_degree, dim = 1)]

        if offload_stream is not None:
            # async offload the dout tensor to CPU
            with torch.cuda.stream(offload_stream):
                for tnsr in doutqkvo_lse[2:]:
                    tnsr.to("cpu", non_blocking=True)
                for tnsr in softmax_lse[1:]:
                    tnsr.to("cpu", non_blocking=True)
        
        PROCESS_GROUP.ULYSSES_PG = ctx.ulysses_group
        ulysses_doutqkvo_lse = all_to_all_4D(doutqkvo_lse[0], 2, 1, False, False) #note that this is scatter 2, gather 1 since this is inverse of forward pass
        
        if offload_stream is not None:
            torch.cuda.current_stream().synchronize()
            with torch.cuda.stream(offload_stream):
                doutqkvo_lse[0].to("cpu", non_blocking=True)
        
        for stage in range(pipe_degree):
            # if dist.get_rank(ctx.ulysses_group) == 0:
            #     print(f"Tensor shape: {ulysses_doutqkvo_lse.shape}\t stage: {stage}")
            if stage+1 != pipe_degree:
                if stage+2 < pipe_degree:
                    if fetch_stream is not None:
                        fetch_stream.synchronize() # make sure the last qkv is fetched to GPU
                        with torch.cuda.stream(fetch_stream):
                            doutqkvo_lse[stage+2].to(orig_device, non_blocking=True)
                            softmax_lse[stage+1].to(orig_device, non_blocking=True)
                next_ulysses_doutqkvo_lse = all_to_all_4D(doutqkvo_lse[stage+1], 2, 1, False, True) #note that this is scatter 2, gather 1 since this is inverse of forward pass
                bs, shard_seqlen, hc, hs = doutqkvo_lse[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            # Chunk the tensor to separate q, k, v, dout, out components (5 components total)
            ulysses_doutqkvo_lse = torch.chunk(ulysses_doutqkvo_lse, eff_bs, dim=0)
            assert len(ulysses_doutqkvo_lse) == eff_bs, f"Pipe: ulysses_tensors length {len(ulysses_doutqkvo_lse)} must be equal to eff_bs {eff_bs}, tensor shape {ulysses_doutqkvo_lse[0].shape}"
            
            # Extract tensors and ensure they are contiguous to avoid CUDA misalignment errors
            inp_q = ulysses_doutqkvo_lse[0]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_k = ulysses_doutqkvo_lse[1]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_v = ulysses_doutqkvo_lse[2]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_dout = ulysses_doutqkvo_lse[3]#[:,:,:,:].contiguous()
            inp_out = ulysses_doutqkvo_lse[4]#[:,:,:,:].contiguous().to(orig_q_dtype)
            
            # Extract softmax_lse and ensure contiguous after transpose
            # inp_softmax_lse = ulysses_doutqkvo_lse[-1][:,:,:,-1].contiguous().to(orig_q_dtype) # all 5 tensors have the same softmax_lse
            # inp_softmax_lse_transposed = inp_softmax_lse.transpose(1, 2).contiguous()
            inp_softmax_lse_transposed = softmax_lse[stage]#.transpose(1, 2).contiguous()

            softmax_scale = inp_q.shape[-1] ** (-0.5)
            
            # Verify tensor contiguity for debugging
            # assert inp_q.is_contiguous(), f"inp_q is not contiguous: {inp_q.stride()}"
            # assert inp_k.is_contiguous(), f"inp_k is not contiguous: {inp_k.stride()}"
            # assert inp_v.is_contiguous(), f"inp_v is not contiguous: {inp_v.stride()}"
            # assert inp_dout.is_contiguous(), f"inp_dout is not contiguous: {inp_dout.stride()}"
            # assert inp_out.is_contiguous(), f"inp_out is not contiguous: {inp_out.stride()}"
            # assert inp_softmax_lse_transposed.is_contiguous(), f"inp_softmax_lse_transposed is not contiguous: {inp_softmax_lse_transposed.stride()}"

            # assert not check_nan_inf(inp_q, f"inp_q", dist.get_rank()), f"Pipe: inp_q is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_k, f"inp_k", dist.get_rank()), f"Pipe: inp_k is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_v, f"inp_v", dist.get_rank()), f"Pipe: inp_v is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_dout, f"inp_dout", dist.get_rank()), f"Pipe: inp_dout is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_out, f"inp_out", dist.get_rank()), f"Pipe: inp_out is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_softmax_lse_transposed, f"inp_softmax_lse_transposed", dist.get_rank()), f"Pipe: inp_softmax_lse_transposed is nan or inf at stage {stage}"

            attn_dq, attn_dk, attn_dv = upipe_ring_flash_attn_backward(
                ctx.ring_group,
                inp_dout,
                inp_q,
                inp_k,
                inp_v,
                inp_out,
                inp_softmax_lse_transposed,
                softmax_scale=softmax_scale,
                dropout_p=ctx.dropout_p,
                causal=ctx.causal,
                window_size=ctx.window_size,
                softcap=ctx.softcap,
                alibi_slopes=ctx.alibi_slopes,
                deterministic=ctx.deterministic,
                attn_type=ctx.attn_type,
            )

            if offload_stream is not None:
                torch.cuda.current_stream().synchronize()
                with torch.cuda.stream(offload_stream):
                    softmax_lse[stage].to("cpu", non_blocking=True)

            # assert not check_nan_inf(attn_dq, f"attn_dq -{stage}", dist.get_rank()), f"Pipe: attn_dq is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dk, f"attn_dk -{stage}", dist.get_rank()), f"Pipe: attn_dk is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dv, f"attn_dv -{stage}", dist.get_rank()), f"Pipe: attn_dv is nan or inf at stage {stage}"

            if stage > 0:
                if U_HANDLE.O_HANDLE!=[]:
                    assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
                    U_HANDLE.O_HANDLE[0].wait()
                clear_o_handle()
                assert block_grads.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_grads shape {block_grads.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                block_grads = block_grads.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                block_grads = block_grads.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                dqkv[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads
                #block_grads = torch.chunk(block_grads, 3, dim=0)
                # dq[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[0]
                # dk[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[1]
                # dv[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[2]
                
            
            attn_grads = torch.cat([attn_dq, attn_dk, attn_dv], dim=0)
            o_bs, o_seqlen, o_shard_hc, o_hs = attn_grads.shape
            o_hc = o_shard_hc * world_size
            o_shard_seqlen = o_seqlen // world_size

            block_grads = all_to_all_4D(
                attn_grads, 1, 2, False, True # scatter 1, gather 2
            )

            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=[]:
                    assert len(U_HANDLE.HANDLE) == 1, f"Pipe: U_HANDLE.HANDLE length {len(U_HANDLE.HANDLE)} must be 1"
                    U_HANDLE.HANDLE[0].wait()
                clear_u_handle()
                if offload_stream is not None:
                    with torch.cuda.stream(offload_stream):
                        doutqkvo_lse[stage+1].to("cpu", non_blocking=True)
                        # softmax_lse[stage].to("cpu", non_blocking=True)
                assert next_ulysses_doutqkvo_lse.numel() ==  (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_doutqkvo_lse shape {next_ulysses_doutqkvo_lse.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse
        
        if U_HANDLE.O_HANDLE!=[]:
            assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
            U_HANDLE.O_HANDLE[0].wait()
            clear_o_handle()
            assert block_grads.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_grads shape {block_grads.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_grads = block_grads.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_grads = block_grads.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            dqkv[:,:,-(world_size):,:] = block_grads
            assert dqkv.shape[0] == 3, f"Pipe: dqkv shape {dqkv.shape} must be 3 in the bs dimension"
            dqkv = torch.chunk(dqkv, 3, dim=0)
            # block_grads = torch.chunk(block_grads, 3, dim=0)
            # dq[:,:,-(world_size):,:] = block_grads[0]
            # dk[:,:,-(world_size):,:] = block_grads[1]
            # dv[:,:,-(world_size):,:] = block_grads[2]

        # assert not check_nan_inf(dqkv[0], f"dqkv[0]", dist.get_rank()), f"Pipe: dqkv[0] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[1], f"dqkv[1]", dist.get_rank()), f"Pipe: dqkv[1] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[2], f"dqkv[2]", dist.get_rank()), f"Pipe: dqkv[2] is nan or inf at stage"
        assert len(dqkv) == 3, f"Pipe: dqkv length {len(dqkv)} must be 3"
        return dqkv[0], dqkv[1], dqkv[2], None, None, None, None, None, None, None, None, None, None, None, None, None


class NewPipeRingFlashAttnFuncQKVPacked(torch.autograd.Function):
    
    @staticmethod
    def forward(
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        attn_type,
    ):  
        # Do all the ulysses stuff here
        bs, shard_seqlen, hc, hs = q[0].shape

        world_size = dist.get_world_size(ulysses_group)
        pipe_degree = hc*len(q) // world_size

        output = torch.zeros([bs, shard_seqlen, hc*len(q), hs], dtype=q[0].dtype, device=q[0].device)
        output.requires_grad = True
        softmax_lse = None
        
        assert type(q) == list, f"Pipe: q must be a list"
        assert type(k) == list, f"Pipe: k must be a list"
        assert type(v) == list, f"Pipe: v must be a list"
        assert len(q) == len(k) == len(v), f"Pipe: q, k, v must have the same length"

        assert k[0].shape[2] == q[0].shape[2], f"Pipe: num heads in key {k[0].shape[2]} must be equal to query {q[0].shape[2]}"

        orig_device = q[0].device

        qkv = [torch.cat([q[i], k[i], v[i]], dim = 0) for i in range(len(q))]

        assert offload_stream is not None, f"Pipe: offload_stream must be provided"
        assert fetch_stream is not None, f"Pipe: fetch_stream must be provided"

        # fetch the first qkv
        qkv[0] = qkv[0].to(orig_device)

        with torch.cuda.stream(offload_stream):
            q[0] = q[0].to("cpu", non_blocking=True)
            k[0] = k[0].to("cpu", non_blocking=True)
            v[0] = v[0].to("cpu", non_blocking=True)
        
        with torch.cuda.stream(fetch_stream):
            qkv[1] = qkv[1].to(orig_device, non_blocking=True)

        PROCESS_GROUP.ULYSSES_PG = ulysses_group
        ulysses_qkv = all_to_all_4D(
            qkv[0], 2, 1, False, False # scatter 2, gather 1
        )
        
        # Initialize variables to avoid undefined variable errors
        block_output = None
        o_bs = o_shard_seqlen = o_hc = o_hs = None

        # global variables
        global process_group
        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size

        torch.cuda.current_stream().wait_stream(fetch_stream)

        for stage in range(pipe_degree):
            if stage+1 != pipe_degree:
                # next_ulysses_qkv = SeqAllToAll4D.apply(
                #     self.ulysses_pg, qkv[:,:,(stage+1)*world_size:(stage+2)*world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, True
                # )
                if stage+2 < pipe_degree:
                    with torch.cuda.stream(fetch_stream):
                        qkv[stage+2] = qkv[stage+2].to(orig_device, non_blocking=True)
                assert qkv[stage+1].device == orig_device, f"Pipe: qkv[stage+1] device {qkv[stage+1].device} must be {orig_device} at stage {stage}"
                next_ulysses_qkv = all_to_all_4D(
                    qkv[stage+1], 2, 1, False, True # scatter 2, gather 1
                )
                bs, shard_seqlen, hc, hs = qkv[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            # output[stage].to(orig_device, non_blocking=True)

            ulysses_qkv = torch.chunk(ulysses_qkv, 3, dim=0)

            inp_q = ulysses_qkv[0]
            inp_k = ulysses_qkv[1]
            inp_v = ulysses_qkv[2]

            if softmax_scale is None:
                softmax_scale = inp_q.shape[-1] ** (-0.5)

            assert alibi_slopes is None
            inp_k = inp_k.contiguous()
            inp_v = inp_v.contiguous()

            # check_nan_inf(inp_q, "inp_q", dist.get_rank())
            # check_nan_inf(inp_k, "inp_k", dist.get_rank())
            # check_nan_inf(inp_v, "inp_v", dist.get_rank())
            assert inp_q.device == orig_device, f"Pipe: inp_q device {inp_q.device} must be {orig_device} at stage {stage}"
            assert inp_k.device == orig_device, f"Pipe: inp_k device {inp_k.device} must be {orig_device} at stage {stage}"
            assert inp_v.device == orig_device, f"Pipe: inp_v device {inp_v.device} must be {orig_device} at stage {stage}"

            
            out, lse = upipe_ring_flash_attn_forward(
                inp_q,
                inp_k,
                inp_v,
                softmax_scale=softmax_scale,
                dropout_p=dropout_p,
                causal=causal,
                softcap=softcap,
                deterministic=False,
            )

            # with torch.cuda.stream(offload_stream):
                # qkv[stage] = qkv[stage].to("cpu", non_blocking=True)

            # qkv[stage] = None # not needed anymore

            # if softmax_lse is None:
            #     softmax_lse = lse#.clone() # cloning the lse tensor to avoid overwriting the lse tensor
            # else:
            #     softmax_lse = torch.cat([softmax_lse, lse], dim = 1)
            if softmax_lse is None:
                softmax_lse = torch.zeros_like(lse).repeat(1, pipe_degree, 1)
            softmax_lse[:, stage, :] = lse

            # try:
            #     context_layer = torch.cat([out, lse.unsqueeze(-1)], dim = 3)
            # except:
            #     print(f"Pipe: out shape {out.shape}, lse shape {lse.shape}")
            #     raise ValueError("Pipe: out and lse shapes are not compatible")
            context_layer = out
            
            o_bs, o_seqlen, o_shard_hc, o_hs = context_layer.shape
            o_hc = o_shard_hc * world_size
            o_shard_seqlen = o_seqlen // world_size
            
            if stage > 0:
                if U_HANDLE.O_HANDLE!=[]:
                    assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
                    U_HANDLE.O_HANDLE[0].wait()
                clear_o_handle()
                assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                output[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_output#[:,:,:,:-1]
                # softmax_lse[:,:,(stage-1)*world_size:(stage)*world_size] = block_output[:,:,:,-1]
                # output[stage-1] = block_output
                # output[stage-1].to("cpu", non_blocking=True)

            block_output = all_to_all_4D(
                context_layer, 1, 2, False, True # scatter 1, gather 2
            )
            
            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=[]:
                    assert len(U_HANDLE.HANDLE) == 1, f"Pipe: U_HANDLE.HANDLE length {len(U_HANDLE.HANDLE)} must be 1"
                    U_HANDLE.HANDLE[0].wait()
                clear_u_handle()
                if offload_stream is not None:
                    with torch.cuda.stream(offload_stream):
                        qkv[stage+1].to("cpu", non_blocking=True)
                assert next_ulysses_qkv.numel()== (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_qkv shape {next_ulysses_qkv.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_qkv = next_ulysses_qkv.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_qkv = next_ulysses_qkv.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_qkv = next_ulysses_qkv
                # assert not check_nan_inf(ulysses_qkv, f"ulysses_qkv_{stage+1}", dist.get_rank()), f"Pipe: ulysses_qkv is nan or inf at stage {stage+1}"
        
        if U_HANDLE.O_HANDLE!=[]:
            assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
            U_HANDLE.O_HANDLE[0].wait()
            clear_o_handle()
            assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            output[:,:,-(world_size):,:] = block_output#[:,:,:,:-1]
            # softmax_lse[:,:,-(world_size):] = block_output[:,:,:,-1]

        # out, softmax_lse = zigzag_ring_flash_attn_forward_op(out_temp, softmax_lse_temp)

        # this should be out_padded
        to_save_tensors = []
        to_save_tensors.extend(q)
        to_save_tensors.extend(k)
        to_save_tensors.extend(v)
        to_save_tensors.append(output)
        to_save_tensors.append(softmax_lse)

        ctx.save_for_backward(*to_save_tensors)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.ulysses_group = ulysses_group
        ctx.attn_type = attn_type
        ctx.pipe_degree = pipe_degree
        ctx.offload_stream = offload_stream
        ctx.fetch_stream = fetch_stream
        
        assert output is not None, "Output is None!!!"
        assert output.requires_grad, f"Pipe: output requires_grad must be True"
        return output if not return_softmax else (output, softmax_lse, None)
        
    
    @staticmethod
    def backward(ctx, dout, *args):
        saved_tensors = ctx.saved_tensors
        q = saved_tensors[:ctx.pipe_degree]
        k = saved_tensors[ctx.pipe_degree:2*ctx.pipe_degree]
        v = saved_tensors[2*ctx.pipe_degree:3*ctx.pipe_degree]
        out = saved_tensors[-2]
        softmax_lse = saved_tensors[-1]

        pipe_degree = ctx.pipe_degree
        offload_stream = ctx.offload_stream
        fetch_stream = ctx.fetch_stream
        orig_device = dout.device
        world_size = dist.get_world_size(ctx.ulysses_group)

        dqkv = [torch.zeros([3*q[i].shape[0], q[i].shape[1], q[i].shape[2], q[i].shape[3]], dtype=q[i].dtype, device=q[i].device) for i in range(len(q))]
        dq = [None] * len(q)
        dk = [None] * len(q)
        dv = [None] * len(q)

        dout = dout.to("cpu")
        out = out.to("cpu")
        
        dout = [t for t in torch.chunk(dout, pipe_degree, dim = 2)]
        out = [t for t in torch.chunk(out, pipe_degree, dim = 2)]

        if torch.distributed.get_rank() == 0:
            breakpoint()
        torch.distributed.barrier()

        orig_q_dtype = q[0].dtype
        doutqkvo_lse = [torch.cat([q[i], k[i], v[i], dout[i], out[i]], dim=0) for i in range(len(q))] # 5 tensors in total
        eff_bs = doutqkvo_lse.shape[0]
        
        del dout, out, q, k, v
        # doutqkvo_lse = torch.cat([doutqkvo, softmax_lse.unsqueeze(-1).repeat(eff_bs, 1, 1, 1).to(dout.dtype)], dim=3).contiguous() #piggybacking softmax_lse as the last hidden dimension for all 5 other tensors, so be mindful when unpacking
        # doutqkvo_lse  = [t for t in torch.chunk(doutqkvo_lse, pipe_degree, dim = 2)]

        softmax_lse = [t for t in torch.chunk(softmax_lse, pipe_degree, dim = 1)]

        doutqkvo_lse[0] = doutqkvo_lse[0].to(orig_device)
        with torch.cuda.stream(fetch_stream):
            doutqkvo_lse[1] = doutqkvo_lse[1].to(orig_device, non_blocking=True)
        
        PROCESS_GROUP.ULYSSES_PG = ctx.ulysses_group
        ulysses_doutqkvo_lse = all_to_all_4D(doutqkvo_lse[0], 2, 1, False, False) #note that this is scatter 2, gather 1 since this is inverse of forward pass
        
        torch.cuda.current_stream().wait_stream(fetch_stream)
        
        for stage in range(pipe_degree):
            # if dist.get_rank(ctx.ulysses_group) == 0:
            #     print(f"Tensor shape: {ulysses_doutqkvo_lse.shape}\t stage: {stage}")
            if stage+1 != pipe_degree:
                if stage+2 < pipe_degree:
                    with torch.cuda.stream(fetch_stream):
                        doutqkvo_lse[stage+2] = doutqkvo_lse[stage+2].to(orig_device, non_blocking=True)
                        # softmax_lse[stage+1].to(orig_device, non_blocking=True)
                next_ulysses_doutqkvo_lse = all_to_all_4D(doutqkvo_lse[stage+1], 2, 1, False, True) #note that this is scatter 2, gather 1 since this is inverse of forward pass
                bs, shard_seqlen, hc, hs = doutqkvo_lse[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            # Chunk the tensor to separate q, k, v, dout, out components (5 components total)
            ulysses_doutqkvo_lse = torch.chunk(ulysses_doutqkvo_lse, eff_bs, dim=0)
            assert len(ulysses_doutqkvo_lse) == eff_bs, f"Pipe: ulysses_tensors length {len(ulysses_doutqkvo_lse)} must be equal to eff_bs {eff_bs}, tensor shape {ulysses_doutqkvo_lse[0].shape}"
            
            # Extract tensors and ensure they are contiguous to avoid CUDA misalignment errors
            inp_q = ulysses_doutqkvo_lse[0]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_k = ulysses_doutqkvo_lse[1]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_v = ulysses_doutqkvo_lse[2]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_dout = ulysses_doutqkvo_lse[3]#[:,:,:,:].contiguous()
            inp_out = ulysses_doutqkvo_lse[4]#[:,:,:,:].contiguous().to(orig_q_dtype)
            
            # Extract softmax_lse and ensure contiguous after transpose
            # inp_softmax_lse = ulysses_doutqkvo_lse[-1][:,:,:,-1].contiguous().to(orig_q_dtype) # all 5 tensors have the same softmax_lse
            # inp_softmax_lse_transposed = inp_softmax_lse.transpose(1, 2).contiguous()
            inp_softmax_lse_transposed = softmax_lse[stage]#.transpose(1, 2).contiguous()

            softmax_scale = inp_q.shape[-1] ** (-0.5)
            
            # Verify tensor contiguity for debugging
            # assert inp_q.is_contiguous(), f"inp_q is not contiguous: {inp_q.stride()}"
            # assert inp_k.is_contiguous(), f"inp_k is not contiguous: {inp_k.stride()}"
            # assert inp_v.is_contiguous(), f"inp_v is not contiguous: {inp_v.stride()}"
            # assert inp_dout.is_contiguous(), f"inp_dout is not contiguous: {inp_dout.stride()}"
            # assert inp_out.is_contiguous(), f"inp_out is not contiguous: {inp_out.stride()}"
            # assert inp_softmax_lse_transposed.is_contiguous(), f"inp_softmax_lse_transposed is not contiguous: {inp_softmax_lse_transposed.stride()}"

            # assert not check_nan_inf(inp_q, f"inp_q", dist.get_rank()), f"Pipe: inp_q is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_k, f"inp_k", dist.get_rank()), f"Pipe: inp_k is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_v, f"inp_v", dist.get_rank()), f"Pipe: inp_v is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_dout, f"inp_dout", dist.get_rank()), f"Pipe: inp_dout is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_out, f"inp_out", dist.get_rank()), f"Pipe: inp_out is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_softmax_lse_transposed, f"inp_softmax_lse_transposed", dist.get_rank()), f"Pipe: inp_softmax_lse_transposed is nan or inf at stage {stage}"

            attn_dq, attn_dk, attn_dv = upipe_ring_flash_attn_backward(
                ctx.ring_group,
                inp_dout,
                inp_q,
                inp_k,
                inp_v,
                inp_out,
                inp_softmax_lse_transposed,
                softmax_scale=softmax_scale,
                dropout_p=ctx.dropout_p,
                causal=ctx.causal,
                window_size=ctx.window_size,
                softcap=ctx.softcap,
                alibi_slopes=ctx.alibi_slopes,
                deterministic=ctx.deterministic,
                attn_type=ctx.attn_type,
            )
            
            # not needed anymore
            doutqkvo_lse[stage] = None
            softmax_lse[stage] = None

            # assert not check_nan_inf(attn_dq, f"attn_dq -{stage}", dist.get_rank()), f"Pipe: attn_dq is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dk, f"attn_dk -{stage}", dist.get_rank()), f"Pipe: attn_dk is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dv, f"attn_dv -{stage}", dist.get_rank()), f"Pipe: attn_dv is nan or inf at stage {stage}"

            if stage > 0:
                if U_HANDLE.O_HANDLE!=[]:
                    assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
                    U_HANDLE.O_HANDLE[0].wait()
                clear_o_handle()
                assert block_grads.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_grads shape {block_grads.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                block_grads = block_grads.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                block_grads = block_grads.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                dqkv[stage-1] = block_grads
                grads = torch.chunk(dqkv[stage-1], 3, dim=0)
                dq[stage-1] = grads[0]
                dk[stage-1] = grads[1]
                dv[stage-1] = grads[2]
                #block_grads = torch.chunk(block_grads, 3, dim=0)
                # dq[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[0]
                # dk[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[1]
                # dv[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[2]
                
            
            attn_grads = torch.cat([attn_dq, attn_dk, attn_dv], dim=0)
            o_bs, o_seqlen, o_shard_hc, o_hs = attn_grads.shape
            o_hc = o_shard_hc * world_size
            o_shard_seqlen = o_seqlen // world_size

            block_grads = all_to_all_4D(
                attn_grads, 1, 2, False, True # scatter 1, gather 2
            )

            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=[]:
                    assert len(U_HANDLE.HANDLE) == 1, f"Pipe: U_HANDLE.HANDLE length {len(U_HANDLE.HANDLE)} must be 1"
                    U_HANDLE.HANDLE[0].wait()
                clear_u_handle()
                if offload_stream is not None:
                    with torch.cuda.stream(offload_stream):
                        doutqkvo_lse[stage+1].to("cpu", non_blocking=True)
                        # softmax_lse[stage].to("cpu", non_blocking=True)
                assert next_ulysses_doutqkvo_lse.numel() ==  (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_doutqkvo_lse shape {next_ulysses_doutqkvo_lse.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse
        
        if U_HANDLE.O_HANDLE!=[]:
            assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
            U_HANDLE.O_HANDLE[0].wait()
            clear_o_handle()
            assert block_grads.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_grads shape {block_grads.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_grads = block_grads.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_grads = block_grads.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            dqkv[-1] = block_grads
            grads = torch.chunk(dqkv[-1], 3, dim=0)
            dq[-1] = grads[0]
            dk[-1] = grads[1]
            dv[-1] = grads[2]
            # assert dqkv.shape[0] == 3, f"Pipe: dqkv shape {dqkv.shape} must be 3 in the bs dimension"
            # dqkv = torch.chunk(dqkv, 3, dim=0)
            # block_grads = torch.chunk(block_grads, 3, dim=0)
            # dq[:,:,-(world_size):,:] = block_grads[0]
            # dk[:,:,-(world_size):,:] = block_grads[1]
            # dv[:,:,-(world_size):,:] = block_grads[2]

        # assert not check_nan_inf(dqkv[0], f"dqkv[0]", dist.get_rank()), f"Pipe: dqkv[0] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[1], f"dqkv[1]", dist.get_rank()), f"Pipe: dqkv[1] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[2], f"dqkv[2]", dist.get_rank()), f"Pipe: dqkv[2] is nan or inf at stage"
        # assert len(dqkv) == 3, f"Pipe: dqkv length {len(dqkv)} must be 3"
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None


class DualStageUpipeRingFlashAttnFuncQKVPacked(torch.autograd.Function):
    
    @staticmethod
    def forward(
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        attn_type,
    ):  
        # Do all the ulysses stuff here
        bs, shard_seqlen, q_hc, hs = q.shape

        world_size = dist.get_world_size(ulysses_group)
        pipe_degree = 2

        # context_layer = torch.empty([bs, shard_seqlen*world_size, hc//world_size, hs], dtype=query.dtype, device=query.device)
        output = torch.zeros_like(q)
        # softmax_lse = torch.zeros([bs, shard_seqlen, hc], dtype=q.dtype, device=q.device)
        softmax_lse = None#torch.zeros([bs, pipe_degree, shard_seqlen*world_size], dtype=q.dtype, device=q.device)
        

        assert k.shape[2] == q.shape[2], f"Pipe: num heads in key {k.shape[2]} must be equal to query {q.shape[2]}"
        assert k.shape[2]//world_size >= 2, f"DualStage UPipe only works for num heads per device >= 2"
        qkv = torch.cat([q, k, v]).contiguous()

        orig_device = qkv.device

        qkv = torch.tensor_split(qkv, [world_size], dim = 2) # indices_or_sections: [world_size] -> split into qkv[:,:,:world_size,:], and qkv[:,:,world_size:,:]
        assert len(qkv) == 2, f"DualStage UPipe only works for QKV num heads per device >= 2"
        # async offload the qkv tensors to CPU
        # if offload_stream is not None:
        #     with torch.cuda.stream(offload_stream):
        #         for tnsr in qkv[2:]:
        #             tnsr.to("cpu", non_blocking=True)

        # first all-to-all is blocking
        # ulysses_qkv = SeqAllToAll4D.apply(
        #     self.ulysses_pg, qkv[:,:,:world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, False
        # )
        ulysses_qkv = all_to_all_4D(
            qkv[0], 2, 1, ulysses_group, False, False # scatter 2, gather 1
        )
        
        if offload_stream is not None:
            torch.cuda.current_stream().synchronize()
            with torch.cuda.stream(offload_stream):
                qkv[0].to("cpu", non_blocking=True)

        # Initialize variables to avoid undefined variable errors
        block_output = None
        o_bs = o_shard_seqlen = o_hc = o_hs = None

        # global variables
        global process_group
        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size

        for stage in range(pipe_degree):
            if stage+1 != pipe_degree:
                # next_ulysses_qkv = SeqAllToAll4D.apply(
                #     self.ulysses_pg, qkv[:,:,(stage+1)*world_size:(stage+2)*world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, True
                # )
                # if stage+2 < pipe_degree:
                #     if fetch_stream is not None:
                #         fetch_stream.synchronize() # make sure the last qkv is fetched to GPU
                #         with torch.cuda.stream(fetch_stream):
                #             qkv[stage+2].to(orig_device, non_blocking=True)
                next_ulysses_qkv = all_to_all_4D(
                    qkv[stage+1], 2, 1, ulysses_group, False, True # scatter 2, gather 1
                )
                bs, shard_seqlen, hc, hs = qkv[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            # output[stage].to(orig_device, non_blocking=True)

            ulysses_qkv = torch.chunk(ulysses_qkv, 3, dim=0)

            inp_q = ulysses_qkv[0]
            inp_k = ulysses_qkv[1]
            inp_v = ulysses_qkv[2]

            if softmax_scale is None:
                softmax_scale = inp_q.shape[-1] ** (-0.5)

            assert alibi_slopes is None
            inp_k = inp_k.contiguous()
            inp_v = inp_v.contiguous()

            # check_nan_inf(inp_q, "inp_q", dist.get_rank())
            # check_nan_inf(inp_k, "inp_k", dist.get_rank())
            # check_nan_inf(inp_v, "inp_v", dist.get_rank())

            out, lse = upipe_ring_flash_attn_forward(
                inp_q,
                inp_k,
                inp_v,
                softmax_scale=softmax_scale,
                dropout_p=dropout_p,
                causal=causal,
                softcap=softcap,
                deterministic=False,
            )

            # if softmax_lse is None:
            #     softmax_lse = lse#.clone() # cloning the lse tensor to avoid overwriting the lse tensor
            # else:
            #     softmax_lse = torch.cat([softmax_lse, lse], dim = 1)
            if softmax_lse is None:
                softmax_lse = torch.zeros_like(lse).repeat(1, q_hc//world_size, 1)
                softmax_lse[:, :1, :] = lse
            else:
                softmax_lse[:, 1:, :] = lse
            # if output is None:
            #     output = torch.zeros_like(out)
            #     output[:, :, :1, :] = out
            # else:
            #     output[:, :, 1:, :] = out

            # try:
            #     context_layer = torch.cat([out, lse.unsqueeze(-1)], dim = 3)
            # except:
            #     print(f"Pipe: out shape {out.shape}, lse shape {lse.shape}")
            #     raise ValueError("Pipe: out and lse shapes are not compatible")
            context_layer = out
            
            if stage == 0:
                o_bs, o_seqlen, o_shard_hc, o_hs = context_layer.shape
                o_hc = o_shard_hc * world_size
                o_shard_seqlen = o_seqlen // world_size
            else:
                # final shapes for the output of the second stage
                fi_bs, fi_seqlen, fi_shard_hc, fi_hs = context_layer.shape
                fi_hc = fi_shard_hc * world_size
                fi_shard_seqlen = fi_seqlen // world_size
            
            if stage > 0:
                if U_HANDLE.O_HANDLE!=[]:
                    assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
                    U_HANDLE.O_HANDLE[0].wait()
                clear_o_handle()
                assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                output[:,:,:world_size,:] = block_output#[:,:,:,:-1]
                # softmax_lse[:,:,(stage-1)*world_size:(stage)*world_size] = block_output[:,:,:,-1]
                # output[stage-1] = block_output
                # output[stage-1].to("cpu", non_blocking=True)

            block_output = all_to_all_4D(
                context_layer, 1, 2, ulysses_group, False, True # scatter 1, gather 2
            )
            
            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=[]:
                    assert len(U_HANDLE.HANDLE) == 1, f"Pipe: U_HANDLE.HANDLE length {len(U_HANDLE.HANDLE)} must be 1"
                    U_HANDLE.HANDLE[0].wait()
                clear_u_handle()
                if offload_stream is not None:
                    with torch.cuda.stream(offload_stream):
                        qkv[stage+1].to("cpu", non_blocking=True)
                assert next_ulysses_qkv.numel()== (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_qkv shape {next_ulysses_qkv.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_qkv = next_ulysses_qkv.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_qkv = next_ulysses_qkv.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_qkv = next_ulysses_qkv
                # assert not check_nan_inf(ulysses_qkv, f"ulysses_qkv_{stage+1}", dist.get_rank()), f"Pipe: ulysses_qkv is nan or inf at stage {stage+1}"
        
        if U_HANDLE.O_HANDLE!=[]:
            assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
            U_HANDLE.O_HANDLE[0].wait()
            clear_o_handle()
            assert block_output.numel() == (fi_bs*fi_shard_seqlen*fi_hc*fi_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {fi_bs*fi_shard_seqlen*fi_hc*fi_hs}"
            block_output = block_output.reshape(fi_hc, fi_shard_seqlen, fi_bs, fi_hs)
            block_output = block_output.transpose(0, 2).contiguous().reshape(fi_bs, fi_shard_seqlen, fi_hc, fi_hs)
            output[:,:,world_size:,:] = block_output#[:,:,:,:-1]
            # softmax_lse[:,:,-(world_size):] = block_output[:,:,:,-1]

        # out, softmax_lse = zigzag_ring_flash_attn_forward_op(out_temp, softmax_lse_temp)

        # this should be out_padded
        ctx.save_for_backward(q, k, v, output, softmax_lse)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.ulysses_group = ulysses_group
        ctx.attn_type = attn_type
        ctx.pipe_degree = pipe_degree
        ctx.offload_stream = offload_stream
        ctx.fetch_stream = fetch_stream

        # if dist.get_rank()==0:
        #     print("Forward Pass completed!")
        return output if not return_softmax else (output, softmax_lse, None)
        
    
    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse = ctx.saved_tensors

        pipe_degree = ctx.pipe_degree
        offload_stream = ctx.offload_stream
        fetch_stream = ctx.fetch_stream
        orig_device = dout.device
        world_size = dist.get_world_size(ctx.ulysses_group)

        dqkv = torch.zeros([3*q.shape[0], q.shape[1], q.shape[2], q.shape[3]], dtype=q.dtype, device=q.device)

        orig_q_dtype = q.dtype
        doutqkvo_lse = torch.cat([q, k, v, dout, out], dim=0).contiguous()
        eff_bs = doutqkvo_lse.shape[0]
        # doutqkvo_lse = torch.cat([doutqkvo, softmax_lse.unsqueeze(-1).repeat(eff_bs, 1, 1, 1).to(dout.dtype)], dim=3).contiguous() #piggybacking softmax_lse as the last hidden dimension for all 5 other tensors, so be mindful when unpacking
        doutqkvo_lse  = torch.tensor_split(doutqkvo_lse, [world_size], dim = 2) # indices_or_sections: [world_size] -> split into doutqkvo_lse[:,:,:world_size,:], and doutqkvo_lse[:,:,world_size:,:]
        assert len(doutqkvo_lse) == 2, f"Pipe: doutqkvo_lse length {len(doutqkvo_lse)} must be 2"
        softmax_lse = torch.tensor_split(softmax_lse, [1], dim = 1) # indices_or_sections: [1] -> split into softmax_lse[:,:1,:], and softmax_lse[:,1:,:]
        assert len(softmax_lse) == 2, f"Pipe: softmax_lse length {len(softmax_lse)} must be 2"

        # if offload_stream is not None:
        #     # async offload the dout tensor to CPU
        #     with torch.cuda.stream(offload_stream):
        #         for tnsr in doutqkvo_lse[2:]:
        #             tnsr.to("cpu", non_blocking=True)
        #         for tnsr in softmax_lse[1:]:
        #             tnsr.to("cpu", non_blocking=True)
        
        ulysses_doutqkvo_lse = all_to_all_4D(doutqkvo_lse[0], 2, 1, ctx.ulysses_group, False, False) #note that this is scatter 2, gather 1 since this is inverse of forward pass
        
        if offload_stream is not None:
            torch.cuda.current_stream().synchronize()
            with torch.cuda.stream(offload_stream):
                doutqkvo_lse[0].to("cpu", non_blocking=True)
        
        for stage in range(pipe_degree):
            # if dist.get_rank(ctx.ulysses_group) == 0:
            #     print(f"Tensor shape: {ulysses_doutqkvo_lse.shape}\t stage: {stage}")
            if stage+1 != pipe_degree:
                # if stage+2 < pipe_degree:
                #     if fetch_stream is not None:
                #         fetch_stream.synchronize() # make sure the last qkv is fetched to GPU
                #         with torch.cuda.stream(fetch_stream):
                #             doutqkvo_lse[stage+2].to(orig_device, non_blocking=True)
                #             softmax_lse[stage+1].to(orig_device, non_blocking=True)
                next_ulysses_doutqkvo_lse = all_to_all_4D(doutqkvo_lse[stage+1], 2, 1, ctx.ulysses_group, False, True) #note that this is scatter 2, gather 1 since this is inverse of forward pass
                bs, shard_seqlen, hc, hs = doutqkvo_lse[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            # Chunk the tensor to separate q, k, v, dout, out components (5 components total)
            ulysses_doutqkvo_lse = torch.chunk(ulysses_doutqkvo_lse, eff_bs, dim=0)
            assert len(ulysses_doutqkvo_lse) == eff_bs, f"Pipe: ulysses_tensors length {len(ulysses_doutqkvo_lse)} must be equal to eff_bs {eff_bs}, tensor shape {ulysses_doutqkvo_lse[0].shape}"
            
            # Extract tensors and ensure they are contiguous to avoid CUDA misalignment errors
            inp_q = ulysses_doutqkvo_lse[0]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_k = ulysses_doutqkvo_lse[1]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_v = ulysses_doutqkvo_lse[2]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_dout = ulysses_doutqkvo_lse[3]#[:,:,:,:].contiguous()
            inp_out = ulysses_doutqkvo_lse[4]#[:,:,:,:].contiguous().to(orig_q_dtype)
            
            # Extract softmax_lse and ensure contiguous after transpose
            # inp_softmax_lse = ulysses_doutqkvo_lse[-1][:,:,:,-1].contiguous().to(orig_q_dtype) # all 5 tensors have the same softmax_lse
            # inp_softmax_lse_transposed = inp_softmax_lse.transpose(1, 2).contiguous()
            inp_softmax_lse_transposed = softmax_lse[stage]#.transpose(1, 2).contiguous()

            softmax_scale = inp_q.shape[-1] ** (-0.5)
            
            # Verify tensor contiguity for debugging
            # assert inp_q.is_contiguous(), f"inp_q is not contiguous: {inp_q.stride()}"
            # assert inp_k.is_contiguous(), f"inp_k is not contiguous: {inp_k.stride()}"
            # assert inp_v.is_contiguous(), f"inp_v is not contiguous: {inp_v.stride()}"
            # assert inp_dout.is_contiguous(), f"inp_dout is not contiguous: {inp_dout.stride()}"
            # assert inp_out.is_contiguous(), f"inp_out is not contiguous: {inp_out.stride()}"
            # assert inp_softmax_lse_transposed.is_contiguous(), f"inp_softmax_lse_transposed is not contiguous: {inp_softmax_lse_transposed.stride()}"

            # assert not check_nan_inf(inp_q, f"inp_q", dist.get_rank()), f"Pipe: inp_q is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_k, f"inp_k", dist.get_rank()), f"Pipe: inp_k is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_v, f"inp_v", dist.get_rank()), f"Pipe: inp_v is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_dout, f"inp_dout", dist.get_rank()), f"Pipe: inp_dout is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_out, f"inp_out", dist.get_rank()), f"Pipe: inp_out is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_softmax_lse_transposed, f"inp_softmax_lse_transposed", dist.get_rank()), f"Pipe: inp_softmax_lse_transposed is nan or inf at stage {stage}"

            attn_dq, attn_dk, attn_dv = upipe_ring_flash_attn_backward(
                ctx.ring_group,
                inp_dout,
                inp_q,
                inp_k,
                inp_v,
                inp_out,
                inp_softmax_lse_transposed,
                softmax_scale=softmax_scale,
                dropout_p=ctx.dropout_p,
                causal=ctx.causal,
                window_size=ctx.window_size,
                softcap=ctx.softcap,
                alibi_slopes=ctx.alibi_slopes,
                deterministic=ctx.deterministic,
                attn_type=ctx.attn_type,
            )

            if offload_stream is not None:
                torch.cuda.current_stream().synchronize()
                with torch.cuda.stream(offload_stream):
                    softmax_lse[stage].to("cpu", non_blocking=True)

            # assert not check_nan_inf(attn_dq, f"attn_dq -{stage}", dist.get_rank()), f"Pipe: attn_dq is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dk, f"attn_dk -{stage}", dist.get_rank()), f"Pipe: attn_dk is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dv, f"attn_dv -{stage}", dist.get_rank()), f"Pipe: attn_dv is nan or inf at stage {stage}"

            if stage > 0:
                if U_HANDLE.O_HANDLE!=[]:
                    assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
                    U_HANDLE.O_HANDLE[0].wait()
                clear_o_handle()
                assert block_grads.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_grads shape {block_grads.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                block_grads = block_grads.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                block_grads = block_grads.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                dqkv[:,:,:world_size,:] = block_grads
                #block_grads = torch.chunk(block_grads, 3, dim=0)
                # dq[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[0]
                # dk[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[1]
                # dv[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[2]
                
            attn_grads = torch.cat([attn_dq, attn_dk, attn_dv], dim=0)
            
            if stage == 0:
                o_bs, o_seqlen, o_shard_hc, o_hs = attn_grads.shape
                o_hc = o_shard_hc * world_size
                o_shard_seqlen = o_seqlen // world_size
            else:
                # final shapes for the output of the second stage
                fi_bs, fi_seqlen, fi_shard_hc, fi_hs = attn_grads.shape
                fi_hc = fi_shard_hc * world_size
                fi_shard_seqlen = fi_seqlen // world_size

            block_grads = all_to_all_4D(
                attn_grads, 1, 2, ctx.ulysses_group, False, True # scatter 1, gather 2
            )

            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=[]:
                    assert len(U_HANDLE.HANDLE) == 1, f"Pipe: U_HANDLE.HANDLE length {len(U_HANDLE.HANDLE)} must be 1"
                    U_HANDLE.HANDLE[0].wait()
                clear_u_handle()
                if offload_stream is not None:
                    with torch.cuda.stream(offload_stream):
                        doutqkvo_lse[stage+1].to("cpu", non_blocking=True)
                        # softmax_lse[stage].to("cpu", non_blocking=True)
                assert next_ulysses_doutqkvo_lse.numel() ==  (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_doutqkvo_lse shape {next_ulysses_doutqkvo_lse.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse
        
        if U_HANDLE.O_HANDLE!=[]:
            assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
            U_HANDLE.O_HANDLE[0].wait()
            clear_o_handle()
            assert block_grads.numel() == (fi_bs*fi_shard_seqlen*fi_hc*fi_hs), f"Pipe: block_grads shape {block_grads.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {fi_bs*fi_shard_seqlen*fi_hc*fi_hs}"
            block_grads = block_grads.reshape(fi_hc, fi_shard_seqlen, fi_bs, fi_hs)
            block_grads = block_grads.transpose(0, 2).contiguous().reshape(fi_bs, fi_shard_seqlen, fi_hc, fi_hs)
            dqkv[:,:,world_size:,:] = block_grads
            assert dqkv.shape[0] == 3, f"Pipe: dqkv shape {dqkv.shape} must be 3 in the bs dimension"
            dqkv = torch.chunk(dqkv, 3, dim=0)
            # block_grads = torch.chunk(block_grads, 3, dim=0)
            # dq[:,:,-(world_size):,:] = block_grads[0]
            # dk[:,:,-(world_size):,:] = block_grads[1]
            # dv[:,:,-(world_size):,:] = block_grads[2]

        # assert not check_nan_inf(dqkv[0], f"dqkv[0]", dist.get_rank()), f"Pipe: dqkv[0] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[1], f"dqkv[1]", dist.get_rank()), f"Pipe: dqkv[1] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[2], f"dqkv[2]", dist.get_rank()), f"Pipe: dqkv[2] is nan or inf at stage"
        assert len(dqkv) == 3, f"Pipe: dqkv length {len(dqkv)} must be 3"
        return dqkv[0], dqkv[1], dqkv[2], None, None, None, None, None, None, None, None, None, None, None, None, None


@torch.library.custom_op("yunchang::_two_streams_dualstage_forward", mutates_args=(), device_types="cuda")
def two_streams_dualstage_forward(q: torch.Tensor,
                                  k: torch.Tensor,
                                  v: torch.Tensor,
                                  dropout_p: float = 0,
                                  softmax_scale: float = 0,
                                  causal: bool = True,
                                  ) -> Tuple[torch.Tensor, torch.Tensor]:  
        
        global ring_group, ulysses_group, two_streams, attn_type, alibi_slopes, window_size

        # if two_streams is not None:
        #     two_streams = [torch.cuda.current_stream()] + two_streams
        # Do all the ulysses stuff here
        bs, shard_seqlen, q_hc, hs = q.shape

        world_size = dist.get_world_size(ulysses_group)
        pipe_degree = 2

        # context_layer = torch.empty([bs, shard_seqlen*world_size, hc//world_size, hs], dtype=query.dtype, device=query.device)
        output = []#torch.zeros_like(q)
        # softmax_lse = torch.zeros([bs, shard_seqlen, hc], dtype=q.dtype, device=q.device)
        softmax_lse = []#None#torch.zeros([bs, pipe_degree, shard_seqlen*world_size], dtype=q.dtype, device=q.device)
        

        assert k.shape[2] == q.shape[2], f"Pipe: num heads in key {k.shape[2]} must be equal to query {q.shape[2]}"
        assert k.shape[2]//world_size >= 2, f"DualStage UPipe only works for num heads per device >= 2"
        
        # Check if input tensors q, k, v are NaN before concatenation
        # assert not check_nan_inf(q, f"input_q", dist.get_rank()), f"Pipe: input q is nan or inf"
        # assert not check_nan_inf(k, f"input_k", dist.get_rank()), f"Pipe: input k is nan or inf"
        # assert not check_nan_inf(v, f"input_v", dist.get_rank()), f"Pipe: input v is nan or inf"
        
        qkv = torch.cat([q, k, v]).contiguous()

        # Check if concatenated qkv is NaN before split
        # assert not check_nan_inf(qkv, f"qkv_before_split", dist.get_rank()), f"Pipe: qkv is nan or inf before split"

        orig_device = qkv.device

        qkv = [t for t in torch.tensor_split(qkv, [world_size], dim = 2)] # indices_or_sections: [world_size] -> split into qkv[:,:,:world_size,:], and qkv[:,:,world_size:,:]
        assert len(qkv) == 2, f"DualStage UPipe only works for QKV num heads per device >= 2"
        
        # Check if qkv chunks are NaN after split
        # assert not check_nan_inf(qkv[0], f"qkv[0]_after_split", dist.get_rank()), f"Pipe: qkv[0] is nan or inf after split"
        # assert not check_nan_inf(qkv[1], f"qkv[1]_after_split", dist.get_rank()), f"Pipe: qkv[1] is nan or inf after split"
    
        qkv_out = [torch.empty([world_size, shard_seqlen, 3*bs, 1, hs], dtype=q.dtype, device=q.device), #(P, seq_len/P, bs, hc/P, hs)
                   torch.empty([world_size, shard_seqlen, 3*bs, (q_hc-world_size)//world_size, hs], dtype=q.dtype, device=q.device)
                   ]
        
        PROCESS_GROUP.ULYSSES_PG = ulysses_group
        PROCESS_GROUP.RING_PG = ring_group

        two_streams[0].wait_stream(torch.cuda.current_stream())
        two_streams[1].wait_stream(torch.cuda.current_stream())
        # two_streams[0].synchronize()
        # two_streams[0].wait_stream(two_streams[0])
        with torch.cuda.stream(two_streams[0]): # send qkv[0] on the first stream
            
            qkv_out[0] = all_to_all_4D(
                qkv[0], 2, 1, False, False, qkv_out[0] # scatter 2, gather 1
            )

        # two_streams[1].synchronize()
        # two_streams[1].wait_stream(two_streams[0])
        with torch.cuda.stream(two_streams[1]): # send qkv[1] on the second stream
            
            qkv_out[1] = all_to_all_4D(
                qkv[1], 2, 1, False, False, qkv_out[1] # scatter 2, gather 1
            )

        # two_streams[0].synchronize()
        # two_streams[0].wait_stream(two_streams[0])
        with torch.cuda.stream(two_streams[0]): # send qkv[0] on the first stream

            ulysses_qkv_0 = torch.chunk(qkv_out[0], 3, dim=0)

            inp_q_0 = ulysses_qkv_0[0]
            inp_k_0 = ulysses_qkv_0[1]
            inp_v_0 = ulysses_qkv_0[2]

            # if softmax_scale is None:
            # softmax_scale_0 = inp_q_0.shape[-1] ** (-0.5)

            assert alibi_slopes is None
            inp_k_0 = inp_k_0.contiguous()
            inp_v_0 = inp_v_0.contiguous()

            # assert not check_nan_inf(inp_q, f"inp_q_stage_{stage}", dist.get_rank()), f"Pipe: inp_q is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_k, f"inp_k_stage_{stage}", dist.get_rank()), f"Pipe: inp_k is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_v, f"inp_v_stage_{stage}", dist.get_rank()), f"Pipe: inp_v is nan or inf at stage {stage}"

            out_0, lse_0 = upipe_ring_flash_attn_forward(
                inp_q_0,
                inp_k_0,
                inp_v_0,
                softmax_scale=softmax_scale,
                dropout_p=dropout_p,
                causal=causal,
                deterministic=False,
            )

            # o_bs, o_seqlen, o_shard_hc, o_hs = out_0.shape
            # o_hc = o_shard_hc * world_size
            # o_shard_seqlen = o_seqlen // world_size

            block_output_0 = all_to_all_4D(
                out_0, 1, 2, False, False # scatter 1, gather 2
            )
            # assert block_output_0.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output_0 shape {block_output_0.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            # block_output_0 = block_output_0.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            # block_output_0 = block_output_0.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)


        
        # two_streams[1].synchronize()
        # two_streams[1].wait_stream(two_streams[1])
        with torch.cuda.stream(two_streams[1]): # send qkv[1] on the second stream

            # qkv_out[1] = all_to_all_4D(
            #     qkv[1], 2, 1, ulysses_group, False, False, qkv_out[1] # scatter 2, gather 1
            # )

            ulysses_qkv_1 = torch.chunk(qkv_out[1], 3, dim=0)

            inp_q_1 = ulysses_qkv_1[0]
            inp_k_1 = ulysses_qkv_1[1]
            inp_v_1 = ulysses_qkv_1[2]

            # if softmax_scale is None:
            # softmax_scale_1 = inp_q_1.shape[-1] ** (-0.5)

            assert alibi_slopes is None
            inp_k_1 = inp_k_1.contiguous()
            inp_v_1 = inp_v_1.contiguous()

            # assert not check_nan_inf(inp_q, f"inp_q_stage_{stage}", dist.get_rank()), f"Pipe: inp_q is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_k, f"inp_k_stage_{stage}", dist.get_rank()), f"Pipe: inp_k is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_v, f"inp_v_stage_{stage}", dist.get_rank()), f"Pipe: inp_v is nan or inf at stage {stage}"

            out_1, lse_1 = upipe_ring_flash_attn_forward(
                inp_q_1,
                inp_k_1,
                inp_v_1,
                softmax_scale=softmax_scale,
                dropout_p=dropout_p,
                causal=causal,
                deterministic=False,
            )

            # fi_bs, fi_seqlen, fi_shard_hc, fi_hs = out_1.shape
            # fi_hc = fi_shard_hc * world_size
            # fi_shard_seqlen = fi_seqlen // world_size

            last_block_output = all_to_all_4D(
                out_1, 1, 2, False, False # scatter 1, gather 2
            )

            # assert last_block_output.numel() == (fi_bs*fi_shard_seqlen*fi_hc*fi_hs), f"Pipe: last_block_output shape {last_block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {fi_bs*fi_shard_seqlen*fi_hc*fi_hs}"
            # last_block_output = last_block_output.reshape(fi_hc, fi_shard_seqlen, fi_bs, fi_hs)
            # last_block_output = last_block_output.transpose(0, 2).contiguous().reshape(fi_bs, fi_shard_seqlen, fi_hc, fi_hs)
        
        # torch.cuda.synchronize()
        torch.cuda.current_stream().wait_stream(two_streams[0])
        torch.cuda.current_stream().wait_stream(two_streams[1])
        
        output = torch.cat([block_output_0, last_block_output], dim=2)
        softmax_lse = torch.cat([lse_0, lse_1], dim=1)
        return output, softmax_lse


@two_streams_dualstage_forward.register_fake
def _(q: torch.Tensor,
      k: torch.Tensor,
      v: torch.Tensor,
      dropout_p: float = 0,
      softmax_scale: float = 0,
      causal: bool = True,
      ) -> Tuple[torch.Tensor, torch.Tensor]:  
    
    bs, sl, nh, d = q.shape
    out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    lse = torch.empty([bs, sl, nh], dtype=q.dtype, device=q.device)
    return out, lse

class DualStageUpipeRingFlashAttnFuncQKVPackedTwoStreams(torch.autograd.Function):
    
    @staticmethod
    def forward(
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        two_streams,
        attn_type,
    ):  
        
        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.ring_group = ring_group
        current_module.ulysses_group = ulysses_group
        current_module.two_streams = two_streams
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size

        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        output, softmax_lse = two_streams_dualstage_forward(q,
                                                            k,
                                                            v,
                                                            dropout_p,
                                                            softmax_scale,
                                                            causal)

        # assert not check_nan_inf(output, f"final_output", dist.get_rank()), f"Pipe: final output is nan or inf"
        # assert not check_nan_inf(softmax_lse, f"softmax_lse", dist.get_rank()), f"Pipe: softmax_lse is nan or inf at stage final"

        ctx.save_for_backward(q, k, v, output, softmax_lse)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.ulysses_group = ulysses_group
        ctx.attn_type = attn_type
        ctx.pipe_degree = 2
        ctx.offload_stream = offload_stream
        ctx.fetch_stream = fetch_stream
        ctx.two_streams = two_streams

        # if dist.get_rank()==0:
        #     print("Forward Pass completed!")
        return output if not return_softmax else (output, softmax_lse, None)
        
    
    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse = ctx.saved_tensors

        bs, shard_seqlen, q_hc, hs = q.shape

        pipe_degree = ctx.pipe_degree
        offload_stream = ctx.offload_stream
        fetch_stream = ctx.fetch_stream
        orig_device = dout.device
        world_size = dist.get_world_size(ctx.ulysses_group)

        dqkv = torch.zeros([3*q.shape[0], q.shape[1], q.shape[2], q.shape[3]], dtype=q.dtype, device=q.device)

        orig_q_dtype = q.dtype
        doutqkvo = torch.cat([q, k, v, dout, out], dim=0).contiguous()
        eff_bs = doutqkvo.shape[0]

        doutqkvo  = [t for t in torch.tensor_split(doutqkvo, [world_size], dim = 2)] # indices_or_sections: [world_size] -> split into doutqkvo[:,:,:world_size,:], and doutqkvo[:,:,world_size:,:]
        assert len(doutqkvo) == 2, f"Pipe: doutqkvo length {len(doutqkvo)} must be 2"

        softmax_lse = [t for t in torch.tensor_split(softmax_lse, [1], dim = 1)] # indices_or_sections: [1] -> split into softmax_lse[:,:1,:], and softmax_lse[:,1:,:]
        assert len(softmax_lse) == 2, f"Pipe: softmax_lse length {len(softmax_lse)} must be 2"

        # if offload_stream is not None:
        #     # async offload the dout tensor to CPU
        #     with torch.cuda.stream(offload_stream):
        #         for tnsr in doutqkvo_lse[2:]:
        #             tnsr.to("cpu", non_blocking=True)
        #         for tnsr in softmax_lse[1:]:
        #             tnsr.to("cpu", non_blocking=True)

        doutqkvo_out = [torch.empty([world_size, 1, shard_seqlen, eff_bs, hs], dtype=q.dtype, device=q.device), #(P, hc/P, seqlen/P, bs, hs)
                        torch.empty([world_size, (q_hc-world_size)//world_size, shard_seqlen, eff_bs, hs], dtype=q.dtype, device=q.device)
                        ]
        PROCESS_GROUP.ULYSSES_PG = ctx.ulysses_group
        PROCESS_GROUP.RING_PG = ctx.ring_group

        ctx.two_streams[0].wait_stream(torch.cuda.current_stream())
        ctx.two_streams[1].wait_stream(torch.cuda.current_stream())
        # ctx.two_streams[0].synchronize()
        # ctx.two_streams[0].wait_stream(ctx.two_streams[0])
        with torch.cuda.stream(ctx.two_streams[0]):
            doutqkvo_out[0] = all_to_all_4D(doutqkvo[0], 2, 1, False, False, doutqkvo_out[0]) #note that this is scatter 2, gather 1 since this is inverse of forward pass
        
        # works if you comment out the 5 lines below
        # ctx.two_streams[1].synchronize()
        # ctx.two_streams[1].wait_stream(ctx.two_streams[0])
        with torch.cuda.stream(ctx.two_streams[1]):
            doutqkvo_out[1] = all_to_all_4D(doutqkvo[1], 2, 1, False, False, doutqkvo_out[1]) #note that this is scatter 2, gather 1 since this is inverse of forward pass
        
        
        # ctx.two_streams[0].synchronize()
        # ctx.two_streams[0].wait_stream(ctx.two_streams[0])
        with torch.cuda.stream(ctx.two_streams[0]):
            ulysses_doutqkvo_0 = torch.chunk(doutqkvo_out[0], eff_bs, dim=0)
            assert len(ulysses_doutqkvo_0) == eff_bs, f"Pipe: ulysses_tensors length {len(ulysses_doutqkvo_0)} must be equal to eff_bs {eff_bs}, tensor shape {ulysses_doutqkvo_0[0].shape}"
            
            # Extract tensors and ensure they are contiguous to avoid CUDA misalignment errors
            inp_q_0 = ulysses_doutqkvo_0[0]
            inp_k_0 = ulysses_doutqkvo_0[1]
            inp_v_0 = ulysses_doutqkvo_0[2]
            inp_dout_0 = ulysses_doutqkvo_0[3]
            inp_out_0 = ulysses_doutqkvo_0[4]

            # assert not check_nan_inf(inp_q, f"inp_q_stage_{stage}", dist.get_rank()), f"Pipe: inp_q is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_k, f"inp_k_stage_{stage}", dist.get_rank()), f"Pipe: inp_k is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_v, f"inp_v_stage_{stage}", dist.get_rank()), f"Pipe: inp_v is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_dout, f"inp_dout_stage_{stage}", dist.get_rank()), f"Pipe: inp_dout is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_out, f"inp_out_stage_{stage}", dist.get_rank()), f"Pipe: inp_out is nan or inf at stage {stage}"
            
            # Extract softmax_lse and ensure contiguous after transpose
            inp_softmax_lse_transposed_0 = softmax_lse[0]

            # softmax_scale_0 = inp_q_0.shape[-1] ** (-0.5)

            attn_dq_0, attn_dk_0, attn_dv_0 = upipe_ring_flash_attn_backward(
                ctx.ring_group,
                inp_dout_0,
                inp_q_0,
                inp_k_0,
                inp_v_0,
                inp_out_0,
                inp_softmax_lse_transposed_0,
                softmax_scale=ctx.softmax_scale,
                dropout_p=ctx.dropout_p,
                causal=ctx.causal,
                window_size=ctx.window_size,
                softcap=ctx.softcap,
                alibi_slopes=ctx.alibi_slopes,
                deterministic=ctx.deterministic,
                attn_type=ctx.attn_type,
            )

            attn_grads_0 = torch.cat([attn_dq_0, attn_dk_0, attn_dv_0], dim=0)

            # o_bs, o_seqlen, o_shard_hc, o_hs = attn_grads_0.shape
            # o_hc = o_shard_hc * world_size
            # o_shard_seqlen = o_seqlen // world_size
            block_grads_0 = all_to_all_4D(
                attn_grads_0, 1, 2, False, False # scatter 1, gather 2
            )
            # ctx.two_streams[-1].synchronize()
            # assert block_grads_0.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_grads_0 shape {block_grads_0.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            # block_grads_0 = block_grads_0.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            # block_grads_0 = block_grads_0.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)

        
        # ctx.two_streams[1].synchronize()
        # ctx.two_streams[1].wait_stream(ctx.two_streams[1])
        with torch.cuda.stream(ctx.two_streams[1]):
            # works if you remove comment below
            # doutqkvo_out[1] = all_to_all_4D(doutqkvo[1], 2, 1, ctx.ulysses_group, False, False, doutqkvo_out[1]) #note that this is scatter 2, gather 1 since this is inverse of forward pass

            ulysses_doutqkvo_1 = torch.chunk(doutqkvo_out[1], eff_bs, dim=0)
            assert len(ulysses_doutqkvo_1) == eff_bs, f"Pipe: ulysses_tensors length {len(ulysses_doutqkvo_1)} must be equal to eff_bs {eff_bs}, tensor shape {ulysses_doutqkvo_1[0].shape}"
            
            # Extract tensors and ensure they are contiguous to avoid CUDA misalignment errors
            inp_q_1 = ulysses_doutqkvo_1[0]
            inp_k_1 = ulysses_doutqkvo_1[1]
            inp_v_1 = ulysses_doutqkvo_1[2]
            inp_dout_1 = ulysses_doutqkvo_1[3]
            inp_out_1 = ulysses_doutqkvo_1[4]

            # assert not check_nan_inf(inp_q, f"inp_q_stage_{stage}", dist.get_rank()), f"Pipe: inp_q is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_k, f"inp_k_stage_{stage}", dist.get_rank()), f"Pipe: inp_k is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_v, f"inp_v_stage_{stage}", dist.get_rank()), f"Pipe: inp_v is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_dout, f"inp_dout_stage_{stage}", dist.get_rank()), f"Pipe: inp_dout is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_out, f"inp_out_stage_{stage}", dist.get_rank()), f"Pipe: inp_out is nan or inf at stage {stage}"
            
            # Extract softmax_lse and ensure contiguous after transpose
            inp_softmax_lse_transposed_1 = softmax_lse[1]

            # softmax_scale_1 = inp_q_1.shape[-1] ** (-0.5)

            attn_dq_1, attn_dk_1, attn_dv_1 = upipe_ring_flash_attn_backward(
                ctx.ring_group,
                inp_dout_1,
                inp_q_1,
                inp_k_1,
                inp_v_1,
                inp_out_1,
                inp_softmax_lse_transposed_1,
                softmax_scale=ctx.softmax_scale,
                dropout_p=ctx.dropout_p,
                causal=ctx.causal,
                window_size=ctx.window_size,
                softcap=ctx.softcap,
                alibi_slopes=ctx.alibi_slopes,
                deterministic=ctx.deterministic,
                attn_type=ctx.attn_type,
            )

            attn_grads_1 = torch.cat([attn_dq_1, attn_dk_1, attn_dv_1], dim=0)

            # fi_bs, fi_seqlen, fi_shard_hc, fi_hs = attn_grads_1.shape
            # fi_hc = fi_shard_hc * world_size
            # fi_shard_seqlen = fi_seqlen // world_size
            block_grads_1 = all_to_all_4D(
                attn_grads_1, 1, 2, False, False # scatter 1, gather 2
            )

            # assert block_grads_1.numel() == (fi_bs*fi_shard_seqlen*fi_hc*fi_hs), f"Pipe: block_grads_1 shape {block_grads_1.shape} must be compatible to (fi_bs*fi_shard_seqlen*fi_hc*fi_hs) {fi_bs*fi_shard_seqlen*fi_hc*fi_hs}"
            # block_grads_1 = block_grads_1.reshape(fi_hc, fi_shard_seqlen, fi_bs, fi_hs)
            # block_grads_1 = block_grads_1.transpose(0, 2).contiguous().reshape(fi_bs, fi_shard_seqlen, fi_hc, fi_hs)

        # torch.cuda.synchronize()
        torch.cuda.current_stream().wait_stream(ctx.two_streams[0])
        torch.cuda.current_stream().wait_stream(ctx.two_streams[1])
        
        # assert block_grads_0.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_grads_0 shape {block_grads_0.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
        # block_grads_0 = block_grads_0.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
        # block_grads_0 = block_grads_0.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
        
        # assert block_grads_1.numel() == (fi_bs*fi_shard_seqlen*fi_hc*fi_hs), f"Pipe: block_grads_1 shape {block_grads_1.shape} must be compatible to (fi_bs*fi_shard_seqlen*fi_hc*fi_hs) {fi_bs*fi_shard_seqlen*fi_hc*fi_hs}"
        # block_grads_1 = block_grads_1.reshape(fi_hc, fi_shard_seqlen, fi_bs, fi_hs)
        # block_grads_1 = block_grads_1.transpose(0, 2).contiguous().reshape(fi_bs, fi_shard_seqlen, fi_hc, fi_hs)

        dqkv = torch.cat([block_grads_0, block_grads_1], dim=2)

        dqkv = torch.chunk(dqkv, 3, dim=0)
        assert len(dqkv) == 3, f"Pipe: dqkv length {len(dqkv)} must be 3"
        return dqkv[0], dqkv[1], dqkv[2], None, None, None, None, None, None, None, None, None, None, None, None, None, None
    

class UpipeRingFlashAttnFunc(torch.autograd.Function):
    
    @staticmethod
    def forward(
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        attn_type,
    ):  
        # Do all the ulysses stuff here
        bs, shard_seqlen, hc, hs = q.shape

        offload_stream = None

        world_size = dist.get_world_size(ulysses_group)
        pipe_degree = hc // world_size

        # context_layer = torch.empty([bs, shard_seqlen*world_size, hc//world_size, hs], dtype=query.dtype, device=query.device)
        output = torch.zeros_like(q)
        # softmax_lse = torch.zeros([bs, shard_seqlen, hc], dtype=q.dtype, device=q.device)
        softmax_lse = None#torch.zeros([bs, pipe_degree, shard_seqlen*world_size], dtype=q.dtype, device=q.device)
        

        assert k.shape[2] == q.shape[2], f"Pipe: num heads in key {k.shape[2]} must be equal to query {q.shape[2]}"
        # qkv = torch.cat([q, k, v]).contiguous()

        orig_device = q.device

        #TODO: Find better way to do this, clone is not memory efficient, need a partial chunk to be offload friendly
        if offload_stream is not None:
            q_chunks = [t.clone() for t in torch.chunk(q, pipe_degree, dim = 2)]
            # q = torch.empty([bs, shard_seqlen*world_size, pipe_degree, hs], dtype=q.dtype, device=q.device)
            # del q
            # torch.cuda.empty_cache()
            k_chunks = [t.clone() for t in torch.chunk(k, pipe_degree, dim = 2)]
            # k = torch.empty([bs, shard_seqlen*world_size, pipe_degree, hs], dtype=k.dtype, device=k.device)
            # del k
            # torch.cuda.empty_cache()
            v_chunks = [t.clone() for t in torch.chunk(v, pipe_degree, dim = 2)]
            # v = torch.empty([bs, shard_seqlen*world_size, pipe_degree, hs], dtype=v.dtype, device=v.device)
            # del v
            # torch.cuda.empty_cache()
        else:
            q_chunks = [t for t in torch.chunk(q, pipe_degree, dim = 2)]
            k_chunks = [t for t in torch.chunk(k, pipe_degree, dim = 2)]
            v_chunks = [t for t in torch.chunk(v, pipe_degree, dim = 2)]
        

        # async offload the qkv tensors to CPU
        if offload_stream is not None:
            with torch.cuda.stream(offload_stream):
                for q_tnsr, k_tnsr, v_tnsr in zip(q_chunks[2:], k_chunks[2:], v_chunks[2:]):
                    q_tnsr.to("cpu", non_blocking=True)
                    k_tnsr.to("cpu", non_blocking=True)
                    v_tnsr.to("cpu", non_blocking=True)
                # q.to("cpu", non_blocking=True)
                # k.to("cpu", non_blocking=True)
                # v.to("cpu", non_blocking=True)
                # torch.cuda.empty_cache()

        # first all-to-all is blocking
        # ulysses_qkv = SeqAllToAll4D.apply(
        #     self.ulysses_pg, qkv[:,:,:world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, False
        # )
        assert q_chunks[0].ndim == 4, f"Pipe: q_chunks[0] shape {q_chunks[0].shape} must be 4D"
        PROCESS_GROUP.ULYSSES_PG = ulysses_group
        ulysses_q = all_to_all_4D(
            q_chunks[0], 2, 1, False, False # scatter 2, gather 1
        )
        ulysses_k = all_to_all_4D(
            k_chunks[0], 2, 1, False, False # scatter 2, gather 1
        )
        ulysses_v = all_to_all_4D(
            v_chunks[0], 2, 1, False, False # scatter 2, gather 1
        )
        
        if offload_stream is not None:
            torch.cuda.current_stream().synchronize()
            with torch.cuda.stream(offload_stream):
                q_chunks[0].to("cpu", non_blocking=True)
                k_chunks[0].to("cpu", non_blocking=True)
                v_chunks[0].to("cpu", non_blocking=True)

        # Initialize variables to avoid undefined variable errors
        block_output = None
        o_bs = o_shard_seqlen = o_hc = o_hs = None

        # global variables
        global process_group
        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size

        for stage in range(pipe_degree):
            if stage+1 != pipe_degree:
                # next_ulysses_qkv = SeqAllToAll4D.apply(
                #     self.ulysses_pg, qkv[:,:,(stage+1)*world_size:(stage+2)*world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, True
                # )
                if stage+2 < pipe_degree:
                    if fetch_stream is not None:
                        fetch_stream.synchronize() # make sure the last qkv is fetched to GPU
                        with torch.cuda.stream(fetch_stream):
                            q_chunks[stage+2].to(orig_device, non_blocking=True)
                            k_chunks[stage+2].to(orig_device, non_blocking=True)
                            v_chunks[stage+2].to(orig_device, non_blocking=True)
                
                next_ulysses_q = all_to_all_4D(
                    q_chunks[stage+1], 2, 1, False, True # scatter 2, gather 1
                )
                next_ulysses_k = all_to_all_4D(
                    k_chunks[stage+1], 2, 1, False, True # scatter 2, gather 1
                )
                next_ulysses_v = all_to_all_4D(
                    v_chunks[stage+1], 2, 1, False, True # scatter 2, gather 1
                )

                bs, shard_seqlen, hc, hs = q_chunks[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            # output[stage].to(orig_device, non_blocking=True)

            # ulysses_qkv = torch.chunk(ulysses_qkv, 3, dim=0)

            inp_q = ulysses_q
            inp_k = ulysses_k
            inp_v = ulysses_v

            # try:
            #     q[:,:,stage:stage+1,:] = inp_q
            # except:
            #     assert False, f"Pipe: inp_q shape {inp_q.shape} while q shape {q[:,:,stage:stage+1,:].shape}"
            # k[:,:,stage:stage+1,:] = inp_k
            # v[:,:,stage:stage+1,:] = inp_v

            if softmax_scale is None:
                softmax_scale = inp_q.shape[-1] ** (-0.5)

            assert alibi_slopes is None
            inp_k = inp_k.contiguous()
            inp_v = inp_v.contiguous()

            # check_nan_inf(inp_q, "inp_q", dist.get_rank())
            # check_nan_inf(inp_k, "inp_k", dist.get_rank())
            # check_nan_inf(inp_v, "inp_v", dist.get_rank())

            out, lse = upipe_ring_flash_attn_forward(
                inp_q,
                inp_k,
                inp_v,
                softmax_scale=softmax_scale,
                dropout_p=dropout_p,
                causal=causal,
                softcap=softcap,
                deterministic=False,
            )

            # if softmax_lse is None:
            #     softmax_lse = lse#.clone() # cloning the lse tensor to avoid overwriting the lse tensor
            # else:
            #     softmax_lse = torch.cat([softmax_lse, lse], dim = 1)
            if softmax_lse is None:
                softmax_lse = torch.zeros_like(lse).repeat(1, pipe_degree, 1)
            softmax_lse[:, stage:stage+1, :] = lse

            # try:
            #     context_layer = torch.cat([out, lse.unsqueeze(-1)], dim = 3)
            # except:
            #     print(f"Pipe: out shape {out.shape}, lse shape {lse.shape}")
            #     raise ValueError("Pipe: out and lse shapes are not compatible")
            context_layer = out
            
            o_bs, o_seqlen, o_shard_hc, o_hs = context_layer.shape
            o_hc = o_shard_hc * world_size
            o_shard_seqlen = o_seqlen // world_size
            
            if stage > 0:
                if U_HANDLE.O_HANDLE!=[]:
                    assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
                    U_HANDLE.O_HANDLE[0].wait()
                clear_o_handle()
                assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                output[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_output#[:,:,:,:-1]
                # softmax_lse[:,:,(stage-1)*world_size:(stage)*world_size] = block_output[:,:,:,-1]
                # output[stage-1] = block_output
                # output[stage-1].to("cpu", non_blocking=True)

            block_output = all_to_all_4D(
                context_layer, 1, 2, False, True # scatter 1, gather 2
            )
            
            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=[]:
                    assert len(U_HANDLE.HANDLE) == 3, f"Pipe: U_HANDLE.HANDLE length {len(U_HANDLE.HANDLE)} must be 3"
                    U_HANDLE.HANDLE[0].wait()
                    assert next_ulysses_q.numel()== (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_q shape {next_ulysses_q.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                    next_ulysses_q = next_ulysses_q.reshape(seqlen, bs, shard_hc, hs)
                    next_ulysses_q = next_ulysses_q.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                    ulysses_q = next_ulysses_q

                    U_HANDLE.HANDLE[1].wait()
                    assert next_ulysses_k.numel()== (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_k shape {next_ulysses_k.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                    next_ulysses_k = next_ulysses_k.reshape(seqlen, bs, shard_hc, hs)
                    next_ulysses_k = next_ulysses_k.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                    ulysses_k = next_ulysses_k

                    U_HANDLE.HANDLE[2].wait()
                    assert next_ulysses_v.numel()== (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_v shape {next_ulysses_v.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                    next_ulysses_v = next_ulysses_v.reshape(seqlen, bs, shard_hc, hs)
                    next_ulysses_v = next_ulysses_v.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                    ulysses_v = next_ulysses_v
                    clear_u_handle()
                    if offload_stream is not None:
                        with torch.cuda.stream(offload_stream):
                            q_chunks[stage+1].to("cpu", non_blocking=True)
                            k_chunks[stage+1].to("cpu", non_blocking=True)
                            v_chunks[stage+1].to("cpu", non_blocking=True)
                # assert not check_nan_inf(ulysses_qkv, f"ulysses_qkv_{stage+1}", dist.get_rank()), f"Pipe: ulysses_qkv is nan or inf at stage {stage+1}"
        
        if U_HANDLE.O_HANDLE!=[]:
            U_HANDLE.O_HANDLE[0].wait()
            clear_o_handle()
            assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            output[:,:,-(world_size):,:] = block_output#[:,:,:,:-1]
            # softmax_lse[:,:,-(world_size):] = block_output[:,:,:,-1]

        # out, softmax_lse = zigzag_ring_flash_attn_forward_op(out_temp, softmax_lse_temp)

        ctx.save_for_backward(q, k, v, output, softmax_lse)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.ulysses_group = ulysses_group
        ctx.attn_type = attn_type
        ctx.pipe_degree = pipe_degree
        ctx.offload_stream = offload_stream
        ctx.fetch_stream = fetch_stream
        return output if not return_softmax else (output, softmax_lse, None)
        
    
    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse = ctx.saved_tensors

        pipe_degree = ctx.pipe_degree
        offload_stream = ctx.offload_stream
        fetch_stream = ctx.fetch_stream
        orig_device = dout.device
        world_size = dist.get_world_size(ctx.ulysses_group)

        # dq = torch.empty([q.shape[0], q.shape[1]//world_size, q.shape[2]*world_size, q.shape[3]], dtype=q.dtype, device=q.device)
        # dk = torch.empty([k.shape[0], k.shape[1]//world_size, k.shape[2]*world_size, k.shape[3]], dtype=k.dtype, device=k.device)
        # dv = torch.empty([v.shape[0], v.shape[1]//world_size, v.shape[2]*world_size, v.shape[3]], dtype=v.dtype, device=v.device)
        dq = torch.empty_like(q)
        dk = torch.empty_like(k)
        dv = torch.empty_like(v)

        orig_q_dtype = q.dtype
        if offload_stream is not None: # if we plan to offload qkv chunks, then we need to clone because otherwise torch has the same memory reference for all chunks and won't be freed from GPU
            q_chunks = [t.clone() for t in torch.chunk(q, pipe_degree, dim = 2)]
            k_chunks = [t.clone() for t in torch.chunk(k, pipe_degree, dim = 2)]
            v_chunks = [t.clone() for t in torch.chunk(v, pipe_degree, dim = 2)]
            dout_chunks = [t.clone() for t in torch.chunk(dout, pipe_degree, dim = 2)]
            out_chunks = [t.clone() for t in torch.chunk(out, pipe_degree, dim = 2)]
            softmax_lse_chunks = [t.clone() for t in torch.chunk(softmax_lse, pipe_degree, dim = 1)]
        else:
            q_chunks = [t for t in torch.chunk(q, pipe_degree, dim = 2)]
            k_chunks = [t for t in torch.chunk(k, pipe_degree, dim = 2)]
            v_chunks = [t for t in torch.chunk(v, pipe_degree, dim = 2)]
            dout_chunks = [t for t in torch.chunk(dout, pipe_degree, dim = 2)]
            out_chunks = [t for t in torch.chunk(out, pipe_degree, dim = 2)]
            softmax_lse_chunks = [t for t in torch.chunk(softmax_lse, pipe_degree, dim = 1)]

        # doutqkvo_lse = torch.cat([q, k, v, dout, out], dim=0).contiguous()
        # eff_bs = doutqkvo_lse.shape[0]
        # doutqkvo_lse = torch.cat([doutqkvo, softmax_lse.unsqueeze(-1).repeat(eff_bs, 1, 1, 1).to(dout.dtype)], dim=3).contiguous() #piggybacking softmax_lse as the last hidden dimension for all 5 other tensors, so be mindful when unpacking
        # doutqkvo_lse  = torch.chunk(doutqkvo_lse, pipe_degree, dim = 2)

        if offload_stream is not None:
            # async offload the dout tensor to CPU
            with torch.cuda.stream(offload_stream):
                for tnsr in q_chunks[2:]:
                    tnsr.to("cpu", non_blocking=True)
                for tnsr in k_chunks[2:]:
                    tnsr.to("cpu", non_blocking=True)
                for tnsr in v_chunks[2:]:
                    tnsr.to("cpu", non_blocking=True)
                for tnsr in dout_chunks[2:]:
                    tnsr.to("cpu", non_blocking=True)
                for tnsr in out_chunks[2:]:
                    tnsr.to("cpu", non_blocking=True)
                for tnsr in softmax_lse_chunks[1:]:
                    tnsr.to("cpu", non_blocking=True)
                # q.to("cpu", non_blocking=True)
                # k.to("cpu", non_blocking=True)
                # v.to("cpu", non_blocking=True)
                # dout.to("cpu", non_blocking=True)
                # out.to("cpu", non_blocking=True)
                # softmax_lse.to("cpu", non_blocking=True)
                del q, k, v, dout, out, softmax_lse
                torch.cuda.empty_cache()
        
        # ulysses_doutqkvo_lse = all_to_all_4D(doutqkvo_lse[0], 2, 1, ctx.ulysses_group, False, False) #note that this is scatter 2, gather 1 since this is inverse of forward pass
        PROCESS_GROUP.ULYSSES_PG = ctx.ulysses_group
        ulysses_dout = all_to_all_4D(dout_chunks[0], 2, 1, False, False) #note that this is scatter 2, gather 1 since this is inverse of forward pass
        ulysses_out = all_to_all_4D(out_chunks[0], 2, 1, False, False) #note that this is scatter 2, gather 1 since this is inverse of forward pass
        ulysses_q = all_to_all_4D(q_chunks[0], 2, 1, False, False) 
        ulysses_k = all_to_all_4D(k_chunks[0], 2, 1, False, False) 
        ulysses_v = all_to_all_4D(v_chunks[0], 2, 1, False, False) 

        
        if offload_stream is not None:
            torch.cuda.current_stream().synchronize()
            with torch.cuda.stream(offload_stream):
                # dout_chunks[0].to("cpu", non_blocking=True)
                # out_chunks[0].to("cpu", non_blocking=True)
                # del dout_chunks[0], out_chunks[0]
                dout_chunks[0] = None
                out_chunks[0] = None
                q_chunks[0] = None
                k_chunks[0] = None
                v_chunks[0] = None
                torch.cuda.empty_cache()
        
        for stage in range(pipe_degree):
            # if dist.get_rank(ctx.ulysses_group) == 0:
            #     print(f"Tensor shape: {ulysses_doutqkvo_lse.shape}\t stage: {stage}")
            if stage+1 != pipe_degree:
                if stage+2 < pipe_degree:
                    if fetch_stream is not None:
                        fetch_stream.synchronize() # make sure the last qkv is fetched to GPU
                        with torch.cuda.stream(fetch_stream):
                            dout_chunks[stage+2].to(orig_device, non_blocking=True)
                            out_chunks[stage+2].to(orig_device, non_blocking=True)
                            q_chunks[stage+2].to(orig_device, non_blocking=True)
                            k_chunks[stage+2].to(orig_device, non_blocking=True)
                            v_chunks[stage+2].to(orig_device, non_blocking=True)

                            softmax_lse_chunks[stage+1].to(orig_device, non_blocking=True)

                next_ulysses_dout = all_to_all_4D(dout_chunks[stage+1], 2, 1, False, True) #note that this is scatter 2, gather 1 since this is inverse of forward pass
                next_ulysses_out = all_to_all_4D(out_chunks[stage+1], 2, 1, False, True) #note that this is scatter 2, gather 1 since this is inverse of forward pass
                next_ulysses_q = all_to_all_4D(q_chunks[stage+1], 2, 1, False, True)
                next_ulysses_k = all_to_all_4D(k_chunks[stage+1], 2, 1, False, True)
                next_ulysses_v = all_to_all_4D(v_chunks[stage+1], 2, 1, False, True)
                bs, shard_seqlen, hc, hs = dout_chunks[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            # Chunk the tensor to separate q, k, v, dout, out components (5 components total)
            # ulysses_doutqkvo_lse = torch.chunk(ulysses_doutqkvo_lse, eff_bs, dim=0)
            # assert len(ulysses_doutqkvo_lse) == eff_bs, f"Pipe: ulysses_tensors length {len(ulysses_doutqkvo_lse)} must be equal to eff_bs {eff_bs}, tensor shape {ulysses_doutqkvo_lse[0].shape}"
            
            # Extract tensors and ensure they are contiguous to avoid CUDA misalignment errors
            inp_q = ulysses_q#q_chunks[stage]
            inp_k = ulysses_k#k_chunks[stage]
            inp_v = ulysses_v#v_chunks[stage]
            inp_dout = ulysses_dout
            inp_out = ulysses_out
            
            # Extract softmax_lse and ensure contiguous after transpose
            # inp_softmax_lse = ulysses_doutqkvo_lse[-1][:,:,:,-1].contiguous().to(orig_q_dtype) # all 5 tensors have the same softmax_lse
            # inp_softmax_lse_transposed = inp_softmax_lse.transpose(1, 2).contiguous()
            inp_softmax_lse_transposed = softmax_lse_chunks[stage]#.transpose(1, 2).contiguous()

            softmax_scale = inp_q.shape[-1] ** (-0.5)
            
            # Verify tensor contiguity for debugging
            # assert inp_q.is_contiguous(), f"inp_q is not contiguous: {inp_q.stride()}"
            # assert inp_k.is_contiguous(), f"inp_k is not contiguous: {inp_k.stride()}"
            # assert inp_v.is_contiguous(), f"inp_v is not contiguous: {inp_v.stride()}"
            # assert inp_dout.is_contiguous(), f"inp_dout is not contiguous: {inp_dout.stride()}"
            # assert inp_out.is_contiguous(), f"inp_out is not contiguous: {inp_out.stride()}"
            # assert inp_softmax_lse_transposed.is_contiguous(), f"inp_softmax_lse_transposed is not contiguous: {inp_softmax_lse_transposed.stride()}"

            # assert not check_nan_inf(inp_q, f"inp_q", dist.get_rank()), f"Pipe: inp_q is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_k, f"inp_k", dist.get_rank()), f"Pipe: inp_k is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_v, f"inp_v", dist.get_rank()), f"Pipe: inp_v is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_dout, f"inp_dout", dist.get_rank()), f"Pipe: inp_dout is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_out, f"inp_out", dist.get_rank()), f"Pipe: inp_out is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_softmax_lse_transposed, f"inp_softmax_lse_transposed", dist.get_rank()), f"Pipe: inp_softmax_lse_transposed is nan or inf at stage {stage}"

            attn_dq, attn_dk, attn_dv = upipe_ring_flash_attn_backward(
                ctx.ring_group,
                inp_dout,
                inp_q,
                inp_k,
                inp_v,
                inp_out,
                inp_softmax_lse_transposed,
                softmax_scale=softmax_scale,
                dropout_p=ctx.dropout_p,
                causal=ctx.causal,
                window_size=ctx.window_size,
                softcap=ctx.softcap,
                alibi_slopes=ctx.alibi_slopes,
                deterministic=ctx.deterministic,
                attn_type=ctx.attn_type,
            )

            if offload_stream is not None:
                torch.cuda.current_stream().synchronize()
                # with torch.cuda.stream(offload_stream):
                #     softmax_lse_chunks[stage].to("cpu", non_blocking=True)
                softmax_lse_chunks[stage] = None
                q_chunks[stage] = None
                k_chunks[stage] = None
                v_chunks[stage] = None
                dout_chunks[stage] = None
                out_chunks[stage] = None
                # torch.cuda.empty_cache()

            # assert not check_nan_inf(attn_dq, f"attn_dq -{stage}", dist.get_rank()), f"Pipe: attn_dq is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dk, f"attn_dk -{stage}", dist.get_rank()), f"Pipe: attn_dk is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dv, f"attn_dv -{stage}", dist.get_rank()), f"Pipe: attn_dv is nan or inf at stage {stage}"

            if stage > 0:
                if U_HANDLE.O_HANDLE!=[]:
                    assert len(U_HANDLE.O_HANDLE) == 3, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 3"
                    U_HANDLE.O_HANDLE[0].wait()
                    assert block_dq.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_dq shape {block_dq.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                    block_dq = block_dq.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                    block_dq = block_dq.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                    dq[:,:,(stage-1)*world_size:(stage)*world_size:,:] = block_dq

                    U_HANDLE.O_HANDLE[1].wait()
                    assert block_dk.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_dk shape {block_dk.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                    block_dk = block_dk.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                    block_dk = block_dk.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                    dk[:,:,(stage-1)*world_size:(stage)*world_size:,:] = block_dk

                    U_HANDLE.O_HANDLE[2].wait()
                    assert block_dv.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_dv shape {block_dv.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                    block_dv = block_dv.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                    block_dv = block_dv.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                    dv[:,:,(stage-1)*world_size:(stage)*world_size:,:] = block_dv

                    clear_o_handle()
            
            # attn_grads = torch.cat([attn_dq, attn_dk, attn_dv], dim=0)
            o_bs, o_seqlen, o_shard_hc, o_hs = attn_dq.shape
            o_hc = o_shard_hc * world_size
            o_shard_seqlen = o_seqlen // world_size

            block_dq = all_to_all_4D(
                attn_dq, 1, 2, False, True # scatter 1, gather 2
            )
            block_dk = all_to_all_4D(
                attn_dk, 1, 2, False, True # scatter 1, gather 2
            )
            block_dv = all_to_all_4D(
                attn_dv, 1, 2, False, True # scatter 1, gather 2
            )

            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=[]:
                    assert len(U_HANDLE.HANDLE) == 5, f"Pipe: U_HANDLE.HANDLE length {len(U_HANDLE.HANDLE)} must be 5"
                    U_HANDLE.HANDLE[0].wait()
                    assert next_ulysses_dout.numel() ==  (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_dout shape {next_ulysses_dout.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                    next_ulysses_dout = next_ulysses_dout.reshape(seqlen, bs, shard_hc, hs)
                    next_ulysses_dout = next_ulysses_dout.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                    ulysses_dout = next_ulysses_dout
                    
                    U_HANDLE.HANDLE[1].wait()
                    assert next_ulysses_out.numel() ==  (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_out shape {next_ulysses_out.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                    next_ulysses_out = next_ulysses_out.reshape(seqlen, bs, shard_hc, hs)
                    next_ulysses_out = next_ulysses_out.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                    ulysses_out = next_ulysses_out

                    U_HANDLE.HANDLE[2].wait()
                    assert next_ulysses_q.numel() ==  (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_q shape {next_ulysses_q.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                    next_ulysses_q = next_ulysses_q.reshape(seqlen, bs, shard_hc, hs)
                    next_ulysses_q = next_ulysses_q.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                    ulysses_q = next_ulysses_q

                    U_HANDLE.HANDLE[3].wait()
                    assert next_ulysses_k.numel() ==  (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_k shape {next_ulysses_k.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                    next_ulysses_k = next_ulysses_k.reshape(seqlen, bs, shard_hc, hs)
                    next_ulysses_k = next_ulysses_k.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                    ulysses_k = next_ulysses_k
                    
                    U_HANDLE.HANDLE[4].wait()
                    assert next_ulysses_v.numel() ==  (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_v shape {next_ulysses_v.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                    next_ulysses_v = next_ulysses_v.reshape(seqlen, bs, shard_hc, hs)
                    next_ulysses_v = next_ulysses_v.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                    ulysses_v = next_ulysses_v

                clear_u_handle()
                if offload_stream is not None:
                    with torch.cuda.stream(offload_stream):
                        # dout_chunks[stage+1].to("cpu", non_blocking=True)
                        # out_chunks[stage+1].to("cpu", non_blocking=True)
                        dout_chunks[stage+1] = None
                        out_chunks[stage+1] = None
                        q_chunks[stage+1] = None
                        k_chunks[stage+1] = None
                        v_chunks[stage+1] = None
                        # torch.cuda.empty_cache()
                        # softmax_lse[stage].to("cpu", non_blocking=True)
                
        
        if U_HANDLE.O_HANDLE!=[]:
            assert len(U_HANDLE.O_HANDLE) == 3, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 3"
            U_HANDLE.O_HANDLE[0].wait()
            assert block_dq.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_dq shape {block_dq.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_dq = block_dq.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_dq = block_dq.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            dq[:,:,-(world_size):,:] = block_dq

            U_HANDLE.O_HANDLE[1].wait()
            assert block_dk.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_dk shape {block_dk.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_dk = block_dk.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_dk = block_dk.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            dk[:,:,-(world_size):,:] = block_dk

            U_HANDLE.O_HANDLE[2].wait()
            assert block_dv.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_dv shape {block_dv.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_dv = block_dv.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_dv = block_dv.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            dv[:,:,-(world_size):,:] = block_dv

            clear_o_handle()
            torch.cuda.empty_cache()
            # dqkv = torch.chunk(dqkv, 3, dim=0)
            # block_grads = torch.chunk(block_grads, 3, dim=0)
            # dq[:,:,-(world_size):,:] = block_grads[0]
            # dk[:,:,-(world_size):,:] = block_grads[1]
            # dv[:,:,-(world_size):,:] = block_grads[2]

        # assert not check_nan_inf(dqkv[0], f"dqkv[0]", dist.get_rank()), f"Pipe: dqkv[0] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[1], f"dqkv[1]", dist.get_rank()), f"Pipe: dqkv[1] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[2], f"dqkv[2]", dist.get_rank()), f"Pipe: dqkv[2] is nan or inf at stage"
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None

#--------------------------------------------------------------
# ---------------------------- GQA ----------------------------
#--------------------------------------------------------------

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        torch.unsqueeze(x, dim=3)
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )

@torch.library.custom_op("yunchang::_upipe_gqa_qkvpacked_forward", mutates_args=(), device_types="cuda")
def upipe_gqa_qkvpacked_forward(q: torch.Tensor,
                                  k: torch.Tensor,
                                  v: torch.Tensor,
                                  dropout_p: float = 0,
                                  softmax_scale: float = 0,
                                  causal: bool = True,
                                  ) -> Tuple[torch.Tensor, torch.Tensor]:  
        
    global offload_stream, fetch_stream, two_streams, attn_type, alibi_slopes, window_size

    assert k.shape[2] < q.shape[2], f"PipeGQA : num heads in key {k.shape[2]} must be less than query {q.shape[2]}"

    world_size = dist.get_world_size(ulysses_group)

    if k.shape[2] < world_size:
        assert world_size % k.shape[2] == 0, f"PipeGQA : num heads in key {k.shape[2]} must be divisible by world size {world_size}"
        k = repeat_kv(k, world_size//k.shape[2])
        v = repeat_kv(v, world_size//v.shape[2]) # this is the only case where we need to repeat kv

    # Do all the ulysses stuff here
    bs, shard_seqlen, q_hc, hs = q.shape
    bs, shard_seqlen, key_hc, key_hs = k.shape

    pipe_degree = q_hc // world_size

    # context_layer = torch.empty([bs, shard_seqlen*world_size, hc//world_size, hs], dtype=query.dtype, device=query.device)
    output = torch.zeros_like(q)
    # softmax_lse = torch.zeros([bs, shard_seqlen, hc], dtype=q.dtype, device=q.device)
    softmax_lse = None#torch.zeros([bs, pipe_degree, shard_seqlen*world_size], dtype=q.dtype, device=q.device)
    
    # prepare packets for all-to-all
    # whenever we are sending q[i] where i%(hc//key_hc) == 0, we need to send k[i] and v[i]
    # otherwise, we only need to send q[i]

    q_idx = torch.arange(q_hc, device=q.device)
    q_idx = q_idx.reshape(-1, q_hc//key_hc).unsqueeze(0).reshape(-1, world_size, q_hc//key_hc).transpose(1,2).flatten(1,2).flatten()

    comm_pkts = []
    for stage in range(pipe_degree):
        # each stage processes world_size heads
        if stage % (q_hc//key_hc) == 0:
            try:
                comm_pkts.append(torch.cat([q[:,:,q_idx[stage*world_size:(stage+1)*world_size],:],
                                            k[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:],
                                            v[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:]
                                            ]).contiguous())
            except:
                assert False, f"k[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:].shape {k[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:].shape}, stage {stage}, world_size {world_size}, q_hc {q_hc}, key_hc {key_hc}"
        else:
            comm_pkts.append(q[:,:,q_idx[stage*world_size:(stage+1)*world_size],:])
    
    # qkv = torch.cat([q, k, v]).contiguous()
    # qkv = torch.chunk(qkv, pipe_degree, dim = 2)

    # qkv = torch.cat([q, k, v]).contiguous()

    orig_device = q.device

    # qkv = torch.chunk(qkv, pipe_degree, dim = 2)

    if offload_stream is not None:
        # async offload the qkv tensors to CPU
        with torch.cuda.stream(offload_stream):
            for tnsr in comm_pkts[2:]:
                tnsr.to("cpu", non_blocking=True)

    # first all-to-all is blocking
    # ulysses_qkv = SeqAllToAll4D.apply(
    #     self.ulysses_pg, qkv[:,:,:world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, False
    # )
    ulysses_qkv = all_to_all_4D(
        comm_pkts[0], 2, 1, False, False # scatter 2, gather 1
    )

    assert ulysses_qkv.shape[0] == 3, f"INITIALLY, Got ulysses_qkv with length {ulysses_qkv.shape[0]} in stage 0"

    if offload_stream is not None:
        torch.cuda.current_stream().synchronize()
        with torch.cuda.stream(offload_stream):
            comm_pkts[0].to("cpu", non_blocking=True)

    # Initialize variables to avoid undefined variable errors
    block_output = None
    o_bs = o_shard_seqlen = o_hc = o_hs = None

    for stage in range(pipe_degree):
        if stage+1 != pipe_degree:
            # next_ulysses_qkv = SeqAllToAll4D.apply(
            #     self.ulysses_pg, qkv[:,:,(stage+1)*world_size:(stage+2)*world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, True
            # )
            if stage+2 < pipe_degree:
                if fetch_stream is not None:
                    fetch_stream.synchronize() # make sure the last qkv is fetched to GPU
                    with torch.cuda.stream(fetch_stream):
                        comm_pkts[stage+2].to(orig_device, non_blocking=True)
            next_ulysses_qkv = all_to_all_4D(
                comm_pkts[stage+1], 2, 1, False, True # scatter 2, gather 1
            )
            bs, shard_seqlen, hc, hs = comm_pkts[stage+1].shape
            seqlen = shard_seqlen * world_size
            shard_hc = hc // world_size

        # output[stage].to(orig_device, non_blocking=True)

        if stage % (q_hc//key_hc) == 0:
            assert ulysses_qkv.shape[0] == 3, f"Got ulysses_qkv with length {ulysses_qkv.shape[0]} in stage {stage}, with hc {hc} and key_hc {key_hc}"
            ulysses_qkv = torch.chunk(ulysses_qkv, 3, dim=0)
            
            inp_q = ulysses_qkv[0]
            inp_k = ulysses_qkv[1]
            inp_v = ulysses_qkv[2]
        else:
            inp_q = ulysses_qkv

        if softmax_scale is None:
            softmax_scale = inp_q.shape[-1] ** (-0.5)

        assert alibi_slopes is None
        inp_k = inp_k.contiguous()
        inp_v = inp_v.contiguous()

        # check_nan_inf(inp_q, "inp_q", dist.get_rank())
        # check_nan_inf(inp_k, "inp_k", dist.get_rank())
        # check_nan_inf(inp_v, "inp_v", dist.get_rank())

        out, lse = upipe_ring_flash_attn_forward(
            inp_q,
            inp_k,
            inp_v,
            softmax_scale=softmax_scale,
            dropout_p=dropout_p,
            causal=causal,
            deterministic=False,
        )

        # if softmax_lse is None:
        #     softmax_lse = lse#.clone() # cloning the lse tensor to avoid overwriting the lse tensor
        # else:
        #     softmax_lse = torch.cat([softmax_lse, lse], dim = 1)
        if softmax_lse is None:
            softmax_lse = torch.zeros_like(lse).repeat(1, pipe_degree, 1)
        softmax_lse[:, stage, :] = lse
        # CRITICAL: we need to ensure that the processing for backward pass is done in the same order as the forward pass for softmax_lse to be correctly indexed

        # try:
        #     context_layer = torch.cat([out, lse.unsqueeze(-1)], dim = 3)
        # except:
        #     print(f"Pipe: out shape {out.shape}, lse shape {lse.shape}")
        #     raise ValueError("Pipe: out and lse shapes are not compatible")
        context_layer = out
        
        o_bs, o_seqlen, o_shard_hc, o_hs = context_layer.shape
        o_hc = o_shard_hc * world_size
        o_shard_seqlen = o_seqlen // world_size
        
        if stage > 0:
            if U_HANDLE.O_HANDLE!=[]:
                assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
                U_HANDLE.O_HANDLE[0].wait()
            clear_o_handle()
            assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            output[:,:,q_idx[(stage-1)*world_size:(stage)*world_size],:] = block_output#[:,:,:,:-1]
            # softmax_lse[:,:,(stage-1)*world_size:(stage)*world_size] = block_output[:,:,:,-1]
            # output[stage-1] = block_output
            # output[stage-1].to("cpu", non_blocking=True)

        block_output = all_to_all_4D(
            context_layer, 1, 2, False, True # scatter 1, gather 2
        )
        
        if stage+1 != pipe_degree:
            if U_HANDLE.HANDLE!=[]:
                assert len(U_HANDLE.HANDLE) == 1, f"Pipe: U_HANDLE.HANDLE length {len(U_HANDLE.HANDLE)} must be 1"
                U_HANDLE.HANDLE[0].wait()
            clear_u_handle()
            if offload_stream is not None:
                with torch.cuda.stream(offload_stream):
                    comm_pkts[stage+1].to("cpu", non_blocking=True)
            assert next_ulysses_qkv.numel()== (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_qkv shape {next_ulysses_qkv.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
            next_ulysses_qkv = next_ulysses_qkv.reshape(seqlen, bs, shard_hc, hs)
            next_ulysses_qkv = next_ulysses_qkv.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
            ulysses_qkv = next_ulysses_qkv
            # assert not check_nan_inf(ulysses_qkv, f"ulysses_qkv_{stage+1}", dist.get_rank()), f"Pipe: ulysses_qkv is nan or inf at stage {stage+1}"
    
    if U_HANDLE.O_HANDLE!=[]:
        assert len(U_HANDLE.O_HANDLE) == 1, f"Pipe: U_HANDLE.O_HANDLE length {len(U_HANDLE.O_HANDLE)} must be 1"
        U_HANDLE.O_HANDLE[0].wait()
        clear_o_handle()
        assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
        block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
        block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
        output[:,:,q_idx[-(world_size):],:] = block_output#[:,:,:,:-1]
    return output, softmax_lse


@upipe_gqa_qkvpacked_forward.register_fake
def _(q: torch.Tensor,
      k: torch.Tensor,
      v: torch.Tensor,
      dropout_p: float = 0,
      softmax_scale: float = 0,
      causal: bool = True,
      ) -> Tuple[torch.Tensor, torch.Tensor]:  
    
    bs, sl, nh, d = q.shape
    out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    lse = torch.empty([bs, sl, nh], dtype=q.dtype, device=q.device)
    return out, lse


@torch.library.custom_op("yunchang::_upipe_gqa_qkvpacked_two_streams_forward", mutates_args=(), device_types="cuda")
def upipe_gqa_qkvpacked_two_streams_forward(q: torch.Tensor,
                                            k: torch.Tensor,
                                            v: torch.Tensor,
                                            dropout_p: float = 0,
                                            softmax_scale: float = 0,
                                            causal: bool = True,
                                            ) -> Tuple[torch.Tensor, torch.Tensor]:  
        
    global offload_stream, fetch_stream, two_streams, attn_type, alibi_slopes, window_size
    assert two_streams!=None, "PipeGQA: two_streams must be initialized"

    assert k.shape[2] < q.shape[2], f"PipeGQA : num heads in key {k.shape[2]} must be less than query {q.shape[2]}"

    world_size = dist.get_world_size(PROCESS_GROUP.ULYSSES_PG)

    if k.shape[2] < world_size:
        assert world_size % k.shape[2] == 0, f"PipeGQA : num heads in key {k.shape[2]} must be divisible by world size {world_size}"
        k = repeat_kv(k, world_size//k.shape[2])
        v = repeat_kv(v, world_size//v.shape[2]) # this is the only case where we need to repeat kv

    # Do all the ulysses stuff here
    bs, shard_seqlen, q_hc, hs = q.shape
    bs, shard_seqlen, key_hc, key_hs = k.shape

    pipe_degree = q_hc // world_size

    # prepare packets for all-to-all
    # whenever we are sending q[i] where i%(hc//key_hc) == 0, we need to send k[i] and v[i]
    # otherwise, we only need to send q[i]

    q_idx = torch.arange(q_hc, device=q.device)
    q_idx = q_idx.reshape(-1, q_hc//key_hc).unsqueeze(0).reshape(-1, world_size, q_hc//key_hc).transpose(1,2).flatten(1,2).flatten()

    comm_pkts = []
    for stage in range(pipe_degree):
        # each stage processes world_size heads
        if stage % (q_hc//key_hc) == 0:
            try:
                comm_pkts.append(torch.cat([q[:,:,q_idx[stage*world_size:(stage+1)*world_size],:],
                                            k[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:],
                                            v[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:]
                                            ]).contiguous())
            except:
                assert False, f"k[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:].shape {k[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:].shape}, stage {stage}, world_size {world_size}, q_hc {q_hc}, key_hc {key_hc}"
        else:
            comm_pkts.append(q[:,:,q_idx[stage*world_size:(stage+1)*world_size],:])

    orig_device = q.device

    del q, k, v

    if offload_stream is not None:
        # async offload the qkv tensors to CPU
        with torch.cuda.stream(offload_stream):
            for tnsr in comm_pkts[2:]:
                tnsr.to("cpu", non_blocking=True)

    ulysses_qkv_out = [torch.empty_like(comm_pkts[i]) for i in range(pipe_degree)]
    # ulysses_qkv_out[0] = all_to_all_4D(
    #     comm_pkts[0], 2, 1, False, False, ulysses_qkv_out[0] # scatter 2, gather 1
    # )

    # assert ulysses_qkv_out[0].shape[0] == 3, f"INITIALLY, Got ulysses_qkv with length {ulysses_qkv_out[0].shape[0]} in stage 0"

    # if offload_stream is not None:
    #     # torch.cuda.current_stream().synchronize()
    #     with torch.cuda.stream(offload_stream):
    #         comm_pkts[0].to("cpu", non_blocking=True)

    softmax_lse = []
    output = []

    two_streams[0].wait_stream(torch.cuda.current_stream())
    two_streams[1].wait_stream(torch.cuda.current_stream())
    
    for stage in range(pipe_degree):

        if stage == 0:
            with torch.cuda.stream(two_streams[(stage)%2]):
                ulysses_qkv_out[stage] = all_to_all_4D(
                    comm_pkts[stage], 2, 1, False, False, ulysses_qkv_out[stage] # scatter 2, gather 1
                )
            
            if offload_stream is not None:
                offload_stream.wait_stream(two_streams[(stage)%2])
                with torch.cuda.stream(offload_stream):
                    comm_pkts[0].to("cpu", non_blocking=True)

        if stage+1 != pipe_degree:

            # prefetch
            if stage+2 < pipe_degree:
                if fetch_stream is not None:
                    # fetch_stream.synchronize() # make sure the last qkv is fetched to GPU
                    with torch.cuda.stream(fetch_stream):
                        comm_pkts[stage+2].to(orig_device, non_blocking=True)
            
            # async queue all-to-all
            with torch.cuda.stream(two_streams[(stage+1)%2]):
                try:
                    ulysses_qkv_out[stage+1] = all_to_all_4D(
                        comm_pkts[stage+1], 2, 1, False, False, ulysses_qkv_out[stage+1] # scatter 2, gather 1
                    )
                except:
                    print("ERROR....World size: ", dist.get_world_size(PROCESS_GROUP.ULYSSES_PG))
                
                if offload_stream is not None:
                    offload_stream.wait_stream(two_streams[(stage+1)%2])
                    with torch.cuda.stream(offload_stream):
                        comm_pkts[stage+1].to("cpu", non_blocking=True)

        with torch.cuda.stream(two_streams[stage%2]):
            if stage % (q_hc//key_hc) == 0:
                assert ulysses_qkv_out[stage].shape[0] % 3 == 0, f"Got ulysses_qkv with length {ulysses_qkv_out[stage].shape[0]} in stage {stage}, with q_hc {q_hc} and key_hc {key_hc}"
                ulysses_qkv = torch.chunk(ulysses_qkv_out[stage], 3, dim=0)
                
                inp_q = ulysses_qkv[0]
                inp_k = ulysses_qkv[1]
                inp_v = ulysses_qkv[2]
            else:
                inp_q = ulysses_qkv_out[stage]

            assert alibi_slopes is None
            inp_k = inp_k.contiguous()
            inp_v = inp_v.contiguous()

            # check_nan_inf(inp_q, "inp_q", dist.get_rank())
            # check_nan_inf(inp_k, "inp_k", dist.get_rank())
            # check_nan_inf(inp_v, "inp_v", dist.get_rank())

            out, lse = upipe_ring_flash_attn_forward(
                inp_q,
                inp_k,
                inp_v,
                softmax_scale=softmax_scale,
                dropout_p=dropout_p,
                causal=causal,
                deterministic=False,
            )

            softmax_lse.append(lse)

            block_output = all_to_all_4D(
                out, 1, 2, False, False # scatter 1, gather 2
            )
            output.append(block_output)
    
    torch.cuda.current_stream().wait_stream(two_streams[0])
    torch.cuda.current_stream().wait_stream(two_streams[1])
    
    output = torch.cat(output, dim = 2)
    softmax_lse = torch.cat(softmax_lse, dim = 1)
    return output, softmax_lse


@upipe_gqa_qkvpacked_two_streams_forward.register_fake
def _(q: torch.Tensor,
      k: torch.Tensor,
      v: torch.Tensor,
      dropout_p: float = 0,
      softmax_scale: float = 0,
      causal: bool = True,
      ) -> Tuple[torch.Tensor, torch.Tensor]:  
    
    bs, sl, nh, d = q.shape
    out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    lse = torch.empty([bs, sl, nh], dtype=q.dtype, device=q.device)
    return out, lse



class UpipeRingFlashAttnGQAFuncQKVPacked(torch.autograd.Function):
    
    @staticmethod
    def forward(
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        two_streams,
        attn_type,
    ):  
        PROCESS_GROUP.ULYSSES_PG = ulysses_group
        PROCESS_GROUP.RING_PG = ring_group

        offload_stream = None
        fetch_stream = None

        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.offload_stream = offload_stream
        current_module.fetch_stream = fetch_stream
        current_module.two_streams = two_streams
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size

        pipe_degree = q.shape[2] // dist.get_world_size(ulysses_group)
        orig_k_hc = k.shape[2]

        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        if two_streams:
            output, softmax_lse = upipe_gqa_qkvpacked_two_streams_forward(q,
                                                            k,
                                                            v,
                                                            dropout_p,
                                                            softmax_scale,
                                                            causal)
        else:
            output, softmax_lse = upipe_gqa_qkvpacked_forward(q,
                                                            k,
                                                            v,
                                                            dropout_p,
                                                            softmax_scale,
                                                            causal)
        
        ctx.save_for_backward(q, k, v, output, softmax_lse)
        ctx.orig_k_hc = orig_k_hc
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.ulysses_group = ulysses_group
        ctx.attn_type = attn_type
        ctx.pipe_degree = pipe_degree
        ctx.offload_stream = offload_stream
        ctx.fetch_stream = fetch_stream
        ctx.two_streams = two_streams
        return output if not return_softmax else (output, softmax_lse, None)
        
    
    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse = ctx.saved_tensors

        world_size = dist.get_world_size(ctx.ulysses_group)

        if k.shape[2] < world_size:
            assert world_size % k.shape[2] == 0, f"PipeGQA : num heads in key {k.shape[2]} must be divisible by world size {world_size}"
            k = repeat_kv(k, world_size//k.shape[2])
            v = repeat_kv(v, world_size//v.shape[2]) # this is the only case where we need to repeat kv

        # Do all the ulysses stuff here
        bs, shard_seqlen, q_hc, hs = q.shape
        bs, shard_seqlen, key_hc, key_hs = k.shape

        q_idx = torch.arange(q_hc, device=q.device)
        q_idx = q_idx.reshape(-1, q_hc//key_hc).unsqueeze(0).reshape(-1, world_size, q_hc//key_hc).transpose(1,2).flatten(1,2).flatten()

        pipe_degree = ctx.pipe_degree
        offload_stream = ctx.offload_stream
        fetch_stream = ctx.fetch_stream
        orig_device = dout.device

        dqkv = []#torch.zeros([3*q.shape[0], q.shape[1], q.shape[2], q.shape[3]], dtype = q.dtype, device=q.device)

        eff_bs_1, eff_bs_2 = None, None
        doutqkvo_lse = []
        for stage in range(pipe_degree):
            # each stage processes world_size heads
            if stage % (q_hc//key_hc) == 0:
                doutqkvo_lse.append(torch.cat([q[:,:,q_idx[stage*world_size:(stage+1)*world_size],:],
                                               k[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:],
                                               v[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:],
                                               dout[:,:,q_idx[stage*world_size:(stage+1)*world_size],:],
                                               out[:,:,q_idx[stage*world_size:(stage+1)*world_size],:]
                                               ], dim = 0).contiguous())
                if eff_bs_1==None:
                    eff_bs_1 = doutqkvo_lse[-1].shape[0]
            else:
                doutqkvo_lse.append(torch.cat([q[:,:,q_idx[stage*world_size:(stage+1)*world_size],:],
                                               dout[:,:,q_idx[stage*world_size:(stage+1)*world_size],:],
                                               out[:,:,q_idx[stage*world_size:(stage+1)*world_size],:]
                                               ], dim = 0).contiguous())
                if eff_bs_2==None:
                    eff_bs_2 = doutqkvo_lse[-1].shape[0]
        # doutqkvo_lse = torch.cat([q, k, v, dout, out], dim=0).contiguous() # 5 tensors in total
        # eff_bs = doutqkvo_lse.shape[0]
        # doutqkvo_lse = torch.cat([doutqkvo, softmax_lse.unsqueeze(-1).repeat(eff_bs, 1, 1, 1).to(dout.dtype)], dim=3).contiguous() #piggybacking softmax_lse as the last hidden dimension for all 5 other tensors, so be mindful when unpacking
        # doutqkvo_lse  = torch.chunk(doutqkvo_lse, pipe_degree, dim = 2)

        softmax_lse = torch.chunk(softmax_lse, pipe_degree, dim = 1)

        if offload_stream is not None:
            # async offload the dout tensor to CPU
            with torch.cuda.stream(offload_stream):
                for tnsr in doutqkvo_lse[2:]:
                    tnsr.to("cpu", non_blocking=True)
                for tnsr in softmax_lse[1:]:
                    tnsr.to("cpu", non_blocking=True)
        
        ulysses_doutqkvo_lse_out = [torch.empty_like(doutqkvo_lse[i]) for i in range(pipe_degree)]
        # ulysses_doutqkvo_lse_out[0] = all_to_all_4D(doutqkvo_lse[0], 2, 1, False, False, ulysses_doutqkvo_lse_out[0]) #note that this is scatter 2, gather 1 since this is inverse of forward pass
        
        # if offload_stream is not None:
        #     # torch.cuda.current_stream().synchronize()
        #     with torch.cuda.stream(offload_stream):
        #         doutqkvo_lse[0].to("cpu", non_blocking=True)


        two_streams[0].wait_stream(torch.cuda.current_stream())
        two_streams[1].wait_stream(torch.cuda.current_stream())
        
        for stage in range(pipe_degree):

            if stage == 0:
                with torch.cuda.stream(two_streams[(stage)%2]):
                    ulysses_doutqkvo_lse_out[stage] = all_to_all_4D(doutqkvo_lse[stage], 2, 1, False, False, ulysses_doutqkvo_lse_out[stage]) #note that this is scatter 2, gather 1 since this is inverse of forward pass

                    if offload_stream is not None:
                        offload_stream.wait_stream(two_streams[(stage)%2])
                        with torch.cuda.stream(offload_stream):
                            doutqkvo_lse[stage].to("cpu", non_blocking=True)
            
            if stage+1 != pipe_degree:

                # prefetch
                if stage+2 < pipe_degree:
                    if fetch_stream is not None:
                        # fetch_stream.synchronize() # make sure the last qkv is fetched to GPU
                        with torch.cuda.stream(fetch_stream):
                            doutqkvo_lse[stage+2].to(orig_device, non_blocking=True)
                            softmax_lse[stage+1].to(orig_device, non_blocking=True)
                
                with torch.cuda.stream(two_streams[(stage+1)%2]):
                    ulysses_doutqkvo_lse_out[stage+1] = all_to_all_4D(doutqkvo_lse[stage+1], 2, 1, False, False, ulysses_doutqkvo_lse_out[stage+1]) #note that this is scatter 2, gather 1 since this is inverse of forward pass

                    if offload_stream is not None:
                        offload_stream.wait_stream(two_streams[(stage+1)%2])
                        with torch.cuda.stream(offload_stream):
                            doutqkvo_lse[stage+1].to("cpu", non_blocking=True)

            with torch.cuda.stream(two_streams[stage%2]):
                if stage % (q_hc//key_hc) == 0:
                    assert ulysses_doutqkvo_lse_out[stage].shape[0] % 5 == 0, f"Got ulysses_doutqkvo_lse with length {ulysses_doutqkvo_lse_out[stage].shape[0]} in stage {stage}, with q_hc {q_hc} and key_hc {key_hc}"
                    # Chunk the tensor to separate q, k, v, dout, out components (5 components total)
                    ulysses_doutqkvo_lse = torch.chunk(ulysses_doutqkvo_lse_out[stage], 5, dim=0)
                    
                    # Extract tensors and ensure they are contiguous to avoid CUDA misalignment errors
                    inp_q = ulysses_doutqkvo_lse[0]
                    inp_k = ulysses_doutqkvo_lse[1] 
                    inp_v = ulysses_doutqkvo_lse[2]
                    inp_dout = ulysses_doutqkvo_lse[3]
                    inp_out = ulysses_doutqkvo_lse[4]
                else:
                    # Chunk the tensor to separate q, dout, out components (3 components total)
                    assert ulysses_doutqkvo_lse_out[stage].shape[0] % 3 == 0, f"Got ulysses_doutqkvo_lse with length {ulysses_doutqkvo_lse_out[stage].shape[0]} in stage {stage}, with q_hc {q_hc} and key_hc {key_hc}"
                    ulysses_doutqkvo_lse = torch.chunk(ulysses_doutqkvo_lse_out[stage], 3, dim=0)
                    
                    # Extract tensors and ensure they are contiguous to avoid CUDA misalignment errors
                    inp_q = ulysses_doutqkvo_lse[0]
                    inp_dout = ulysses_doutqkvo_lse[1]
                    inp_out = ulysses_doutqkvo_lse[2]

                inp_softmax_lse_transposed = softmax_lse[stage]

                softmax_scale = inp_q.shape[-1] ** (-0.5)

                try:
                    attn_dq, attn_dk, attn_dv = upipe_ring_flash_attn_backward(
                        ctx.ring_group,
                        inp_dout,
                        inp_q,
                        inp_k,
                        inp_v,
                        inp_out,
                        inp_softmax_lse_transposed,
                        softmax_scale=softmax_scale,
                        dropout_p=ctx.dropout_p,
                        causal=ctx.causal,
                        window_size=ctx.window_size,
                        softcap=ctx.softcap,
                        alibi_slopes=ctx.alibi_slopes,
                        deterministic=ctx.deterministic,
                        attn_type=ctx.attn_type,
                    )
                except RuntimeError:
                    if torch.distributed.get_rank() == 0:
                        breakpoint()
                    torch.distributed.barrier()

                if offload_stream is not None:
                    # torch.cuda.current_stream().synchronize()
                    with torch.cuda.stream(offload_stream):
                        softmax_lse[stage].to("cpu", non_blocking=True)

                
                attn_grads = torch.cat([attn_dq, attn_dk, attn_dv], dim=0)
                block_grads = all_to_all_4D(
                    attn_grads, 1, 2, False, False # scatter 1, gather 2
                )

                dqkv.append(block_grads)
        
        torch.cuda.current_stream().wait_stream(two_streams[0])
        torch.cuda.current_stream().wait_stream(two_streams[1])
        
        dqkv = torch.cat(dqkv, dim=2)
        dqkv = torch.chunk(dqkv, 3, dim=0)
        bs, seqlen, nh, hs = dqkv[0].shape
        key_hc = ctx.orig_k_hc
        
        return dqkv[0], dqkv[1].reshape(bs, seqlen, key_hc, nh//key_hc, hs).sum(3), dqkv[2].reshape(bs, seqlen, key_hc, nh//key_hc, hs).sum(3), None, None, None, None, None, None, None, None, None, None, None, None, None, None

def upipe_ring_flash_attn_qkvpacked_func(
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    group=None,
    attn_type: AttnType = AttnType.FA,
):
    return UpipeRingFlashAttnFunc.apply(
        qkv[:, :, 0],
        qkv[:, :, 1],
        qkv[:, :, 2],
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_attn_probs,
        group,
        attn_type,
    )


def upipe_ring_flash_attn_kvpacked_func(
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    group=None,
    attn_type: AttnType = AttnType.FA,
):
    return UpipeRingFlashAttnFunc.apply(
        q,
        kv[:, :, 0],
        kv[:, :, 1],
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_attn_probs,
        group,
        attn_type,
    )


def upipe_ring_flash_attn_func(
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    ring_group=None,
    ulysses_group=None,
    offload_stream=None,
    fetch_stream=None,
    two_streams=None,
    attn_type: AttnType = AttnType.FA,
    attn_processor=None,
    use_pack_qkv=True,
    dualstage=False,
):
    num_q_heads = q.shape[2] if type(q) == torch.Tensor else q[0].shape[2]
    num_k_heads = k.shape[2] if type(k) == torch.Tensor else k[0].shape[2]
    
    if num_q_heads==num_k_heads:
        if use_pack_qkv:
            if dualstage:
                if two_streams is None:
                    return DualStageUpipeRingFlashAttnFuncQKVPacked.apply(
                            q,
                            k,
                            v,
                            dropout_p,
                            softmax_scale,
                            causal,
                            window_size,
                            softcap,
                            alibi_slopes,
                            deterministic,
                            return_attn_probs,
                            ring_group,
                            ulysses_group,
                            offload_stream,
                            fetch_stream,
                            attn_type,
                        )
                else:
                    return DualStageUpipeRingFlashAttnFuncQKVPackedTwoStreams.apply(
                            q,
                            k,
                            v,
                            dropout_p,
                            softmax_scale,
                            causal,
                            window_size,
                            softcap,
                            alibi_slopes,
                            deterministic,
                            return_attn_probs,
                            ring_group,
                            ulysses_group,
                            offload_stream,
                            fetch_stream,
                            two_streams,
                            attn_type,
                        )

            else:
                if type(q) == torch.Tensor:
                    output = UpipeRingFlashAttnFuncQKVPacked.apply(
                            q,
                            k,
                            v,
                            dropout_p,
                            softmax_scale,
                            causal,
                            window_size,
                            softcap,
                            alibi_slopes,
                            deterministic,
                            return_attn_probs,
                            ring_group,
                            ulysses_group,
                            offload_stream,
                            fetch_stream,
                            attn_type,
                        )
                
                    if torch.distributed.get_rank() == 0:
                        breakpoint()
                    torch.distributed.barrier()
                    return output
                
                else:
                    output = NewPipeRingFlashAttnFuncQKVPacked.apply(
                            q,
                            k,
                            v,
                            dropout_p,
                            softmax_scale,
                            causal,
                            window_size,
                            softcap,
                            alibi_slopes,
                            deterministic,
                            return_attn_probs,
                            ring_group,
                            ulysses_group,
                            offload_stream,
                            fetch_stream,
                            attn_type,
                        )
                    assert output.requires_grad, f"Inside upipe_ring_flash_attn_func: output requires_grad must be True"
                    return output
            
        return UpipeRingFlashAttnFunc.apply(
                q,
                k,
                v,
                dropout_p,
                softmax_scale,
                causal,
                window_size,
                softcap,
                alibi_slopes,
                deterministic,
                return_attn_probs,
                ring_group,
                ulysses_group,
                offload_stream,
                fetch_stream,
                attn_type,
            )

    elif num_q_heads>num_k_heads:
        assert use_pack_qkv, "UpipeRingFlashAttnGQAFunc only supports use_pack_qkv=True"
        return UpipeRingFlashAttnGQAFuncQKVPacked.apply(
                q,
                k,
                v,
                dropout_p,
                softmax_scale,
                causal,
                window_size,
                softcap,
                alibi_slopes,
                deterministic,
                return_attn_probs,
                ring_group,
                ulysses_group,
                offload_stream,
                fetch_stream,
                two_streams,
                attn_type,
            )