import torch
import os
import torch.nn as nn
import torch.nn.functional as F

from .utils import RingComm, update_out_and_lse
from yunchang.kernels import AttnType, select_flash_attn_impl
from yunchang.globals import U_HANDLE, clear_u_handle, clear_o_handle, PROCESS_GROUP, GRAPH_FLAGS
from yunchang.comm.all_to_all import vanilla_all_to_all_4D as all_to_all_4D, SeqAllToAll4D
from yunchang.ring.fully_pipelined_attn_backup import FullyFusedAttnFunc
import logging
logger = logging.getLogger(__name__)

from typing import List, Tuple, Dict, Any


global process_group, attn_type, alibi_slopes, window_size, next_ulysses_qkv, a2a_available, ulysses_group, two_streams

import torch.distributed as dist

def check_nan_inf(tensor, name, rank=0):
    """Check for NaN or Inf values in tensor"""
    if torch.isnan(tensor).any():
        print(f"[RANK {rank}] NaN detected in {name}, shape: {tensor.shape}")
        print(f"[RANK {rank}] NaN locations: {torch.isnan(tensor).sum().item()}")
        return True
    if torch.isinf(tensor).any():
        print(f"[RANK {rank}] Inf detected in {name}, shape: {tensor.shape}")
        print(f"[RANK {rank}] Inf locations: {torch.isinf(tensor).sum().item()}")
        return True
    return False


def fully_fused_ring_flash_attn_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: float,
    dropout_p: float = 0,
    causal: bool = True,
    softcap: float = 0.0,
    deterministic: bool=False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    global process_group, attn_type, alibi_slopes, window_size
    
    assert causal == True, "zigzag ring is meaningless for causal=False"
    comm = RingComm(process_group, torch.empty_like(q), pass_kv=False)

    block_seq_len = q.shape[1] // 2
    # q1 = q[:, block_seq_len:]
    k0 = k[:, :block_seq_len]
    v0 = v[:, :block_seq_len]

    out = None
    lse = None
    next_k, next_v = None, None
    next_q = None
    
    def forward(q, k, v, causal):
        fn = select_flash_attn_impl(attn_type, stage="fwd-only")
        block_out, block_lse = fn(
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
            window_size=window_size,
            softcap=softcap,
            alibi_slopes=alibi_slopes,
            return_softmax=True and dropout_p > 0,
        )
        return block_out, block_lse

    for step in range(comm.world_size):
        if step + 1 != comm.world_size:
            # next_k: torch.Tensor = comm.send_recv(k)
            # next_v: torch.Tensor = comm.send_recv(v)
            next_q: torch.Tensor = comm.send_recv(q)
            comm.commit()

        if step == 0:
            block_out, block_lse = forward(q, k, v, causal=True)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)
        elif step <= comm.rank:
            # for pass KV
            # k0 = k[:, :block_seq_len]
            # v0 = v[:, :block_seq_len]
            # block_out, block_lse = forward(q, k0, v0, causal=False)
            # out, lse = update_out_and_lse(out, lse, block_out, block_lse)
            
            # for pass query
            q1 = q[:, block_seq_len:]
            block_out, block_lse = forward(q1, k, v, causal=False)
            out, lse = update_out_and_lse(
                out,
                lse,
                block_out,
                block_lse,
                slice_=(slice(None), slice(block_seq_len, None)),
            )
        else:
            # for pass query
            block_out, block_lse = forward(q, k0, v0, causal=False)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)
            
            # for pass KV
            # block_out, block_lse = forward(q1, k, v, causal=False)
            # out, lse = update_out_and_lse(
            #     out,
            #     lse,
            #     block_out,
            #     block_lse,
            #     slice_=(slice(None), slice(block_seq_len, None)),
            # )

        if step + 1 != comm.world_size:
            comm.wait()
            # k = next_k
            # v = next_v
            q = next_q
    
    # del comm, k, v, next_k, next_v, next_q

    out = out.to(q.dtype)
    lse = lse.squeeze(dim=-1).transpose(1, 2)
    return out, lse

# @torch.no_grad()
def fully_fused_ring_flash_attn_backward(
    process_group,
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    softmax_scale,
    dropout_p=0,
    causal=True,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    attn_type: AttnType = AttnType.FA,
):
    assert causal == True, "upipe ring is meaningless for causal=False"
    
    # for pass KV
    kv_comm = RingComm(process_group)
    d_kv_comm = RingComm(process_group)
    
    # for pass query
    # q_comm = RingComm(process_group, torch.empty_like(q))
    # d_q_comm = RingComm(process_group, torch.empty_like(q))
    
    dq, dk, dv = None, None, None
    next_dk, next_dv = None, None
    next_k, next_v = None, None
    dk_comm_buffer, dv_comm_buffer = None, None
    
    # next_dq = None
    # next_q = None
    # dq_comm_buffer = None

    dout1 = dout.chunk(2, dim=1)[1]
    q1 = q.chunk(2, dim=1)[1]
    out1 = out.chunk(2, dim=1)[1]
    softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous()
    block_seq_len = q.shape[1] // 2
    
    # for pass query
    # k0 = k[:, :block_seq_len]
    # v0 = v[:, :block_seq_len]

    # repeatly allocating buffer may be slow...
    # if k.shape[2] == q.shape[2]:
    dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
    dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
    # else:
    #     dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    #     dk_buffer = torch.empty((k.shape[0], k.shape[1], q.shape[2], k.shape[3]), dtype=k.dtype, device=k.device)
    #     dv_buffer = torch.empty((v.shape[0], v.shape[1], q.shape[2], v.shape[3]), dtype=v.dtype, device=v.device)

    def backward(dout, q, k, v, out, softmax_lse, causal):
        seqlen_q = q.shape[1]
        seqlen_kv = k.shape[1]
        fn = select_flash_attn_impl(attn_type, stage="bwd-only")

        fn(
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq_buffer[:, :seqlen_q],
            dk_buffer[:, :seqlen_kv],
            dv_buffer[:, :seqlen_kv],
            dropout_p,
            softmax_scale,
            causal,
            window_size,
            softcap,
            alibi_slopes,
            deterministic,
            rng_state=None,
        )

    for step in range(kv_comm.world_size):
        # assert not check_nan_inf(dout, f"dout -{step}", dist.get_rank()), f"Pipe: dout is nan or inf at step {step}"
        # assert not check_nan_inf(q, f"q -{step}", dist.get_rank()), f"Pipe: q is nan or inf at step {step}"
        # assert not check_nan_inf(k, f"k -{step}", dist.get_rank()), f"Pipe: k is nan or inf at step {step}"
        # assert not check_nan_inf(v, f"v -{step}", dist.get_rank()), f"Pipe: v is nan or inf at step {step}"
        # assert not check_nan_inf(out, f"out -{step}", dist.get_rank()), f"Pipe: out is nan or inf at step {step}"
        # assert not check_nan_inf(softmax_lse, f"softmax_lse -{step}", dist.get_rank()), f"Pipe: softmax_lse is nan or inf at step {step}"

        if step + 1 != kv_comm.world_size:
            next_k = kv_comm.send_recv(k)
            next_v = kv_comm.send_recv(v)
            kv_comm.commit()
            # next_q = q_comm.send_recv(q)
            # q_comm.commit()

        if step == 0:
            backward(dout.contiguous(), q.contiguous(), k.contiguous(), v.contiguous(), out.contiguous(), softmax_lse.contiguous(), causal=True)
            if kv_comm.world_size == 1:
                return dq_buffer, dk_buffer, dv_buffer
            dq = dq_buffer.to(torch.float32)
            dk = dk_buffer.to(torch.float32)
            dv = dv_buffer.to(torch.float32)
            
            
            # assert not check_nan_inf(dq, f"dq -{step}", dist.get_rank()), f"Pipe: dq is nan or inf at step {step}"
            # assert not check_nan_inf(dk, f"dk -{step}", dist.get_rank()), f"Pipe: dk is nan or inf at step {step}"
            # assert not check_nan_inf(dv, f"dv -{step}", dist.get_rank()), f"Pipe: dv is nan or inf at step {step}"
        else:
            if step <= kv_comm.rank:
                # for pass KV
                k0 = k[:, :block_seq_len]
                v0 = v[:, :block_seq_len]
                backward(dout, q, k0, v0, out, softmax_lse, causal=False)
                dq += dq_buffer
                
                # for pass query
                # q1 = q.chunk(2, dim=1)[1]
                # backward(dout1, q1, k, v, out1, softmax_lse1, causal=False)
                # dk += dk_buffer
                # dv += dv_buffer
                
                
            else:
                # for pass query
                # backward(dout, q, k0, v0, out, softmax_lse, causal=False)
                # dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len]
                # dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len]
                

                # for pass KV
                backward(dout1, q1, k, v, out1, softmax_lse1, causal=False)
                # always use the first half in dq_buffer.
                dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len]

            # for pass KV
            d_kv_comm.wait()
            dk_comm_buffer, dv_comm_buffer = dk, dv
            dk, dv = next_dk, next_dv

            # for pass query
            # d_q_comm.wait()
            # dq_comm_buffer = dq
            # dq = next_dq

            if step <= kv_comm.rank:
                # for pass KV
                dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len]
                dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len]

                # for pass query
                # dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len]
                
            else:
                # for pass query
                # dq += dq_buffer
                
                # for pass KV
                dk += dk_buffer
                dv += dv_buffer

        if step + 1 != kv_comm.world_size:
            kv_comm.wait()
            k = next_k
            v = next_v
            
            # q_comm.wait()
            # q = next_q

        
        next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer)
        next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer)
        d_kv_comm.commit()
        # next_dq = d_q_comm.send_recv(dq, dq_comm_buffer)
        # d_q_comm.commit()

    
    d_kv_comm.wait()
    # d_q_comm.wait()

    orig_q_dtype = q.dtype

    del dout, q, k, v, out, softmax_lse, dq_buffer, dk_buffer, dv_buffer, dk_comm_buffer, dv_comm_buffer, next_k, next_v

    # if k.shape[2] != q.shape[2]:
    #     bs, slen, nqh, hdim = q.shape
    #     return dq.to(q.dtype), next_dk.view(bs, slen, k.shape[2], (q.shape[2]//k.shape[2]), hdim).sum(dim=3).to(q.dtype), next_dv.view(bs, slen, v.shape[2], (q.shape[2]//v.shape[2]), hdim).sum(dim=3).to(q.dtype)
    # else:
    
    return dq.to(orig_q_dtype).detach(), next_dk.to(orig_q_dtype).detach(), next_dv.to(orig_q_dtype).detach()
    # return next_dq.to(orig_q_dtype).detach(), dk.to(orig_q_dtype).detach(), dv.to(orig_q_dtype).detach()

from torch import Tensor
# from yunchang.ring.upipe_ring_flash_attn import upipe_ring_flash_attn_forward, upipe_ring_flash_attn_backward 

# from torchtitan.models.llama3.model import TransformerModelArgs
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """
    Reshape frequency tensor for broadcasting it with another tensor.

    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
    for the purpose of broadcasting the frequency tensor during element-wise operations.

    The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
    and the first seqlen elements will be sliced, but dim must match x.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
        x (torch.Tensor): Target tensor for broadcasting compatibility.

    Returns:
        torch.Tensor: Reshaped frequency tensor.
    """
    ndim = x.ndim
    assert ndim > 1
    seqlen = x.shape[1]
    freqs_cis = freqs_cis[0:seqlen]
    assert freqs_cis.shape == (seqlen, x.shape[-1]), f"freqs_cis.shape: {freqs_cis.shape} != (seqlen, x.shape[-1]): {(seqlen, x.shape[-1])}"
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor = None,
    freqs_cis: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings.
        xk (torch.Tensor): Key tensor to apply rotary embeddings.
        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

    Returns:
        tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
    """
    assert xq.device != torch.device("cpu"), "xq must be on GPU"
    assert freqs_cis.device != torch.device("cpu"), "freqs_cis must be on GPU"

    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    if xk is not None:
        xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
        xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
        return xq_out.type_as(xq), xk_out.type_as(xk)
    else:
        return xq_out.type_as(xq)


# ------------------------------------------------------------
# ----------------------- GQA --------------------------------
# ------------------------------------------------------------

@torch.library.custom_op("yunchang::_fully_fused_attn_gqa_forward", mutates_args=(), device_types="cuda")
# @torch.no_grad()
def fully_fused_attn_gqa_forward(
        x: Tensor,
        wq: Tensor,
        wk: Tensor,
        wv: Tensor,
        freqs_cis: Tensor,
        head_dim: int,
        dropout_p: float = 0,
        softmax_scale: float = 0,
        causal: bool = True,
        layer_id: int = 0,
) -> list[Tensor]:
    global two_streams, attn_type, alibi_slopes, window_size, ulysses_group, final_lse

    freqs_cis = freqs_cis.to(x.device)
    
    bs, seqlen, hid_dim = x.shape
    n_heads = hid_dim // head_dim
    n_kv_heads = wk.shape[0] // head_dim
    gqa_ratio = n_heads // n_kv_heads

    ulysses_degree = dist.get_world_size(ulysses_group)
    pipe_degree = n_heads // ulysses_degree

    assert n_kv_heads % ulysses_degree == 0, f"n_kv_heads: {n_kv_heads} must be divisible by ulysses_degree: {ulysses_degree}"

    proj_dim = wq.shape[1]

    q_in, k_in, v_in = [None] * pipe_degree, [None] * (pipe_degree//gqa_ratio), [None] * (pipe_degree//gqa_ratio)
    out, lse = [None] * pipe_degree, [None] * pipe_degree

    a2a_events = [torch.cuda.Event() for _ in range(pipe_degree)]

    # wt_idx = []
    # for stage in range(pipe_degree):
    #     if stage==0 or stage//gqa_ratio != (stage-1)//gqa_ratio:
    #         stage_idx = [(stage + i)*gqa_ratio for i in range(ulysses_degree)]
    #     else:
    #         stage_idx = [idx+1 for idx in stage_idx]

    #     for si in stage_idx:
    #         wt_idx.extend(range(si*head_dim, (si+1)*head_dim))
    
    # wq = wq[wt_idx, :]
    # wk = wk[wt_idx, :]
    # wv = wv[wt_idx, :]

    wq_chunks = torch.chunk(wq, pipe_degree, dim = 0) #0 is the output dimension
    wk_chunks = torch.chunk(wk, pipe_degree//gqa_ratio, dim = 0)
    wv_chunks = torch.chunk(wv, pipe_degree//gqa_ratio, dim = 0)

    # if torch.distributed.get_rank() == 0:
    #     breakpoint()
    # torch.distributed.barrier()

    q_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
    k_out = [None for _ in range(pipe_degree//gqa_ratio)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
    v_out = [None for _ in range(pipe_degree//gqa_ratio)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
    final_out = torch.empty([bs, seqlen, n_heads, head_dim], device=x.device, dtype=x.dtype, memory_format=torch.contiguous_format)

    final_lse = []
    # final_lse_out = torch.empty([bs, pipe_degree, seqlen*ulysses_degree], device=x.device, dtype=x.dtype, memory_format=torch.contiguous_format)
    # final_lse = list(torch.chunk(final_lse, pipe_degree, dim = 1))

    # two_streams[0].wait_stream(torch.cuda.current_stream())
    # two_streams[1].wait_stream(torch.cuda.current_stream())

    # #Hoisting QKV Projection outside the loop
    # for stage in range(pipe_degree):
    #     q_in[stage] = F.linear(x, wq_chunks[stage])
    #     if stage==0 or stage//gqa_ratio > (stage-1)//gqa_ratio:
    #         k_in[(stage)//gqa_ratio] = F.linear(x, wk_chunks[(stage)//gqa_ratio])
    #         v_in[(stage)//gqa_ratio] = F.linear(x, wv_chunks[(stage)//gqa_ratio])

    #         q_in[stage], k_in[(stage)//gqa_ratio] = apply_rotary_emb(xq = q_in[stage].view(bs, seqlen, -1, head_dim), xk = k_in[(stage)//gqa_ratio].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
    #         v_in[(stage)//gqa_ratio] = v_in[(stage)//gqa_ratio].view(bs, seqlen, -1, head_dim)
    #     else:
    #         q_in[stage] = apply_rotary_emb(xq = q_in[stage].view(bs, seqlen, -1, head_dim), xk = None, freqs_cis=freqs_cis)


    num_streams = len(two_streams)
    for i in range(num_streams):
        two_streams[i].wait_stream(torch.cuda.current_stream())

    # def run():
    for stage in range(pipe_degree):
        if stage == 0 or len(two_streams) == 1:
            with torch.cuda.stream(two_streams[0]):
                
                q_in[stage] = F.linear(x, wq_chunks[stage])

                if stage==0 or stage//gqa_ratio > (stage-1)//gqa_ratio:

                    k_in[(stage)//gqa_ratio] = F.linear(x, wk_chunks[(stage)//gqa_ratio])
                    v_in[(stage)//gqa_ratio] = F.linear(x, wv_chunks[(stage)//gqa_ratio])

                    q_in[stage], k_in[(stage)//gqa_ratio] = apply_rotary_emb(xq = q_in[stage].view(bs, seqlen, -1, head_dim), xk = k_in[(stage)//gqa_ratio].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                    v_in[(stage)//gqa_ratio] = v_in[(stage)//gqa_ratio].view(bs, seqlen, -1, head_dim)

                    q_out[stage] = all_to_all_4D(q_in[stage], 2, 1, False, False)#, output=q_out[stage+1])
                    k_out[(stage)//gqa_ratio] = all_to_all_4D(k_in[(stage)//gqa_ratio], 2, 1, False, False)#, output=k_out[stage+1])
                    v_out[(stage)//gqa_ratio] = all_to_all_4D(v_in[(stage)//gqa_ratio], 2, 1, False, False)#, output=v_out[stage+1])

                    a2a_events[stage].record()
                    
                    q_in[stage] = None
                    k_in[(stage)//gqa_ratio] = None
                    v_in[(stage)//gqa_ratio] = None
                else:
                    q_in[stage] = apply_rotary_emb(xq = q_in[stage].view(bs, seqlen, -1, head_dim), xk = None, freqs_cis=freqs_cis)

                    q_out[stage] = all_to_all_4D(q_in[stage], 2, 1, False, False)#, output=q_out[stage+1])

                    a2a_events[stage].record()
                    
                    q_in[stage] = None

        if stage != pipe_degree - 1 and len(two_streams) > 1:
            with torch.cuda.stream(two_streams[(stage+1)%num_streams]):
                a2a_events[stage].wait()
                
                q_in[stage+1] = F.linear(x, wq_chunks[stage+1])

                if (stage+1)//gqa_ratio > stage//gqa_ratio:
                    
                    k_in[(stage+1)//gqa_ratio] = F.linear(x, wk_chunks[(stage+1)//gqa_ratio])
                    v_in[(stage+1)//gqa_ratio] = F.linear(x, wv_chunks[(stage+1)//gqa_ratio])

                    q_in[stage+1], k_in[(stage+1)//gqa_ratio] = apply_rotary_emb(xq = q_in[stage+1].view(bs, seqlen, -1, head_dim), xk = k_in[(stage+1)//gqa_ratio].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                    v_in[(stage+1)//gqa_ratio] = v_in[(stage+1)//gqa_ratio].view(bs, seqlen, -1, head_dim)

                    q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])
                    k_out[(stage+1)//gqa_ratio] = all_to_all_4D(k_in[(stage+1)//gqa_ratio], 2, 1, False, False)#, output=k_out[stage+1])
                    v_out[(stage+1)//gqa_ratio] = all_to_all_4D(v_in[(stage+1)//gqa_ratio], 2, 1, False, False)#, output=v_out[stage+1])

                    a2a_events[stage+1].record()
                    
                    q_in[stage+1] = None
                    k_in[(stage+1)//gqa_ratio] = None
                    v_in[(stage+1)//gqa_ratio] = None
                else:
                    q_in[stage+1] = apply_rotary_emb(xq = q_in[stage+1].view(bs, seqlen, -1, head_dim), xk = None, freqs_cis=freqs_cis)

                    q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])

                    a2a_events[stage+1].record()

                    q_in[stage+1] = None

        with torch.cuda.stream(two_streams[stage%num_streams]):
            a2a_events[stage].wait()
            k_out[stage//gqa_ratio] = k_out[stage//gqa_ratio].contiguous()
            v_out[stage//gqa_ratio] = v_out[stage//gqa_ratio].contiguous()
            
            out[stage], lse[stage] = fully_fused_ring_flash_attn_forward(
                    q_out[stage],
                    k_out[stage//gqa_ratio],
                    v_out[stage//gqa_ratio],
                    softmax_scale=softmax_scale,
                    dropout_p=dropout_p,
                    causal=causal,
                    deterministic=False,
            )

            # assert not check_nan_inf(out[stage], "out[stage]", torch.distributed.get_rank()), f"NaN detected in out[stage] stage {stage}"
            # assert not check_nan_inf(lse[stage], "lse[stage]", torch.distributed.get_rank()), f"NaN detected in lse[stage] stage {stage}"

            
            final_lse.append(lse[stage])
            final_out[:, :, (stage*ulysses_degree):((stage+1)*ulysses_degree), :] = all_to_all_4D(out[stage], 1, 2, False, False)
            
            q_out[stage] = None
            if (stage+1)//gqa_ratio != stage//gqa_ratio:
                k_out[stage//gqa_ratio] = None
                v_out[stage//gqa_ratio] = None
            
            out[stage] = None
            
    
    # global GRAPH_FLAGS
    # if len(GRAPH_FLAGS) <= layer_id:
    #     run_graph = torch.cuda.make_graphed_callables(run, ())
    #     GRAPH_FLAGS.append(run_graph)
    # else:
    #     run_graph = GRAPH_FLAGS[layer_id]
    # final_out, final_lse = run()
    for i in range(num_streams):
        torch.cuda.current_stream().wait_stream(two_streams[i])

    # torch.cuda.empty_cache()

    # final_lse = torch.cat(final_lse, dim=1)

    # final_out = torch.cat(final_out, dim=2)
    output = [final_out]
    output.extend(final_lse)

    return output

@fully_fused_attn_gqa_forward.register_fake
def _(
        x: Tensor,
        wq: Tensor,
        wk: Tensor,
        wv: Tensor,
        freqs_cis: Tensor,
        head_dim: int,
        dropout_p: float = 0,
        softmax_scale: float = 0,
        causal: bool = True,
        layer_id: int = 0,
) -> Tuple[Tensor, Tensor]:
    
    bs, sl, d = x.shape
    out = torch.empty([bs, sl, d//head_dim, head_dim], dtype=x.dtype, device=x.device)
    lse = torch.empty([bs, d//head_dim, sl], dtype=x.dtype, device=x.device)
    return out, lse
    

class FullyFusedAttnGQAFunc(torch.autograd.Function):
    
    @staticmethod
    # @torch.no_grad()
    def forward(
        ctx,
        x: Tensor,
        wq: Tensor,
        wk: Tensor,
        wv: Tensor,
        freqs_cis: Tensor,
        head_dim: int,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        ring_group,
        ulysses_group,
        offload_stream,
        fetch_stream,
        two_streams,
        attn_type,
        layer_id,
    ) -> Tuple[Tensor, Tensor, None] | Tensor | Dict[str, Any]: 
        
        bs, seqlen, hid_dim = x.shape
        n_heads = hid_dim // head_dim

        proj_dim = wq.shape[1]

        ulysses_degree = dist.get_world_size(ulysses_group)
        pipe_degree = n_heads // ulysses_degree

        # q_out, k_out, v_out = [None] * pipe_degree, [None] * pipe_degree, [None] * pipe_degree
        # final_out = torch.zeros([bs, seqlen, n_heads, head_dim], device=x.device, dtype=x.dtype)

        # final_lse = []

        if softmax_scale is None:
            softmax_scale = head_dim ** (-0.5)
        
        # global variables
        global process_group
        import sys
        current_module = sys.modules[__name__]
        current_module.process_group = ring_group
        current_module.attn_type = attn_type
        current_module.alibi_slopes = alibi_slopes
        current_module.window_size = window_size
        current_module.ulysses_group = ulysses_group
        current_module.two_streams = two_streams

        # out_event = torch.cuda.Event()
        # assert not check_nan_inf(x, "x", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in x"
        # assert not check_nan_inf(wq, "wq", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wq"
        # assert not check_nan_inf(wk, "wk", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wk"
        # assert not check_nan_inf(wv, "wv", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wv"
        # assert not check_nan_inf(freqs_cis, "freqs_cis", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in freqs_cis"

        with torch.no_grad():
            
            output = fully_fused_attn_gqa_forward(x, 
                                                                wq, 
                                                                wk, 
                                                                wv, 
                                                                freqs_cis, 
                                                                head_dim, 
                                                                dropout_p, 
                                                                softmax_scale, 
                                                                causal,
                                                                layer_id)
            final_out = output[0]
            final_lse = output[1:]
            
        # assert not check_nan_inf(final_out, "final_out", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_out"
        # assert not check_nan_inf(final_lse, "final_lse", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_lse"
        # out_event.record()
        # with torch.cuda.stream(offload_stream):
        #     out_event.wait()
        #     # cpu_out = torch.empty(final_out.shape, dtype=final_out.dtype, device="cpu", pin_memory=True)
        #     # cpu_out.copy_(final_out, non_blocking=True)
        #     cpu_out = final_out.to("cpu", non_blocking=True)

        saved_tensor_list = [x, wq, wk, wv, freqs_cis, final_out]
        saved_tensor_list.extend(final_lse)
        

        ctx.save_for_backward(*saved_tensor_list)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.ring_group = ring_group
        ctx.ulysses_group = ulysses_group
        ctx.attn_type = attn_type
        ctx.two_streams = two_streams
        ctx.pipe_degree = pipe_degree
        ctx.layer_id = layer_id

        # if layer_id == 31 and max(final_out.shape) * torch.distributed.get_world_size(ulysses_group) * torch.distributed.get_world_size(ring_group) == 4194304:
        #     torch.cuda.empty_cache()
        
        return final_out if not return_softmax else (final_out, final_lse, None)
        
    
    @staticmethod
    # @torch.no_grad()
    def backward(ctx, dout, *args) -> Tuple[Tensor, Tensor, Tensor, Tensor, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]:

        # torch.cuda.cudart().cudaProfilerStart()
        fn = select_flash_attn_impl(ctx.attn_type, stage="bwd-only")

        saved_tensors = ctx.saved_tensors
        x, wq, wk, wv, freqs_cis, final_out, *final_lse = saved_tensors

        freqs_cis = freqs_cis.to(x.device)

        layer_id = ctx.layer_id

        # assert not check_nan_inf(final_out, "final_out", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_out"
        # assert not check_nan_inf(final_lse, "final_lse", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in final_lse"
        # assert not check_nan_inf(x, "x", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in x"
        # assert not check_nan_inf(wq, "wq", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wq"
        # assert not check_nan_inf(wk, "wk", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wk"
        # assert not check_nan_inf(wv, "wv", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in wv"
        # assert not check_nan_inf(freqs_cis, "freqs_cis", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in freqs_cis"
        # assert not check_nan_inf(dout, "dout", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in dout"
        
        bs, seqlen, n_heads, head_dim = final_out.shape
        hidden_dim = n_heads * head_dim
        ulysses_degree = dist.get_world_size(ctx.ulysses_group)
        pipe_degree = n_heads // ulysses_degree
        n_kv_heads = wk.shape[0] // head_dim
        gqa_ratio = n_heads // n_kv_heads

        # dx = torch.zeros_like(x)
        # dx = torch.empty_like(x)
        dx = None
        # dwq = []
        # dwk = []
        # dwv = []
        dwq = torch.empty_like(wq)
        dwk = torch.empty_like(wk)
        dwv = torch.empty_like(wv)

        proj_dim = wq.shape[1]

        two_streams = ctx.two_streams

        q_in, k_in, v_in = [None] * pipe_degree, [None] * (pipe_degree//gqa_ratio), [None] * (pipe_degree//gqa_ratio)
        attn_dq, attn_dk, attn_dv = [None] * pipe_degree, [None] * pipe_degree, [None] * pipe_degree
        # attn_dq = [torch.empty([bs, seqlen*ulysses_degree, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        # attn_dk = [torch.empty([bs, seqlen*ulysses_degree, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        # attn_dv = [torch.empty([bs, seqlen*ulysses_degree, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        dq_out, dk_out, dv_out = [None] * pipe_degree, [None] * (pipe_degree), [None] * (pipe_degree)
        dk_debug = [None] * pipe_degree
        dq_debug = [None] * pipe_degree

        a2a_events = [torch.cuda.Event() for _ in range(pipe_degree)]

        # wt_idx = []
        # for stage in range(pipe_degree):
        #     if stage==0 or stage//gqa_ratio != (stage-1)//gqa_ratio:
        #         stage_idx = [(stage + i)*gqa_ratio for i in range(ulysses_degree)]
        #     else:
        #         stage_idx = [idx+1 for idx in stage_idx]

        #     for si in stage_idx:
        #         wt_idx.extend(range(si*head_dim, (si+1)*head_dim))
        
        # wq = wq[wt_idx, :]
        # wk = wk[wt_idx, :]
        # wv = wv[wt_idx, :]
        
        wq_chunks = torch.chunk(wq, pipe_degree, dim = 0) #0 is the output dimension
        wk_chunks = torch.chunk(wk, pipe_degree//gqa_ratio, dim = 0)
        wv_chunks = torch.chunk(wv, pipe_degree//gqa_ratio, dim = 0)

        q_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        k_out = [None for _ in range(pipe_degree//gqa_ratio)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        v_out = [None for _ in range(pipe_degree//gqa_ratio)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        out_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=x.dtype) for _ in range(pipe_degree)]
        dout_out = [None for _ in range(pipe_degree)] #[torch.empty([ulysses_degree, seqlen, bs, (n_heads//ulysses_degree)//pipe_degree, head_dim], device=x.device, dtype=dout.dtype) for _ in range(pipe_degree)]

        # final_out_idx = []
        # for stage in range(pipe_degree):
        #     if stage==0 or stage//gqa_ratio != (stage-1)//gqa_ratio:
        #             stage_idx = [(stage + i)*gqa_ratio for i in range(ulysses_degree)]
        #     else:
        #         stage_idx = [idx+1 for idx in stage_idx]
        #     final_out_idx.extend(stage_idx)

        # final_out = final_out[:, :, final_out_idx, :]
        # dout = dout[:, :, final_out_idx, :]

        final_out = list(torch.chunk(final_out, pipe_degree, dim = 2))
        # final_lse = list(torch.chunk(final_lse, pipe_degree, dim = 1))
        dout = list(torch.chunk(dout, pipe_degree, dim = 2))


        # #Hoisting QKV Projection outside the loop
        # for stage in range(pipe_degree):
        #     q_in[stage] = F.linear(x, wq_chunks[stage])
        #     if stage==0 or stage//gqa_ratio > (stage-1)//gqa_ratio:
        #         k_in[(stage)//gqa_ratio] = F.linear(x, wk_chunks[(stage)//gqa_ratio])
        #         v_in[(stage)//gqa_ratio] = F.linear(x, wv_chunks[(stage)//gqa_ratio])

        #         q_in[stage], k_in[(stage)//gqa_ratio] = apply_rotary_emb(xq = q_in[stage].view(bs, seqlen, -1, head_dim), xk = k_in[(stage)//gqa_ratio].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
        #         v_in[(stage)//gqa_ratio] = v_in[(stage)//gqa_ratio].view(bs, seqlen, -1, head_dim)
        #     else:
        #         q_in[stage] = apply_rotary_emb(xq = q_in[stage].view(bs, seqlen, -1, head_dim), xk = None, freqs_cis=freqs_cis)

        num_streams = len(two_streams)
        for i in range(num_streams):
            two_streams[i].wait_stream(torch.cuda.current_stream())

        for stage in range(pipe_degree):
            if stage == 0 or len(two_streams) == 1:
                with torch.cuda.stream(two_streams[0]):

                    q_in[stage] = F.linear(x, wq_chunks[stage])
                    if stage==0 or stage//gqa_ratio > (stage-1)//gqa_ratio:
                        
                        k_in[(stage)//gqa_ratio] = F.linear(x, wk_chunks[(stage)//gqa_ratio])
                        v_in[(stage)//gqa_ratio] = F.linear(x, wv_chunks[(stage)//gqa_ratio])

                        q_in[stage], k_in[(stage)//gqa_ratio] = apply_rotary_emb(xq = q_in[stage].view(bs, seqlen, -1, head_dim), xk = k_in[(stage)//gqa_ratio].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                        v_in[(stage)//gqa_ratio] = v_in[(stage)//gqa_ratio].view(bs, seqlen, -1, head_dim)

                        q_out[stage] = all_to_all_4D(q_in[stage], 2, 1, False, False)#, output=q_out[stage+1])
                        k_out[(stage)//gqa_ratio] = all_to_all_4D(k_in[(stage)//gqa_ratio], 2, 1, False, False)#, output=k_out[stage+1])
                        v_out[(stage)//gqa_ratio] = all_to_all_4D(v_in[(stage)//gqa_ratio], 2, 1, False, False)#, output=v_out[stage+1])

                        out_out[stage] = all_to_all_4D(final_out[stage], 2, 1, False, False)#, output=out_out[stage+1])
                        dout_out[stage] = all_to_all_4D(dout[stage], 2, 1, False, False)#, output=dout_out[stage+1])

                        a2a_events[stage].record()

                        q_in[stage] = None
                        k_in[(stage)//gqa_ratio] = None
                        v_in[(stage)//gqa_ratio] = None

                        final_out[stage] = None
                        dout[stage] = None
                    else:
                        q_in[stage] = apply_rotary_emb(xq = q_in[stage].view(bs, seqlen, -1, head_dim), xk = None, freqs_cis=freqs_cis)
                        
                        q_out[stage] = all_to_all_4D(q_in[stage], 2, 1, False, False)#, output=q_out[stage+1])
                        out_out[stage] = all_to_all_4D(final_out[stage], 2, 1, False, False)#, output=out_out[stage+1])
                        dout_out[stage] = all_to_all_4D(dout[stage], 2, 1, False, False)#, output=dout_out[stage+1])

                        a2a_events[stage].record()

                        q_in[stage] = None
                        final_out[stage] = None
                        dout[stage] = None

            if stage != pipe_degree - 1 and len(two_streams) > 1:
                with torch.cuda.stream(two_streams[(stage+1)%num_streams]):
                    a2a_events[stage].wait()
                    
                    q_in[stage+1] = F.linear(x, wq_chunks[stage+1])
                    if (stage+1)//gqa_ratio > stage//gqa_ratio:
                        
                        k_in[(stage+1)//gqa_ratio] = F.linear(x, wk_chunks[(stage+1)//gqa_ratio])
                        v_in[(stage+1)//gqa_ratio] = F.linear(x, wv_chunks[(stage+1)//gqa_ratio])

                        q_in[stage+1], k_in[(stage+1)//gqa_ratio] = apply_rotary_emb(xq = q_in[stage+1].view(bs, seqlen, -1, head_dim), xk = k_in[(stage+1)//gqa_ratio].view(bs, seqlen, -1, head_dim), freqs_cis=freqs_cis)
                        v_in[(stage+1)//gqa_ratio] = v_in[(stage+1)//gqa_ratio].view(bs, seqlen, -1, head_dim)

                        q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])
                        k_out[(stage+1)//gqa_ratio] = all_to_all_4D(k_in[(stage+1)//gqa_ratio], 2, 1, False, False)#, output=k_out[stage+1])
                        v_out[(stage+1)//gqa_ratio] = all_to_all_4D(v_in[(stage+1)//gqa_ratio], 2, 1, False, False)#, output=v_out[stage+1])

                        out_out[stage+1] = all_to_all_4D(final_out[stage+1], 2, 1, False, False)#, output=out_out[stage+1])
                        dout_out[stage+1] = all_to_all_4D(dout[stage+1], 2, 1, False, False)#, output=dout_out[stage+1])

                        a2a_events[stage+1].record()

                        q_in[stage+1] = None
                        k_in[(stage+1)//gqa_ratio] = None
                        v_in[(stage+1)//gqa_ratio] = None

                        final_out[stage+1] = None
                        dout[stage+1] = None
                    else:
                        q_in[stage+1] = apply_rotary_emb(xq = q_in[stage+1].view(bs, seqlen, -1, head_dim), xk = None, freqs_cis=freqs_cis)
                        
                        q_out[stage+1] = all_to_all_4D(q_in[stage+1], 2, 1, False, False)#, output=q_out[stage+1])
                        out_out[stage+1] = all_to_all_4D(final_out[stage+1], 2, 1, False, False)#, output=out_out[stage+1])
                        dout_out[stage+1] = all_to_all_4D(dout[stage+1], 2, 1, False, False)#, output=dout_out[stage+1])

                        a2a_events[stage+1].record()

                        q_in[stage+1] = None
                        final_out[stage+1] = None
                        dout[stage+1] = None

            
            with torch.cuda.stream(two_streams[stage%num_streams]):
                a2a_events[stage].wait()
                # if int(layer_id) == 30 and torch.distributed.get_rank() == 0:
                #     breakpoint()
                # torch.distributed.barrier()
                k_out[stage//gqa_ratio] = k_out[stage//gqa_ratio].contiguous()
                v_out[stage//gqa_ratio] = v_out[stage//gqa_ratio].contiguous()

                attn_dq[stage], attn_dk[stage], attn_dv[stage] = fully_fused_ring_flash_attn_backward(
                    ctx.ring_group,
                    dout_out[stage],
                    q_out[stage],
                    k_out[stage//gqa_ratio],
                    v_out[stage//gqa_ratio],
                    out_out[stage],
                    final_lse[stage],
                    softmax_scale=ctx.softmax_scale,
                    dropout_p=ctx.dropout_p,
                    causal=ctx.causal,
                    window_size=ctx.window_size,
                    softcap=ctx.softcap,
                    alibi_slopes=ctx.alibi_slopes,
                    deterministic=ctx.deterministic,
                    attn_type=ctx.attn_type,
                )

                # assert not check_nan_inf(attn_dq[stage], "attn_dq[stage]", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in attn_dq[stage]"
                # assert not check_nan_inf(attn_dk[stage], "attn_dk[stage]", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in attn_dk[stage]"
                # assert not check_nan_inf(attn_dv[stage], "attn_dv[stage]", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in attn_dv[stage]"

                if os.environ.get("DEBUG_MODE", "0") == "1" and torch.distributed.get_rank() == 0:
                    dk_debug[stage] = torch.clone(attn_dk[stage].to(torch.float32))
                    dq_debug[stage] = torch.clone(attn_dq[stage])
                    breakpoint()
                if os.environ.get("DEBUG_MODE", "0") == "1":
                    torch.distributed.barrier()

                dq_out[stage] = all_to_all_4D(attn_dq[stage], 1, 2, False, False)#.view(bs, seqlen, -1)
                # dk_out[stage] = all_to_all_4D(attn_dk[stage], 1, 2, False, False).view(bs, seqlen, -1)
                # dv_out[stage] = all_to_all_4D(attn_dv[stage], 1, 2, False, False).view(bs, seqlen, -1)

                # if torch.distributed.get_rank() == 0:
                #     breakpoint()
                # torch.distributed.barrier()

                dq_out[stage] = apply_rotary_emb(dq_out[stage], freqs_cis=torch.conj(freqs_cis))
                dq_out[stage] = dq_out[stage].view(bs, seqlen, -1)

                if dk_out[stage//gqa_ratio] is None:
                    dk_out[stage//gqa_ratio] = attn_dk[stage]#dk_out[stage]
                    dv_out[stage//gqa_ratio] = attn_dv[stage]#dv_out[stage]
                else:
                    dk_out[stage//gqa_ratio].add_(attn_dk[stage])#dk_out[stage]
                    dv_out[stage//gqa_ratio].add_(attn_dv[stage])#dv_out[stage]
                # assert not check_nan_inf(dk_out[stage//gqa_ratio], f"dk_out[stage//gqa_ratio] stage {stage} Layer {layer_id}", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in dk_out[stage//gqa_ratio]"
                # assert not check_nan_inf(dv_out[stage//gqa_ratio], f"dv_out[stage//gqa_ratio] stage {stage} Layer {layer_id}", torch.distributed.get_rank()), f"Layer {layer_id} NaN detected in dv_out[stage//gqa_ratio]"


                q_out[stage] = None
                if (stage+1)//gqa_ratio != stage//gqa_ratio:
                    k_out[stage//gqa_ratio] = None
                    v_out[stage//gqa_ratio] = None
                out_out[stage] = None
                dout_out[stage] = None
                final_lse[stage] = None

                attn_dq[stage] = None
                attn_dk[stage] = None
                attn_dv[stage] = None
                # torch.cuda.empty_cache()

                # K/V at group boundary
                if dx is None:
                    dx = dq_out[stage].view(bs * seqlen, -1) @ wq_chunks[stage]
                else:
                    dx.addmm_(dq_out[stage].view(bs * seqlen, -1), wq_chunks[stage], alpha=1.0, beta=1.0)
                if (stage+1)//gqa_ratio != stage//gqa_ratio or stage == pipe_degree - 1:
                    
                    dk_out[stage//gqa_ratio] = all_to_all_4D(dk_out[stage//gqa_ratio], 1, 2, False, False)
                    dv_out[stage//gqa_ratio] = all_to_all_4D(dv_out[stage//gqa_ratio], 1, 2, False, False)


                    dk = apply_rotary_emb(dk_out[stage//gqa_ratio].view(bs, seqlen, -1, head_dim), freqs_cis=torch.conj(freqs_cis)).view(bs * seqlen, -1)
                    dv = dv_out[stage//gqa_ratio].view(bs * seqlen, -1)

                    dx.addmm_(dk, wk_chunks[stage//gqa_ratio], alpha=1.0, beta=1.0)
                    dx.addmm_(dv, wv_chunks[stage//gqa_ratio], alpha=1.0, beta=1.0)
                    
                    # dwk.append(dk.T @ x.view(bs * seqlen, -1))                                  
                    # dwv.append(dv.T @ x.view(bs * seqlen, -1))
                    dwk[(stage//gqa_ratio)*(head_dim*ulysses_degree):((stage+1)//gqa_ratio)*(head_dim*ulysses_degree), :] = dk.T @ x.view(bs * seqlen, -1)
                    dwv[(stage//gqa_ratio)*(head_dim*ulysses_degree):((stage+1)//gqa_ratio)*(head_dim*ulysses_degree), :] = dv.T @ x.view(bs * seqlen, -1)
                
                # if dx is None:
                #     dx = dx_temp.view(bs, seqlen, -1)
                # else:
                #     dx += dx_temp.view(bs, seqlen, -1)

                # Q
                dq = dq_out[stage].view(bs * seqlen, -1)
                # dwq.append(dq.T @ x.view(bs * seqlen, -1))
                dwq[stage*(head_dim*ulysses_degree):(stage+1)*(head_dim*ulysses_degree), :] = dq.T @ x.view(bs * seqlen, -1)
                     

                dq_out[stage] = None
                if (stage+1)//gqa_ratio != stage:
                    dk_out[stage] = None
                    dv_out[stage] = None

                if (stage+1)//gqa_ratio != stage//gqa_ratio:
                    dk_out[stage//gqa_ratio] = None
                    dv_out[stage//gqa_ratio] = None

                # torch.cuda.empty_cache()
        
        for i in range(len(two_streams)):
            torch.cuda.current_stream().wait_stream(two_streams[i])
        # torch.cuda.empty_cache()

        # dwq = torch.cat(dwq, dim = 0)
        # if len(dwk) > 0:
        #     dwk = torch.cat(dwk, dim = 0)
        #     dwv = torch.cat(dwv, dim = 0)
        # else:
        #     dwk = dwk[0]
        #     dwv = dwv[0]
        # torch.cuda.cudart().cudaProfilerStop()
        if dwk.shape[0]//head_dim > n_kv_heads:
            n_rep = (dwk.shape[0] // head_dim) // n_kv_heads
            
            dwk = dwk.unsqueeze(0).unsqueeze(0).view(n_kv_heads, n_rep, head_dim, hidden_dim)
            dwk[:, 0, :, :] = dwk.sum(dim = 1)
            for i in range(1, n_rep):
                dwk[:, i, :, :] = dwk[:, 0, :, :]
            dwk = dwk.view(-1, hidden_dim)
            
            dwv = dwv.unsqueeze(0).unsqueeze(0).view(n_kv_heads, n_rep, head_dim, hidden_dim)
            dwv[:, 0, :, :] = dwv.sum(dim = 1)
            for i in range(1, n_rep):
                dwv[:, i, :, :] = dwv[:, 0, :, :]
            dwv = dwv.view(-1, hidden_dim)

        return dx.view(bs, seqlen, -1).to(x.dtype), dwq, dwk, dwv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None

def fully_fused_attn_func(
        x,
        wq,
        wk,
        wv,
        freqs_cis,
        head_dim,
        dropout_p=0.0,
        softmax_scale=None,
        causal=False,
        window_size=(-1, -1),
        softcap=0.0,
        alibi_slopes=None,
        deterministic=False,
        return_attn_probs=False,
        ring_group=None,
        ulysses_group=None,
        offload_stream=None,
        fetch_stream=None,
        two_streams=None,
        attn_type: AttnType = AttnType.FA3,
        attn_processor=None,
        use_pack_qkv=True,
        dualstage=False,
        layer_id=None,
        fused_attn_type="fully_fused",
        w_rms=None,
        eps_rms=None,
):
    
    if "mha" in fused_attn_type:#wq.shape[0] == wk.shape[0] and wq.shape[1] == wk.shape[1]:
        # raise NotImplementedError("MHA is no longer here. Refer to fully_pipelined_attn_backup.py")
        output = FullyFusedAttnFunc.apply(
                x,
                wq,
                wk,
                wv,
                freqs_cis,
                head_dim,
                dropout_p,
                softmax_scale,
                causal,
                window_size,
                softcap,
                alibi_slopes,
                deterministic,
                return_attn_probs,
                ring_group,
                ulysses_group,
                offload_stream,
                fetch_stream,
                two_streams,
                attn_type,
                layer_id,
            )
    else:
        if "ultra_fused" in fused_attn_type:
            raise NotImplementedError("UltraFused is no longer here. Refer to fully_pipelined_attn_backup.py")
        else:
            output = FullyFusedAttnGQAFunc.apply(
                x,
                wq,
                wk,
                wv,
                freqs_cis,
                head_dim,
                dropout_p,
                softmax_scale,
                causal,
                window_size,
                softcap,
                alibi_slopes,
                deterministic,
                return_attn_probs,
                ring_group,
                ulysses_group,
                offload_stream,
                fetch_stream,
                two_streams,
                attn_type,
                layer_id,
            )
    
    return output
