import torch
from .utils import RingComm, update_out_and_lse
from yunchang.kernels import AttnType, select_flash_attn_impl
from yunchang.globals import U_HANDLE, clear_u_handle, clear_o_handle
from yunchang.comm.all_to_all import all_to_all_4D

from typing import List, Tuple

global process_group, attn_type, alibi_slopes, window_size

import torch.distributed as dist

def check_nan_inf(tensor, name, rank=0):
    """Check for NaN or Inf values in tensor"""
    if torch.isnan(tensor).any():
        print(f"[RANK {rank}] NaN detected in {name}, shape: {tensor.shape}")
        print(f"[RANK {rank}] NaN locations: {torch.isnan(tensor).sum().item()}")
        return True
    if torch.isinf(tensor).any():
        print(f"[RANK {rank}] Inf detected in {name}, shape: {tensor.shape}")
        print(f"[RANK {rank}] Inf locations: {torch.isinf(tensor).sum().item()}")
        return True
    return False

@torch.library.custom_op("yunchang::_upipe_ring_flash_attn_forward", mutates_args=(), device_types="cuda")
def upipe_ring_flash_attn_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: float,
    dropout_p: float = 0,
    causal: bool = True,
    softcap: float = 0.0,
    deterministic: bool=False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    global process_group, attn_type, alibi_slopes, window_size
    
    assert causal == True, "zigzag ring is meaningless for causal=False"
    comm = RingComm(process_group)

    block_seq_len = q.shape[1] // 2
    q1 = q[:, block_seq_len:]

    out = None
    lse = None
    next_k, next_v = None, None
    
    def forward(q, k, v, causal):
        fn = select_flash_attn_impl(attn_type, stage="fwd-only")
        block_out, block_lse = fn(
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
            window_size=window_size,
            softcap=softcap,
            alibi_slopes=alibi_slopes,
            return_softmax=True and dropout_p > 0,
        )
        return block_out, block_lse

    for step in range(comm.world_size):
        if step + 1 != comm.world_size:
            next_k: torch.Tensor = comm.send_recv(k)
            next_v: torch.Tensor = comm.send_recv(v)
            comm.commit()

        if step == 0:
            block_out, block_lse = forward(q, k, v, causal=True)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)
        elif step <= comm.rank:
            k0 = k[:, :block_seq_len]
            v0 = v[:, :block_seq_len]
            block_out, block_lse = forward(q, k0, v0, causal=False)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)
        else:
            block_out, block_lse = forward(q1, k, v, causal=False)
            out, lse = update_out_and_lse(
                out,
                lse,
                block_out,
                block_lse,
                slice_=(slice(None), slice(block_seq_len, None)),
            )

        if step + 1 != comm.world_size:
            comm.wait()
            k = next_k
            v = next_v

    out = out.to(q.dtype)
    lse = lse.squeeze(dim=-1).transpose(1, 2)
    return out, lse

@upipe_ring_flash_attn_forward.register_fake
def _(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: float,
    dropout_p: float = 0,
    causal: bool = True,
    softcap: float = 0.0,
    deterministic: bool=False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    bs, sl, nh, d = q.shape
    out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    lse = torch.empty([bs, sl, nh], dtype=q.dtype, device=q.device)
    return out, lse

# @torch.library.custom_op("yunchang::_zigzag_ring_flash_attn_forward_op", mutates_args=(), device_types="cuda")
# def zigzag_ring_flash_attn_forward_op(out: torch.Tensor, softmax_lse: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
#     return out.clone().detach().requires_grad_(True), softmax_lse.clone().detach().requires_grad_(True)

# @zigzag_ring_flash_attn_forward_op.register_fake
# def _(out: torch.Tensor, softmax_lse: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
#     return out.clone().detach().requires_grad_(True), softmax_lse.clone().detach().requires_grad_(True)

def upipe_ring_flash_attn_backward(
    process_group,
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    softmax_scale,
    dropout_p=0,
    causal=True,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    attn_type: AttnType = AttnType.FA,
):
    assert causal == True, "upipe ring is meaningless for causal=False"
    kv_comm = RingComm(process_group)
    d_kv_comm = RingComm(process_group)
    dq, dk, dv = None, None, None
    next_dk, next_dv = None, None
    next_k, next_v = None, None
    dk_comm_buffer, dv_comm_buffer = None, None

    dout1 = dout.chunk(2, dim=1)[1]
    q1 = q.chunk(2, dim=1)[1]
    out1 = out.chunk(2, dim=1)[1]
    softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous()
    block_seq_len = q.shape[1] // 2

    # repeatly allocating buffer may be slow...
    # if k.shape[2] == q.shape[2]:
    dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
    dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
    # else:
    #     dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    #     dk_buffer = torch.empty((k.shape[0], k.shape[1], q.shape[2], k.shape[3]), dtype=k.dtype, device=k.device)
    #     dv_buffer = torch.empty((v.shape[0], v.shape[1], q.shape[2], v.shape[3]), dtype=v.dtype, device=v.device)

    def backward(dout, q, k, v, out, softmax_lse, causal):
        seqlen_q = q.shape[1]
        seqlen_kv = k.shape[1]
        fn = select_flash_attn_impl(attn_type, stage="bwd-only")

        fn(
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq_buffer[:, :seqlen_q],
            dk_buffer[:, :seqlen_kv],
            dv_buffer[:, :seqlen_kv],
            dropout_p,
            softmax_scale,
            causal,
            window_size,
            softcap,
            alibi_slopes,
            deterministic,
            rng_state=None,
        )

    for step in range(kv_comm.world_size):
        # assert not check_nan_inf(dout, f"dout -{step}", dist.get_rank()), f"Pipe: dout is nan or inf at step {step}"
        # assert not check_nan_inf(q, f"q -{step}", dist.get_rank()), f"Pipe: q is nan or inf at step {step}"
        # assert not check_nan_inf(k, f"k -{step}", dist.get_rank()), f"Pipe: k is nan or inf at step {step}"
        # assert not check_nan_inf(v, f"v -{step}", dist.get_rank()), f"Pipe: v is nan or inf at step {step}"
        # assert not check_nan_inf(out, f"out -{step}", dist.get_rank()), f"Pipe: out is nan or inf at step {step}"
        # assert not check_nan_inf(softmax_lse, f"softmax_lse -{step}", dist.get_rank()), f"Pipe: softmax_lse is nan or inf at step {step}"

        if step + 1 != kv_comm.world_size:
            next_k = kv_comm.send_recv(k)
            next_v = kv_comm.send_recv(v)
            kv_comm.commit()

        if step == 0:
            backward(dout.contiguous(), q.contiguous(), k.contiguous(), v.contiguous(), out.contiguous(), softmax_lse.contiguous(), causal=True)
            dq = dq_buffer.to(torch.float32)
            dk = dk_buffer.to(torch.float32)
            dv = dv_buffer.to(torch.float32)
            # assert not check_nan_inf(dq, f"dq -{step}", dist.get_rank()), f"Pipe: dq is nan or inf at step {step}"
            # assert not check_nan_inf(dk, f"dk -{step}", dist.get_rank()), f"Pipe: dk is nan or inf at step {step}"
            # assert not check_nan_inf(dv, f"dv -{step}", dist.get_rank()), f"Pipe: dv is nan or inf at step {step}"
        else:
            if step <= kv_comm.rank:
                k0 = k[:, :block_seq_len]
                v0 = v[:, :block_seq_len]
                backward(dout, q, k0, v0, out, softmax_lse, causal=False)
                dq += dq_buffer
                # assert not check_nan_inf(dq, f"dq 251 -{step}", dist.get_rank()), f"Pipe: dq is nan or inf at step {step}"
            else:
                backward(dout1, q1, k, v, out1, softmax_lse1, causal=False)
                # always use the first half in dq_buffer.
                dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len]
                # assert not check_nan_inf(dq, f"dq 256 -{step}", dist.get_rank()), f"Pipe: dq is nan or inf at step {step}"

            d_kv_comm.wait()
            dk_comm_buffer, dv_comm_buffer = dk, dv
            dk, dv = next_dk, next_dv

            if step <= kv_comm.rank:
                dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len]
                dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len]
            else:
                dk += dk_buffer
                dv += dv_buffer
                # assert not check_nan_inf(dk, f"dk 268 -{step}", dist.get_rank()), f"Pipe: dk is nan or inf at step {step}"
                # assert not check_nan_inf(dv, f"dv 269 -{step}", dist.get_rank()), f"Pipe: dv is nan or inf at step {step}"

        if step + 1 != kv_comm.world_size:
            kv_comm.wait()
            k = next_k
            v = next_v

        next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer)
        next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer)
        d_kv_comm.commit()

    d_kv_comm.wait()

    # if k.shape[2] != q.shape[2]:
    #     bs, slen, nqh, hdim = q.shape
    #     return dq.to(q.dtype), next_dk.view(bs, slen, k.shape[2], (q.shape[2]//k.shape[2]), hdim).sum(dim=3).to(q.dtype), next_dv.view(bs, slen, v.shape[2], (q.shape[2]//v.shape[2]), hdim).sum(dim=3).to(q.dtype)
    # else:
    return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype)



class UpipeRingFlashAttnFunc(torch.autograd.Function):
    
    @staticmethod
    def forward(
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        attn_type,
    ):  
        # Do all the ulysses stuff here
        bs, shard_seqlen, hc, hs = q.shape

        world_size = dist.get_world_size(ulysses_group)
        pipe_degree = hc // world_size

        # context_layer = torch.empty([bs, shard_seqlen*world_size, hc//world_size, hs], dtype=query.dtype, device=query.device)
        output = torch.zeros_like(q)
        # softmax_lse = torch.zeros([bs, shard_seqlen, hc], dtype=q.dtype, device=q.device)
        softmax_lse = None#torch.zeros([bs, pipe_degree, shard_seqlen*world_size], dtype=q.dtype, device=q.device)
        

        assert k.shape[2] == q.shape[2], f"Pipe: num heads in key {k.shape[2]} must be equal to query {q.shape[2]}"
        qkv = torch.cat([q, k, v]).contiguous()

        orig_device = qkv.device

        qkv = torch.chunk(qkv, pipe_degree, dim = 2)

        # async offload the qkv tensors to CPU
        if offload_stream is not None:
            with torch.cuda.stream(offload_stream):
                for tnsr in qkv[2:]:
                    tnsr.to("cpu", non_blocking=True)

        # first all-to-all is blocking
        # ulysses_qkv = SeqAllToAll4D.apply(
        #     self.ulysses_pg, qkv[:,:,:world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, False
        # )
        ulysses_qkv = all_to_all_4D(
            qkv[0], 2, 1, ulysses_group, False, False # scatter 2, gather 1
        )
        
        if offload_stream is not None:
            torch.cuda.current_stream().synchronize()
            with torch.cuda.stream(offload_stream):
                qkv[0].to("cpu", non_blocking=True)

        # Initialize variables to avoid undefined variable errors
        block_output = None
        o_bs = o_shard_seqlen = o_hc = o_hs = None

        # global variables
        global process_group
        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size

        for stage in range(pipe_degree):
            if stage+1 != pipe_degree:
                # next_ulysses_qkv = SeqAllToAll4D.apply(
                #     self.ulysses_pg, qkv[:,:,(stage+1)*world_size:(stage+2)*world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, True
                # )
                if stage+2 < pipe_degree:
                    if fetch_stream is not None:
                        fetch_stream.synchronize() # make sure the last qkv is fetched to GPU
                        with torch.cuda.stream(fetch_stream):
                            qkv[stage+2].to(orig_device, non_blocking=True)
                next_ulysses_qkv = all_to_all_4D(
                    qkv[stage+1], 2, 1, ulysses_group, False, True # scatter 2, gather 1
                )
                bs, shard_seqlen, hc, hs = qkv[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            # output[stage].to(orig_device, non_blocking=True)

            ulysses_qkv = torch.chunk(ulysses_qkv, 3, dim=0)

            inp_q = ulysses_qkv[0]
            inp_k = ulysses_qkv[1]
            inp_v = ulysses_qkv[2]

            if softmax_scale is None:
                softmax_scale = inp_q.shape[-1] ** (-0.5)

            assert alibi_slopes is None
            inp_k = inp_k.contiguous()
            inp_v = inp_v.contiguous()

            # check_nan_inf(inp_q, "inp_q", dist.get_rank())
            # check_nan_inf(inp_k, "inp_k", dist.get_rank())
            # check_nan_inf(inp_v, "inp_v", dist.get_rank())

            out, lse = upipe_ring_flash_attn_forward(
                inp_q,
                inp_k,
                inp_v,
                softmax_scale=softmax_scale,
                dropout_p=dropout_p,
                causal=causal,
                softcap=softcap,
                deterministic=False,
            )

            # if softmax_lse is None:
            #     softmax_lse = lse#.clone() # cloning the lse tensor to avoid overwriting the lse tensor
            # else:
            #     softmax_lse = torch.cat([softmax_lse, lse], dim = 1)
            if softmax_lse is None:
                softmax_lse = torch.zeros_like(lse).repeat(1, pipe_degree, 1)
            softmax_lse[:, stage, :] = lse

            # try:
            #     context_layer = torch.cat([out, lse.unsqueeze(-1)], dim = 3)
            # except:
            #     print(f"Pipe: out shape {out.shape}, lse shape {lse.shape}")
            #     raise ValueError("Pipe: out and lse shapes are not compatible")
            context_layer = out
            
            o_bs, o_seqlen, o_shard_hc, o_hs = context_layer.shape
            o_hc = o_shard_hc * world_size
            o_shard_seqlen = o_seqlen // world_size
            
            if stage > 0:
                if U_HANDLE.O_HANDLE!=None:
                    U_HANDLE.O_HANDLE.wait()
                clear_o_handle()
                assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                output[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_output#[:,:,:,:-1]
                # softmax_lse[:,:,(stage-1)*world_size:(stage)*world_size] = block_output[:,:,:,-1]
                # output[stage-1] = block_output
                # output[stage-1].to("cpu", non_blocking=True)

            block_output = all_to_all_4D(
                context_layer, 1, 2, ulysses_group, False, True # scatter 1, gather 2
            )
            
            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=None:
                    U_HANDLE.HANDLE.wait()
                clear_u_handle()
                if offload_stream is not None:
                    with torch.cuda.stream(offload_stream):
                        qkv[stage+1].to("cpu", non_blocking=True)
                assert next_ulysses_qkv.numel()== (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_qkv shape {next_ulysses_qkv.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_qkv = next_ulysses_qkv.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_qkv = next_ulysses_qkv.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_qkv = next_ulysses_qkv
                # assert not check_nan_inf(ulysses_qkv, f"ulysses_qkv_{stage+1}", dist.get_rank()), f"Pipe: ulysses_qkv is nan or inf at stage {stage+1}"
        
        if U_HANDLE.O_HANDLE!=None:
            U_HANDLE.O_HANDLE.wait()
            clear_o_handle()
            assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            output[:,:,-(world_size):,:] = block_output#[:,:,:,:-1]
            # softmax_lse[:,:,-(world_size):] = block_output[:,:,:,-1]

        # out, softmax_lse = zigzag_ring_flash_attn_forward_op(out_temp, softmax_lse_temp)

        # this should be out_padded
        ctx.save_for_backward(q, k, v, output, softmax_lse)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.ulysses_group = ulysses_group
        ctx.attn_type = attn_type
        ctx.pipe_degree = pipe_degree
        ctx.offload_stream = offload_stream
        ctx.fetch_stream = fetch_stream
        return output if not return_softmax else (output, softmax_lse, None)
        
    
    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse = ctx.saved_tensors

        pipe_degree = ctx.pipe_degree
        offload_stream = ctx.offload_stream
        fetch_stream = ctx.fetch_stream
        orig_device = dout.device
        world_size = dist.get_world_size(ctx.ulysses_group)

        dqkv = torch.zeros([3*q.shape[0], q.shape[1], q.shape[2], q.shape[3]], dtype=q.dtype, device=q.device)

        orig_q_dtype = q.dtype
        doutqkvo_lse = torch.cat([q, k, v, dout, out], dim=0).contiguous()
        eff_bs = doutqkvo_lse.shape[0]
        # doutqkvo_lse = torch.cat([doutqkvo, softmax_lse.unsqueeze(-1).repeat(eff_bs, 1, 1, 1).to(dout.dtype)], dim=3).contiguous() #piggybacking softmax_lse as the last hidden dimension for all 5 other tensors, so be mindful when unpacking
        doutqkvo_lse  = torch.chunk(doutqkvo_lse, pipe_degree, dim = 2)

        softmax_lse = torch.chunk(softmax_lse, pipe_degree, dim = 1)

        if offload_stream is not None:
            # async offload the dout tensor to CPU
            with torch.cuda.stream(offload_stream):
                for tnsr in doutqkvo_lse[2:]:
                    tnsr.to("cpu", non_blocking=True)
                for tnsr in softmax_lse[1:]:
                    tnsr.to("cpu", non_blocking=True)
        
        ulysses_doutqkvo_lse = all_to_all_4D(doutqkvo_lse[0], 2, 1, ctx.ulysses_group, False, False) #note that this is scatter 2, gather 1 since this is inverse of forward pass
        
        if offload_stream is not None:
            torch.cuda.current_stream().synchronize()
            with torch.cuda.stream(offload_stream):
                doutqkvo_lse[0].to("cpu", non_blocking=True)
        
        for stage in range(pipe_degree):
            # if dist.get_rank(ctx.ulysses_group) == 0:
            #     print(f"Tensor shape: {ulysses_doutqkvo_lse.shape}\t stage: {stage}")
            if stage+1 != pipe_degree:
                if stage+2 < pipe_degree:
                    if fetch_stream is not None:
                        fetch_stream.synchronize() # make sure the last qkv is fetched to GPU
                        with torch.cuda.stream(fetch_stream):
                            doutqkvo_lse[stage+2].to(orig_device, non_blocking=True)
                            softmax_lse[stage+1].to(orig_device, non_blocking=True)
                next_ulysses_doutqkvo_lse = all_to_all_4D(doutqkvo_lse[stage+1], 2, 1, ctx.ulysses_group, False, True) #note that this is scatter 2, gather 1 since this is inverse of forward pass
                bs, shard_seqlen, hc, hs = doutqkvo_lse[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            # Chunk the tensor to separate q, k, v, dout, out components (5 components total)
            ulysses_doutqkvo_lse = torch.chunk(ulysses_doutqkvo_lse, eff_bs, dim=0)
            assert len(ulysses_doutqkvo_lse) == eff_bs, f"Pipe: ulysses_tensors length {len(ulysses_doutqkvo_lse)} must be equal to eff_bs {eff_bs}, tensor shape {ulysses_doutqkvo_lse[0].shape}"
            
            # Extract tensors and ensure they are contiguous to avoid CUDA misalignment errors
            inp_q = ulysses_doutqkvo_lse[0]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_k = ulysses_doutqkvo_lse[1]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_v = ulysses_doutqkvo_lse[2]#[:,:,:,:].contiguous().to(orig_q_dtype)
            inp_dout = ulysses_doutqkvo_lse[3]#[:,:,:,:].contiguous()
            inp_out = ulysses_doutqkvo_lse[4]#[:,:,:,:].contiguous().to(orig_q_dtype)
            
            # Extract softmax_lse and ensure contiguous after transpose
            # inp_softmax_lse = ulysses_doutqkvo_lse[-1][:,:,:,-1].contiguous().to(orig_q_dtype) # all 5 tensors have the same softmax_lse
            # inp_softmax_lse_transposed = inp_softmax_lse.transpose(1, 2).contiguous()
            inp_softmax_lse_transposed = softmax_lse[stage]#.transpose(1, 2).contiguous()

            softmax_scale = inp_q.shape[-1] ** (-0.5)
            
            # Verify tensor contiguity for debugging
            # assert inp_q.is_contiguous(), f"inp_q is not contiguous: {inp_q.stride()}"
            # assert inp_k.is_contiguous(), f"inp_k is not contiguous: {inp_k.stride()}"
            # assert inp_v.is_contiguous(), f"inp_v is not contiguous: {inp_v.stride()}"
            # assert inp_dout.is_contiguous(), f"inp_dout is not contiguous: {inp_dout.stride()}"
            # assert inp_out.is_contiguous(), f"inp_out is not contiguous: {inp_out.stride()}"
            # assert inp_softmax_lse_transposed.is_contiguous(), f"inp_softmax_lse_transposed is not contiguous: {inp_softmax_lse_transposed.stride()}"

            # assert not check_nan_inf(inp_q, f"inp_q", dist.get_rank()), f"Pipe: inp_q is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_k, f"inp_k", dist.get_rank()), f"Pipe: inp_k is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_v, f"inp_v", dist.get_rank()), f"Pipe: inp_v is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_dout, f"inp_dout", dist.get_rank()), f"Pipe: inp_dout is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_out, f"inp_out", dist.get_rank()), f"Pipe: inp_out is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_softmax_lse_transposed, f"inp_softmax_lse_transposed", dist.get_rank()), f"Pipe: inp_softmax_lse_transposed is nan or inf at stage {stage}"

            attn_dq, attn_dk, attn_dv = upipe_ring_flash_attn_backward(
                ctx.ring_group,
                inp_dout,
                inp_q,
                inp_k,
                inp_v,
                inp_out,
                inp_softmax_lse_transposed,
                softmax_scale=softmax_scale,
                dropout_p=ctx.dropout_p,
                causal=ctx.causal,
                window_size=ctx.window_size,
                softcap=ctx.softcap,
                alibi_slopes=ctx.alibi_slopes,
                deterministic=ctx.deterministic,
                attn_type=ctx.attn_type,
            )

            if offload_stream is not None:
                torch.cuda.current_stream().synchronize()
                with torch.cuda.stream(offload_stream):
                    softmax_lse[stage].to("cpu", non_blocking=True)

            # assert not check_nan_inf(attn_dq, f"attn_dq -{stage}", dist.get_rank()), f"Pipe: attn_dq is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dk, f"attn_dk -{stage}", dist.get_rank()), f"Pipe: attn_dk is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dv, f"attn_dv -{stage}", dist.get_rank()), f"Pipe: attn_dv is nan or inf at stage {stage}"

            if stage > 0:
                if U_HANDLE.O_HANDLE!=None:
                    U_HANDLE.O_HANDLE.wait()
                clear_o_handle()
                assert block_grads.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_grads shape {block_grads.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                block_grads = block_grads.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                block_grads = block_grads.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                dqkv[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads
                #block_grads = torch.chunk(block_grads, 3, dim=0)
                # dq[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[0]
                # dk[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[1]
                # dv[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[2]
                
            
            attn_grads = torch.cat([attn_dq, attn_dk, attn_dv], dim=0)
            o_bs, o_seqlen, o_shard_hc, o_hs = attn_grads.shape
            o_hc = o_shard_hc * world_size
            o_shard_seqlen = o_seqlen // world_size

            block_grads = all_to_all_4D(
                attn_grads, 1, 2, ctx.ulysses_group, False, True # scatter 1, gather 2
            )

            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=None:
                    U_HANDLE.HANDLE.wait()
                clear_u_handle()
                if offload_stream is not None:
                    with torch.cuda.stream(offload_stream):
                        doutqkvo_lse[stage+1].to("cpu", non_blocking=True)
                        # softmax_lse[stage].to("cpu", non_blocking=True)
                assert next_ulysses_doutqkvo_lse.numel() ==  (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_doutqkvo_lse shape {next_ulysses_doutqkvo_lse.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse
        
        if U_HANDLE.O_HANDLE!=None:
            U_HANDLE.O_HANDLE.wait()
            clear_o_handle()
            assert block_grads.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_grads shape {block_grads.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_grads = block_grads.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_grads = block_grads.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            dqkv[:,:,-(world_size):,:] = block_grads
            dqkv = torch.chunk(dqkv, 3, dim=0)
            # block_grads = torch.chunk(block_grads, 3, dim=0)
            # dq[:,:,-(world_size):,:] = block_grads[0]
            # dk[:,:,-(world_size):,:] = block_grads[1]
            # dv[:,:,-(world_size):,:] = block_grads[2]

        # assert not check_nan_inf(dqkv[0], f"dqkv[0]", dist.get_rank()), f"Pipe: dqkv[0] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[1], f"dqkv[1]", dist.get_rank()), f"Pipe: dqkv[1] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[2], f"dqkv[2]", dist.get_rank()), f"Pipe: dqkv[2] is nan or inf at stage"
        return dqkv[0], dqkv[1], dqkv[2], None, None, None, None, None, None, None, None, None, None, None, None, None

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        torch.unsqueeze(x, dim=3)
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )

class UpipeRingFlashAttnGQAFunc(torch.autograd.Function):
    
    @staticmethod
    def forward(
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        attn_type,
    ):  
        assert k.shape[2] < q.shape[2], f"PipeGQA : num heads in key {k.shape[2]} must be less than query {q.shape[2]}"

        orig_k_hc = k.shape[2]

        world_size = dist.get_world_size(ulysses_group)

        if k.shape[2] < world_size:
            assert world_size % k.shape[2] == 0, f"PipeGQA : num heads in key {k.shape[2]} must be divisible by world size {world_size}"
            k = repeat_kv(k, world_size//k.shape[2])
            v = repeat_kv(v, world_size//v.shape[2]) # this is the only case where we need to repeat kv

        # Do all the ulysses stuff here
        bs, shard_seqlen, q_hc, hs = q.shape
        bs, shard_seqlen, key_hc, key_hs = k.shape

        pipe_degree = q_hc // world_size

        # context_layer = torch.empty([bs, shard_seqlen*world_size, hc//world_size, hs], dtype=query.dtype, device=query.device)
        output = torch.zeros_like(q)
        # softmax_lse = torch.zeros([bs, shard_seqlen, hc], dtype=q.dtype, device=q.device)
        softmax_lse = None#torch.zeros([bs, pipe_degree, shard_seqlen*world_size], dtype=q.dtype, device=q.device)
        
        # prepare packets for all-to-all
        # whenever we are sending q[i] where i%(hc//key_hc) == 0, we need to send k[i] and v[i]
        # otherwise, we only need to send q[i]

        q_idx = torch.arange(q_hc, device=q.device)
        q_idx = q_idx.reshape(-1, q_hc//key_hc).unsqueeze(0).reshape(-1, world_size, q_hc//key_hc).transpose(1,2).flatten(1,2).flatten()

        comm_pkts = []
        for stage in range(pipe_degree):
            # each stage processes world_size heads
            if stage % (q_hc//key_hc) == 0:
                try:
                    comm_pkts.append(torch.cat([q[:,:,q_idx[stage*world_size:(stage+1)*world_size],:],
                                                k[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:],
                                                v[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:]
                                                ]).contiguous())
                except:
                    assert False, f"k[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:].shape {k[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:].shape}, stage {stage}, world_size {world_size}, q_hc {q_hc}, key_hc {key_hc}"
            else:
                comm_pkts.append(q[:,:,q_idx[stage*world_size:(stage+1)*world_size],:])
        
        # qkv = torch.cat([q, k, v]).contiguous()
        # qkv = torch.chunk(qkv, pipe_degree, dim = 2)

        # qkv = torch.cat([q, k, v]).contiguous()

        orig_device = q.device

        # qkv = torch.chunk(qkv, pipe_degree, dim = 2)

        if offload_stream is not None:
            # async offload the qkv tensors to CPU
            with torch.cuda.stream(offload_stream):
                for tnsr in comm_pkts[2:]:
                    tnsr.to("cpu", non_blocking=True)

        # first all-to-all is blocking
        # ulysses_qkv = SeqAllToAll4D.apply(
        #     self.ulysses_pg, qkv[:,:,:world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, False
        # )
        ulysses_qkv = all_to_all_4D(
            comm_pkts[0], 2, 1, ulysses_group, False, False # scatter 2, gather 1
        )

        assert ulysses_qkv.shape[0] == 3, f"INITIALLY, Got ulysses_qkv with length {ulysses_qkv.shape[0]} in stage 0"

        if offload_stream is not None:
            torch.cuda.current_stream().synchronize()
            with torch.cuda.stream(offload_stream):
                comm_pkts[0].to("cpu", non_blocking=True)

        # Initialize variables to avoid undefined variable errors
        block_output = None
        o_bs = o_shard_seqlen = o_hc = o_hs = None

        # global variables
        global process_group
        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size

        for stage in range(pipe_degree):
            if stage+1 != pipe_degree:
                # next_ulysses_qkv = SeqAllToAll4D.apply(
                #     self.ulysses_pg, qkv[:,:,(stage+1)*world_size:(stage+2)*world_size,:], self.scatter_idx, self.gather_idx, self.use_sync, True
                # )
                if stage+2 < pipe_degree:
                    if fetch_stream is not None:
                        fetch_stream.synchronize() # make sure the last qkv is fetched to GPU
                        with torch.cuda.stream(fetch_stream):
                            comm_pkts[stage+2].to(orig_device, non_blocking=True)
                next_ulysses_qkv = all_to_all_4D(
                    comm_pkts[stage+1], 2, 1, ulysses_group, False, True # scatter 2, gather 1
                )
                bs, shard_seqlen, hc, hs = comm_pkts[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            # output[stage].to(orig_device, non_blocking=True)

            if stage % (q_hc//key_hc) == 0:
                assert ulysses_qkv.shape[0] == 3, f"Got ulysses_qkv with length {ulysses_qkv.shape[0]} in stage {stage}, with hc {hc} and key_hc {key_hc}"
                ulysses_qkv = torch.chunk(ulysses_qkv, 3, dim=0)
                
                inp_q = ulysses_qkv[0]
                inp_k = ulysses_qkv[1]
                inp_v = ulysses_qkv[2]
            else:
                inp_q = ulysses_qkv

            if softmax_scale is None:
                softmax_scale = inp_q.shape[-1] ** (-0.5)

            assert alibi_slopes is None
            inp_k = inp_k.contiguous()
            inp_v = inp_v.contiguous()

            # check_nan_inf(inp_q, "inp_q", dist.get_rank())
            # check_nan_inf(inp_k, "inp_k", dist.get_rank())
            # check_nan_inf(inp_v, "inp_v", dist.get_rank())

            out, lse = upipe_ring_flash_attn_forward(
                inp_q,
                inp_k,
                inp_v,
                softmax_scale=softmax_scale,
                dropout_p=dropout_p,
                causal=causal,
                softcap=softcap,
                deterministic=False,
            )

            # if softmax_lse is None:
            #     softmax_lse = lse#.clone() # cloning the lse tensor to avoid overwriting the lse tensor
            # else:
            #     softmax_lse = torch.cat([softmax_lse, lse], dim = 1)
            if softmax_lse is None:
                softmax_lse = torch.zeros_like(lse).repeat(1, pipe_degree, 1)
            softmax_lse[:, stage, :] = lse
            # CRITICAL: we need to ensure that the processing for backward pass is done in the same order as the forward pass for softmax_lse to be correctly indexed

            # try:
            #     context_layer = torch.cat([out, lse.unsqueeze(-1)], dim = 3)
            # except:
            #     print(f"Pipe: out shape {out.shape}, lse shape {lse.shape}")
            #     raise ValueError("Pipe: out and lse shapes are not compatible")
            context_layer = out
            
            o_bs, o_seqlen, o_shard_hc, o_hs = context_layer.shape
            o_hc = o_shard_hc * world_size
            o_shard_seqlen = o_seqlen // world_size
            
            if stage > 0:
                if U_HANDLE.O_HANDLE!=None:
                    U_HANDLE.O_HANDLE.wait()
                clear_o_handle()
                assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                output[:,:,q_idx[(stage-1)*world_size:(stage)*world_size],:] = block_output#[:,:,:,:-1]
                # softmax_lse[:,:,(stage-1)*world_size:(stage)*world_size] = block_output[:,:,:,-1]
                # output[stage-1] = block_output
                # output[stage-1].to("cpu", non_blocking=True)

            block_output = all_to_all_4D(
                context_layer, 1, 2, ulysses_group, False, True # scatter 1, gather 2
            )
            
            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=None:
                    U_HANDLE.HANDLE.wait()
                clear_u_handle()
                if offload_stream is not None:
                    with torch.cuda.stream(offload_stream):
                        comm_pkts[stage+1].to("cpu", non_blocking=True)
                assert next_ulysses_qkv.numel()== (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_qkv shape {next_ulysses_qkv.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_qkv = next_ulysses_qkv.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_qkv = next_ulysses_qkv.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_qkv = next_ulysses_qkv
                # assert not check_nan_inf(ulysses_qkv, f"ulysses_qkv_{stage+1}", dist.get_rank()), f"Pipe: ulysses_qkv is nan or inf at stage {stage+1}"
        
        if U_HANDLE.O_HANDLE!=None:
            U_HANDLE.O_HANDLE.wait()
            clear_o_handle()
            assert block_output.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_output shape {block_output.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_output = block_output.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_output = block_output.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            output[:,:,q_idx[-(world_size):],:] = block_output#[:,:,:,:-1]
            # softmax_lse[:,:,-(world_size):] = block_output[:,:,:,-1]

        # out, softmax_lse = zigzag_ring_flash_attn_forward_op(out_temp, softmax_lse_temp)

        # this should be out_padded
        ctx.save_for_backward(q, k, v, output, softmax_lse)
        ctx.orig_k_hc = orig_k_hc
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.ulysses_group = ulysses_group
        ctx.attn_type = attn_type
        ctx.pipe_degree = pipe_degree
        ctx.offload_stream = offload_stream
        ctx.fetch_stream = fetch_stream
        return output if not return_softmax else (output, softmax_lse, None)
        
    
    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse = ctx.saved_tensors

        world_size = dist.get_world_size(ctx.ulysses_group)

        if k.shape[2] < world_size:
            assert world_size % k.shape[2] == 0, f"PipeGQA : num heads in key {k.shape[2]} must be divisible by world size {world_size}"
            k = repeat_kv(k, world_size//k.shape[2])
            v = repeat_kv(v, world_size//v.shape[2]) # this is the only case where we need to repeat kv

        # Do all the ulysses stuff here
        bs, shard_seqlen, q_hc, hs = q.shape
        bs, shard_seqlen, key_hc, key_hs = k.shape

        q_idx = torch.arange(q_hc, device=q.device)
        q_idx = q_idx.reshape(-1, q_hc//key_hc).unsqueeze(0).reshape(-1, world_size, q_hc//key_hc).transpose(1,2).flatten(1,2).flatten()

        pipe_degree = ctx.pipe_degree
        offload_stream = ctx.offload_stream
        fetch_stream = ctx.fetch_stream
        orig_device = dout.device

        dqkv = torch.zeros([3*q.shape[0], q.shape[1], q.shape[2], q.shape[3]], dtype = q.dtype, device=q.device)

        eff_bs_1, eff_bs_2 = None, None
        doutqkvo_lse = []
        for stage in range(pipe_degree):
            # each stage processes world_size heads
            if stage % (q_hc//key_hc) == 0:
                doutqkvo_lse.append(torch.cat([q[:,:,q_idx[stage*world_size:(stage+1)*world_size],:],
                                               k[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:],
                                               v[:,:,(stage//(q_hc//key_hc))*world_size:(stage//(q_hc//key_hc)+1)*world_size,:],
                                               dout[:,:,q_idx[stage*world_size:(stage+1)*world_size],:],
                                               out[:,:,q_idx[stage*world_size:(stage+1)*world_size],:]
                                               ], dim = 0).contiguous())
                if eff_bs_1==None:
                    eff_bs_1 = doutqkvo_lse[-1].shape[0]
            else:
                doutqkvo_lse.append(torch.cat([q[:,:,q_idx[stage*world_size:(stage+1)*world_size],:],
                                               dout[:,:,q_idx[stage*world_size:(stage+1)*world_size],:],
                                               out[:,:,q_idx[stage*world_size:(stage+1)*world_size],:]
                                               ], dim = 0).contiguous())
                if eff_bs_2==None:
                    eff_bs_2 = doutqkvo_lse[-1].shape[0]
        # doutqkvo_lse = torch.cat([q, k, v, dout, out], dim=0).contiguous() # 5 tensors in total
        # eff_bs = doutqkvo_lse.shape[0]
        # doutqkvo_lse = torch.cat([doutqkvo, softmax_lse.unsqueeze(-1).repeat(eff_bs, 1, 1, 1).to(dout.dtype)], dim=3).contiguous() #piggybacking softmax_lse as the last hidden dimension for all 5 other tensors, so be mindful when unpacking
        # doutqkvo_lse  = torch.chunk(doutqkvo_lse, pipe_degree, dim = 2)

        softmax_lse = torch.chunk(softmax_lse, pipe_degree, dim = 1)

        if offload_stream is not None:
            # async offload the dout tensor to CPU
            with torch.cuda.stream(offload_stream):
                for tnsr in doutqkvo_lse[2:]:
                    tnsr.to("cpu", non_blocking=True)
                for tnsr in softmax_lse[1:]:
                    tnsr.to("cpu", non_blocking=True)
        
        ulysses_doutqkvo_lse = all_to_all_4D(doutqkvo_lse[0], 2, 1, ctx.ulysses_group, False, False) #note that this is scatter 2, gather 1 since this is inverse of forward pass
        
        if offload_stream is not None:
            torch.cuda.current_stream().synchronize()
            with torch.cuda.stream(offload_stream):
                doutqkvo_lse[0].to("cpu", non_blocking=True)
        
        for stage in range(pipe_degree):
            # if dist.get_rank(ctx.ulysses_group) == 0:
            #     print(f"Tensor shape: {ulysses_doutqkvo_lse.shape}\t stage: {stage}")
            if stage+1 != pipe_degree:
                if stage+2 < pipe_degree:
                    if fetch_stream is not None:
                        fetch_stream.synchronize() # make sure the last qkv is fetched to GPU
                        with torch.cuda.stream(fetch_stream):
                            doutqkvo_lse[stage+2].to(orig_device, non_blocking=True)
                            softmax_lse[stage+1].to(orig_device, non_blocking=True)
                next_ulysses_doutqkvo_lse = all_to_all_4D(doutqkvo_lse[stage+1], 2, 1, ctx.ulysses_group, False, True) #note that this is scatter 2, gather 1 since this is inverse of forward pass
                bs, shard_seqlen, hc, hs = doutqkvo_lse[stage+1].shape
                seqlen = shard_seqlen * world_size
                shard_hc = hc // world_size

            if stage % (q_hc//key_hc) == 0:
                assert ulysses_doutqkvo_lse.shape[0] == 5, f"Got ulysses_doutqkvo_lse with length {ulysses_doutqkvo_lse.shape[0]} in stage {stage}, with q_hc {q_hc} and key_hc {key_hc}"
                # Chunk the tensor to separate q, k, v, dout, out components (5 components total)
                ulysses_doutqkvo_lse = torch.chunk(ulysses_doutqkvo_lse, eff_bs_1, dim=0)
                
                # Extract tensors and ensure they are contiguous to avoid CUDA misalignment errors
                inp_q = ulysses_doutqkvo_lse[0]#[:,:,:,:].contiguous().to(orig_q_dtype)
                inp_k = ulysses_doutqkvo_lse[1]#[:,:,:,:].contiguous().to(orig_q_dtype)
                inp_v = ulysses_doutqkvo_lse[2]#[:,:,:,:].contiguous().to(orig_q_dtype)
                inp_dout = ulysses_doutqkvo_lse[3]#[:,:,:,:].contiguous()
                inp_out = ulysses_doutqkvo_lse[4]#[:,:,:,:].contiguous().to(orig_q_dtype)
            else:
                # Chunk the tensor to separate q, dout, out components (3 components total)
                assert ulysses_doutqkvo_lse.shape[0] == 3, f"Got ulysses_doutqkvo_lse with length {ulysses_doutqkvo_lse.shape[0]} in stage {stage}, with q_hc {q_hc} and key_hc {key_hc}"
                ulysses_doutqkvo_lse = torch.chunk(ulysses_doutqkvo_lse, eff_bs_2, dim=0)
                
                # Extract tensors and ensure they are contiguous to avoid CUDA misalignment errors
                inp_q = ulysses_doutqkvo_lse[0]#[:,:,:,:].contiguous().to(orig_q_dtype)
                inp_dout = ulysses_doutqkvo_lse[1]#[:,:,:,:].contiguous()
                inp_out = ulysses_doutqkvo_lse[2]#[:,:,:,:].contiguous().to(orig_q_dtype)

            
            # Extract softmax_lse and ensure contiguous after transpose
            # inp_softmax_lse = ulysses_doutqkvo_lse[-1][:,:,:,-1].contiguous().to(orig_q_dtype) # all 5 tensors have the same softmax_lse
            # inp_softmax_lse_transposed = inp_softmax_lse.transpose(1, 2).contiguous()
            inp_softmax_lse_transposed = softmax_lse[stage]#.transpose(1, 2).contiguous()

            softmax_scale = inp_q.shape[-1] ** (-0.5)
            
            # Verify tensor contiguity for debugging
            # assert inp_q.is_contiguous(), f"inp_q is not contiguous: {inp_q.stride()}"
            # assert inp_k.is_contiguous(), f"inp_k is not contiguous: {inp_k.stride()}"
            # assert inp_v.is_contiguous(), f"inp_v is not contiguous: {inp_v.stride()}"
            # assert inp_dout.is_contiguous(), f"inp_dout is not contiguous: {inp_dout.stride()}"
            # assert inp_out.is_contiguous(), f"inp_out is not contiguous: {inp_out.stride()}"
            # assert inp_softmax_lse_transposed.is_contiguous(), f"inp_softmax_lse_transposed is not contiguous: {inp_softmax_lse_transposed.stride()}"

            # assert not check_nan_inf(inp_q, f"inp_q", dist.get_rank()), f"Pipe: inp_q is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_k, f"inp_k", dist.get_rank()), f"Pipe: inp_k is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_v, f"inp_v", dist.get_rank()), f"Pipe: inp_v is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_dout, f"inp_dout", dist.get_rank()), f"Pipe: inp_dout is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_out, f"inp_out", dist.get_rank()), f"Pipe: inp_out is nan or inf at stage {stage}"
            # assert not check_nan_inf(inp_softmax_lse_transposed, f"inp_softmax_lse_transposed", dist.get_rank()), f"Pipe: inp_softmax_lse_transposed is nan or inf at stage {stage}"

            attn_dq, attn_dk, attn_dv = upipe_ring_flash_attn_backward(
                ctx.ring_group,
                inp_dout,
                inp_q,
                inp_k,
                inp_v,
                inp_out,
                inp_softmax_lse_transposed,
                softmax_scale=softmax_scale,
                dropout_p=ctx.dropout_p,
                causal=ctx.causal,
                window_size=ctx.window_size,
                softcap=ctx.softcap,
                alibi_slopes=ctx.alibi_slopes,
                deterministic=ctx.deterministic,
                attn_type=ctx.attn_type,
            )

            if offload_stream is not None:
                torch.cuda.current_stream().synchronize()
                with torch.cuda.stream(offload_stream):
                    softmax_lse[stage].to("cpu", non_blocking=True)

            # assert not check_nan_inf(attn_dq, f"attn_dq -{stage}", dist.get_rank()), f"Pipe: attn_dq is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dk, f"attn_dk -{stage}", dist.get_rank()), f"Pipe: attn_dk is nan or inf at stage {stage}"
            # assert not check_nan_inf(attn_dv, f"attn_dv -{stage}", dist.get_rank()), f"Pipe: attn_dv is nan or inf at stage {stage}"

            if stage > 0:
                if U_HANDLE.O_HANDLE!=None:
                    U_HANDLE.O_HANDLE.wait()
                clear_o_handle()
                assert block_grads.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_grads shape {block_grads.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
                block_grads = block_grads.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
                block_grads = block_grads.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
                try:
                    dqkv[:,:,q_idx[(stage-1)*world_size:(stage)*world_size],:] = block_grads
                except:
                    print(f"dqkv shape {dqkv.shape}, q_idx shape {q_idx.shape}, stage {stage}, world_size {world_size}, block_grads shape {block_grads.shape}")
                    assert False, f"dqkv shape {dqkv.shape}, q_idx shape {q_idx.shape}, stage {stage}, world_size {world_size}, block_grads shape {block_grads.shape}"
                #block_grads = torch.chunk(block_grads, 3, dim=0)
                # dq[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[0]
                # dk[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[1]
                # dv[:,:,(stage-1)*world_size:(stage)*world_size,:] = block_grads[2]
                
            
            attn_grads = torch.cat([attn_dq, attn_dk, attn_dv], dim=0)
            o_bs, o_seqlen, o_shard_hc, o_hs = attn_grads.shape
            o_hc = o_shard_hc * world_size
            o_shard_seqlen = o_seqlen // world_size

            block_grads = all_to_all_4D(
                attn_grads, 1, 2, ctx.ulysses_group, False, True # scatter 1, gather 2
            )

            if stage+1 != pipe_degree:
                if U_HANDLE.HANDLE!=None:
                    U_HANDLE.HANDLE.wait()
                clear_u_handle()
                if offload_stream is not None:
                    with torch.cuda.stream(offload_stream):
                        doutqkvo_lse[stage+1].to("cpu", non_blocking=True)
                    # softmax_lse[stage].to("cpu", non_blocking=True)
                assert next_ulysses_doutqkvo_lse.numel() ==  (seqlen*bs*shard_hc*hs), f"Pipe: next_ulysses_doutqkvo_lse shape {next_ulysses_doutqkvo_lse.shape} must be compatible to (seqlen*bs*shard_hc*hs) {seqlen*bs*shard_hc*hs}"
                next_ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse.reshape(seqlen, bs, shard_hc, hs)
                next_ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
                ulysses_doutqkvo_lse = next_ulysses_doutqkvo_lse
        
        if U_HANDLE.O_HANDLE!=None:
            U_HANDLE.O_HANDLE.wait()
            clear_o_handle()
            assert block_grads.numel() == (o_bs*o_shard_seqlen*o_hc*o_hs), f"Pipe: block_grads shape {block_grads.shape} must be compatible to (o_bs*o_shard_seqlen*o_hc*o_hs) {o_bs*o_shard_seqlen*o_hc*o_hs}"
            block_grads = block_grads.reshape(o_hc, o_shard_seqlen, o_bs, o_hs)
            block_grads = block_grads.transpose(0, 2).contiguous().reshape(o_bs, o_shard_seqlen, o_hc, o_hs)
            dqkv[:,:,-(world_size):,:] = block_grads
            dqkv = torch.chunk(dqkv, 3, dim=0)
            # block_grads = torch.chunk(block_grads, 3, dim=0)
            # dq[:,:,-(world_size):,:] = block_grads[0]
            # dk[:,:,-(world_size):,:] = block_grads[1]
            # dv[:,:,-(world_size):,:] = block_grads[2]

        # assert not check_nan_inf(dqkv[0], f"dqkv[0]", dist.get_rank()), f"Pipe: dqkv[0] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[1], f"dqkv[1]", dist.get_rank()), f"Pipe: dqkv[1] is nan or inf at stage"
        # assert not check_nan_inf(dqkv[2], f"dqkv[2]", dist.get_rank()), f"Pipe: dqkv[2] is nan or inf at stage"
        bs, seqlen, nh, hs = dqkv[0].shape
        key_hc = ctx.orig_k_hc
        return dqkv[0], dqkv[1].reshape(bs, seqlen, key_hc, nh//key_hc, hs).sum(3), dqkv[2].reshape(bs, seqlen, key_hc, nh//key_hc, hs).sum(3), None, None, None, None, None, None, None, None, None, None, None, None, None

def upipe_ring_flash_attn_qkvpacked_func(
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    group=None,
    attn_type: AttnType = AttnType.FA,
):
    return UpipeRingFlashAttnFunc.apply(
        qkv[:, :, 0],
        qkv[:, :, 1],
        qkv[:, :, 2],
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_attn_probs,
        group,
        attn_type,
    )


def upipe_ring_flash_attn_kvpacked_func(
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    group=None,
    attn_type: AttnType = AttnType.FA,
):
    return UpipeRingFlashAttnFunc.apply(
        q,
        kv[:, :, 0],
        kv[:, :, 1],
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_attn_probs,
        group,
        attn_type,
    )


def upipe_ring_flash_attn_func(
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    ring_group=None,
    ulysses_group=None,
    offload_stream=None,
    fetch_stream=None,
    attn_type: AttnType = AttnType.FA,
    attn_processor=None,
):
    if q.shape[2]==k.shape[2]:
        return UpipeRingFlashAttnFunc.apply(
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_attn_probs,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        attn_type,
    )

    elif q.shape[2]>k.shape[2]:
        return UpipeRingFlashAttnGQAFunc.apply(
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_attn_probs,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        attn_type,
    )