import torch
from .utils import RingComm, update_out_and_lse
from yunchang.kernels import AttnType, select_flash_attn_impl
from typing import Tuple

from yunchang.ring.dist_flash_utils import maybe_send_recv_fwd_qkvo, maybe_send_recv_bwd_qkvo, maybe_send_recv_bwd_last_dkv, maybe_get_set_global_memory_buffer, is_compute_for_local_query, is_idle, is_sync_from_remote, wait_async_handles, is_last_time

global process_group, attn_type, alibi_slopes, window_size

def dist_flash_attn_forward_balanced_custom(
    process_group,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale,
    dropout_p=0,
    causal=True,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    attn_type: AttnType = AttnType.FA,
    buff_q=None,
    buff_k=None,
    buff_v=None,
    buff_l=None,
    buff_o=None,
):
    
    comm = RingComm(process_group)
    P = comm.world_size  # Number of processes
    p = comm.rank        # Current process rank
    
    def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
        """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
        return x
        bs, slen, n_kv_heads, head_dim = x.shape
        if n_rep == 1:
            return x
        return (
            torch.unsqueeze(x, dim=3)
            .expand(bs, slen, n_kv_heads, n_rep, head_dim)
            .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
        )

    def attn_compute(q_input, k_input, v_input, causal, out_prev=None, lse_prev=None):
        fn = select_flash_attn_impl(attn_type, stage="fwd-only")
        
        # Handle GQA/MQA by repeating k,v if needed
        if k_input.shape[2] != q_input.shape[2]:
            k_input = repeat_kv(k_input, q_input.shape[2] // k_input.shape[2])
            v_input = repeat_kv(v_input, q_input.shape[2] // v_input.shape[2])
            
        block_out, block_lse = fn(
            q_input,
            k_input,
            v_input,
            dropout_p,
            softmax_scale,
            causal=causal,
            window_size=window_size,
            softcap=softcap,
            alibi_slopes=alibi_slopes,
            return_softmax=True and dropout_p > 0,
        )
        
        # if out_prev is None:
        #     return block_out, block_lse
        # else:
        #     # assert False, f"out_prev.shape: {out_prev.shape}, block_out.shape: {block_out.shape}, lse_prev.shape: {lse_prev.shape}, block_lse.shape: {block_lse.shape}"
        return update_out_and_lse(out_prev, lse_prev, block_out, block_lse)

    def send_to_rank(tensor, dest_rank):
        """Send tensor to specific rank"""
        global_dest_rank = torch.distributed.get_global_rank(process_group, dest_rank)
        req = torch.distributed.isend(tensor.contiguous(), dst=global_dest_rank, group=process_group)
        return req

    def recv_from_rank(tensor_shape, src_rank, recv_buffer, dtype=None, device=None):
        """Receive tensor from specific rank"""
        if dtype is None:
            dtype = q.dtype
        if device is None:
            device = q.device
        
        # recv_tensor = torch.empty(tensor_shape, dtype=dtype, device=device)
        global_src_rank = torch.distributed.get_global_rank(process_group, src_rank)
        req = torch.distributed.irecv(recv_buffer, src=global_src_rank, group=process_group)
        return req
    
    reqs_handles = []

    bs, nh, seq_len, hdim = q.shape

    # Allocate buffers if not provided, 2 for send/recv
    if buff_q is None:
        buff_q = [torch.empty(q.shape, dtype=q.dtype, device=q.device)]*2
    if buff_k is None:
        buff_k = [torch.empty(k.shape, dtype=k.dtype, device=k.device)]*2
    if buff_v is None:
        buff_v = [torch.empty(v.shape, dtype=v.dtype, device=v.device)]*2
    if buff_l is None:
        buff_l = [torch.empty([bs, nh, seq_len, 1], dtype=torch.float32, device=q.device)]*2
    if buff_o is None:
        buff_o = [torch.empty(q.shape, dtype=torch.float32, device=q.device)]*2


    for time_step in range(P // 2 + 1):

        torch.cuda.synchronize()
        
        # Handling all complicated stuff
        if time_step < P//2 - 1:
            if(p == P-time_step-1): #if I am the rank P - time_step - 1, I send my query to the rank 0
                reqs_handles.append(send_to_rank(q, 0))
            elif(p == 0): # if I am rank 0, I receive query from the rank P - time_step - 1
                reqs_handles.append(recv_from_rank(q.shape, P-time_step-1, buff_q[1]))
            elif(p < time_step): # if my rank is < time_step, I send the buff_q to my next rank
                reqs_handles.append(send_to_rank(buff_q[0], p+1))

        # Handling all simple stuff
        if(time_step == 0): # at time_step = 0, all perform their local attention, and send their K,V to the next rank as long as my rank is < P-1
            if(p < P-1):
                reqs_handles.append(send_to_rank(k, p+1))
                reqs_handles.append(send_to_rank(v, p+1))
            if(p > 0):
                reqs_handles.append(recv_from_rank(k.shape, p-1, buff_k[1]))
                reqs_handles.append(recv_from_rank(v.shape, p-1, buff_v[1]))
            out, lse = attn_compute(q, k, v, causal=True)
        else:
            # handling KV comms
            if(p < P-1):
                reqs_handles.append(send_to_rank(buff_k[0], p+1))
                reqs_handles.append(send_to_rank(buff_v[0], p+1))
            if(p > time_step):
                reqs_handles.append(recv_from_rank(k.shape, p-1, buff_k[1]))
                reqs_handles.append(recv_from_rank(v.shape, p-1, buff_v[1]))

            #handling Out, LSE comms
            if(p < P//2 - 1 and p < time_step - 1): #if my rank < time_step - 1, I have done someone else's work (P - time_step + p to be precise), so better send it to them
                reqs_handles.append(send_to_rank(buff_o[0], P-time_step+1+p))
                reqs_handles.append(send_to_rank(buff_l[0], P-time_step+1+p))
            elif(p > P//2 + 1 and time_step > P - p): # if time_step has passed the point ie, time_step > P - p, I have already sent my query and need to grab the output from them (time_step - (P - p)) to reconsolidate
                reqs_handles.append(recv_from_rank(out.shape, time_step - (P - p) - 1, buff_o[1]))
                reqs_handles.append(recv_from_rank(lse.shape, time_step - (P - p) - 1, buff_l[1]))
                # out, lse = update_out_and_lse(out, lse, buff_o[1], buff_l[1])
            
            #handling attn
            if(p>=time_step): #normal attn
                buff_k[0] = buff_k[1]
                buff_v[0] = buff_v[1]
                out, lse = attn_compute(q, buff_k[0], buff_v[0], causal=False, out_prev=out, lse_prev=lse)
            else: #attn for someone else (the more busy workers)
                buff_q[0] = buff_q[1]
                buff_o[0], buff_l[0] = attn_compute(buff_q[0], k, v, causal=False)
        
        #sync'em all
        for rh in reqs_handles:
            rh.wait()
        reqs_handles = []

        if(p > P//2 + 1 and time_step > P - p):
            out, lse = update_out_and_lse(out, lse, buff_o[1], buff_l[1])
    
    out = out.to(q.dtype)
    lse = lse.squeeze(dim=-1).transpose(1, 2)
    return out, lse

@torch.library.custom_op("yunchang::_dist_flash_attn_forward_balanced", mutates_args=(), device_types="cuda")
def dist_flash_attn_forward_balanced(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: float,
    dropout_p: float=0,
    causal: bool=True,
    softcap: float=0.0,
    deterministic: bool=False
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    global process_group, attn_type, alibi_slopes, window_size

    comm = RingComm(process_group)
    P = comm.world_size  # Number of processes
    p = comm.rank        # Current process rank
    
    # def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    #     """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    #     return x
    #     bs, slen, n_kv_heads, head_dim = x.shape
    #     if n_rep == 1:
    #         return x
    #     return (
    #         torch.unsqueeze(x, dim=3)
    #         .expand(bs, slen, n_kv_heads, n_rep, head_dim)
    #         .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    #     )

    def attn_compute(q_input, k_input, v_input, causal, out_prev=None, lse_prev=None):
        fn = select_flash_attn_impl(attn_type, stage="fwd-only")
        
        # Handle GQA/MQA by repeating k,v if needed
        # if k_input.shape[2] != q_input.shape[2]:
        #     k_input = repeat_kv(k_input, q_input.shape[2] // k_input.shape[2])
        #     v_input = repeat_kv(v_input, q_input.shape[2] // v_input.shape[2])
            
        block_out, block_lse = fn(
            q_input,
            k_input,
            v_input,
            dropout_p,
            softmax_scale,
            causal=causal,
            window_size=window_size,
            softcap=softcap,
            alibi_slopes=alibi_slopes,
            return_softmax=True and dropout_p > 0,
        )
        # if out_prev is None:
        #     return block_out, block_lse
        # else:
        #     # assert False, f"out_prev.shape: {out_prev.shape}, block_out.shape: {block_out.shape}, lse_prev.shape: {lse_prev.shape}, block_lse.shape: {block_lse.shape}"
        return update_out_and_lse(out_prev, lse_prev, block_out, block_lse)

    def launch_async_handles(handles):
        if len(handles) > 0:
            return dist.batch_isend_irecv(handles)
        return []

    bs, nh, seq_len, hdim = q.shape

    # Allocate buffers if not provided, 2 for send/recv
    # if peer_q is None:
    peer_q = [torch.empty(q.shape, dtype=q.dtype, device=q.device, layout=q.layout)]*2
    # if peer_k is None:
    peer_k = [torch.empty(k.shape, dtype=k.dtype, device=k.device, layout=k.layout)]*2
    # if peer_v is None:
    peer_v = [torch.empty(v.shape, dtype=v.dtype, device=v.device, layout=v.layout)]*2
    # if peer_l is None:
    peer_l = [torch.empty([bs, nh, seq_len, 1], dtype=torch.float32, device=q.device)]*2
    # if peer_o is None:
    peer_o = [torch.empty(q.shape, dtype=torch.float32, device=q.device, layout=q.layout)]*2


    all_handles = []

    for time_step in range(P // 2 + 1):
        # This is important for cuda scheduler to execute nccl calls first.
        torch.cuda.synchronize()
        # Communication uses buffer_idx_1, and compute uses buffer_idx_2, which effectively are contents from the last time step.
        buffer_idx_1 = time_step % 2
        buffer_idx_2 = (time_step - 1) % 2

        reqs = maybe_send_recv_fwd_qkvo(process_group, 
                                        q, 
                                        peer_q[buffer_idx_1], 
                                        k, 
                                        peer_k[buffer_idx_1], 
                                        v, 
                                        peer_v[buffer_idx_1], 
                                        [peer_o[buffer_idx_1], peer_l[buffer_idx_1]], 
                                        time_step)
        # wait_async_handles(reqs)
        if is_compute_for_local_query(process_group, time_step):
            # print(f"t={time_step}: (Comp) R={seq_rank} local compute")
            if time_step == 0:
                out, lse = attn_compute(q, k, v, causal=True)
                # out = out.to(q.dtype)
                # lse = lse.squeeze(dim=-1).transpose(1, 2)
                # return out, lse
            else:
                # if needs to sync from others, do not normalize here
                out, lse = attn_compute(q, peer_k[buffer_idx_2], peer_v[buffer_idx_2], causal=False, out_prev=out, lse_prev=lse)
                # fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), m, l, o, L, False, not is_sync_from_remote(process_group, time_step) and is_last_time(process_group, time_step))
        elif is_idle(process_group, time_step):
            # print(f"t={time_step}: (Comp) R={seq_rank} idle", flush=True)
            pass
        else:
            # print(f"t={time_step}: (Comp) R={seq_rank} helps other")
            # peer_l[buffer_idx_2] = torch.zeros_like(l)
            # peer_o[buffer_idx_2] = torch.zeros_like(o)

            #print(f"rank 3 q is: {peer_q[buffer_idx_2]}")
            # fwd_launch_helper(peer_q[buffer_idx_2], maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), peer_m[buffer_idx_2], peer_l[buffer_idx_2], peer_o[buffer_idx_2], None, False, False)
            # peer_l[buffer_idx_2] = torch.zeros_like(lse)
            # peer_o[buffer_idx_2] = torch.zeros_like(out)
            peer_o[buffer_idx_2], peer_l[buffer_idx_2] = attn_compute(peer_q[buffer_idx_2], k, v, causal=False)
        
        # Make sure tensors for next steps are ready
        wait_async_handles(reqs)
        
        # sync between statistics get from other ranks and the local ones
        # if is_sync_from_remote(process_group,time_step):
            # out, lse = update_out_and_lse(out, lse, peer_o[buffer_idx_1], peer_l[buffer_idx_1])
    
    out = out.to(q.dtype)
    lse = lse.squeeze(dim=-1).transpose(1, 2)
    return out, lse

@dist_flash_attn_forward_balanced.register_fake
def _(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: float,
    dropout_p: float=0,
    causal: bool=True,
    softcap: float=0.0,
    deterministic: bool=False
) -> Tuple[torch.Tensor, torch.Tensor]:
    bs, sl, nh, d = q.shape
    out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    lse = torch.empty([bs, sl, nh], dtype=q.dtype, device=q.device)
    return out, lse

def dist_flash_attn_forward(
    group,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale,
    dropout_p=0,
    causal=True,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    attn_type: AttnType = AttnType.FA,
):
    """
    Distributed flash attention with balanced algorithm
    """

    # global variables
    global process_group
    import sys
    current_module = sys.modules[__name__]
    current_module.process_group = group
    current_module.attn_type = attn_type
    current_module.alibi_slopes = alibi_slopes
    current_module.window_size = window_size
    
    return dist_flash_attn_forward_balanced(q,
                                            k,
                                            v,
                                            softmax_scale,
                                            dropout_p,
                                            causal,
                                            softcap,
                                            deterministic)
    
def dist_flash_attn_backward_balanced(
    process_group,
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    softmax_scale,
    dropout_p=0,
    causal=True,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    attn_type: AttnType = AttnType.FA,
):
    """
    Balanced DISTFLASHATTN backward implementation
    """
    comm = RingComm(process_group)
    P = comm.world_size  # Number of processes
    p = comm.rank        # Current process rank
    
    # Initialize gradient accumulators
    dq = None
    dk = None 
    dv = None
    
    # Allocate temporary buffers for gradient computation
    dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
    dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
    
    # def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    #     """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    #     return x
    #     bs, slen, n_kv_heads, head_dim = x.shape
    #     if n_rep == 1:
    #         return x
    #     return (
    #         torch.unsqueeze(x, dim=3)
    #         .expand(bs, slen, n_kv_heads, n_rep, head_dim)
    #         .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    #     )

    def backward(dout, q, k, v, out, softmax_lse, causal):
        seqlen_q = q.shape[1]
        seqlen_kv = k.shape[1]
        assert seqlen_q == seqlen_kv, f"seqlen_q: {seqlen_q}, seqlen_kv: {seqlen_kv}"
        fn = select_flash_attn_impl(attn_type, stage="bwd-only")

        # if k.shape[2] != q.shape[2]:
        #     k = repeat_kv(k, q.shape[2]//k.shape[2])
        #     v = repeat_kv(v, q.shape[2]//v.shape[2])

        fn(
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq_buffer,
            dk_buffer,
            dv_buffer,
            dropout_p,
            softmax_scale,
            causal,
            window_size,
            softcap,
            alibi_slopes,
            deterministic,
            rng_state=None,
        )

    dq_delta = [torch.empty_like(q) for _ in range(2)]
    dk_delta = [torch.empty_like(k) for _ in range(2)]
    dv_delta = [torch.empty_like(v) for _ in range(2)]

    dk_delta_from_peer = torch.empty_like(k)
    dv_delta_from_peer = torch.empty_like(v)

    # Allocate buffers if not provided, 2 for send/recv
    bs, nh, seq_len, hdim = q.shape
    peer_q = [torch.empty_like(q) for _ in range(2)]
    peer_k = [torch.empty_like(k) for _ in range(2)]
    peer_v = [torch.empty_like(v) for _ in range(2)]
    peer_l = [torch.empty([bs, nh, seq_len], dtype=softmax_lse.dtype, device=q.device)]*2
    peer_o = [torch.empty_like(out) for _ in range(2)]
    peer_dout = [torch.empty_like(dout) for _ in range(2)]
    
    # Step 2: Main loop for time_step from 1 to floor(P/2) - same pattern as forward
    for time_step in range(P // 2 + 1):
        # Step 3: r = (p - t) mod P
        torch.cuda.synchronize()
        buffer_idx_1 = time_step % 2
        buffer_idx_2 = (time_step - 1) % 2
        
        reqs, is_update_dq, is_update_dkv = maybe_send_recv_bwd_qkvo(process_group,
                                                                     dq_delta[buffer_idx_1], 
                                                                     dk_delta[buffer_idx_1], 
                                                                     dv_delta[buffer_idx_1], 
                                                                     dk_delta_from_peer, 
                                                                     dv_delta_from_peer, 
                                                                     q, 
                                                                     peer_q[buffer_idx_1], 
                                                                     softmax_lse,
                                                                     peer_l[buffer_idx_1], 
                                                                     k, 
                                                                     peer_k[buffer_idx_1], 
                                                                     v, 
                                                                     peer_v[buffer_idx_1], 
                                                                     out, 
                                                                     peer_o[buffer_idx_1], 
                                                                     dout, 
                                                                     peer_dout[buffer_idx_1], 
                                                                     time_step)
        
        if is_compute_for_local_query(process_group, time_step):
            if time_step == 0:
                backward(dout, q, k, v, out, softmax_lse, causal=True)
                dq = dq_buffer
                dk = dk_buffer
                dv = dv_buffer
                # if backward_engine == "flash":
                #     _flash_attn_backward(do, q, k, v, o, L, dq, dk, dv, 0.0, sm_scale, True, -1,-1, 100000000, None, False, None)
                # else:
                #     inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=xformers.ops.LowerTriangularMask(), p=0, scale=sm_scale)
                #     op_ctx = Context(lse=L, out=o, rng_state=None)
                #     # Let xformers dispatch the correct backend
                #     grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None)
                #     dq = grads.dq
                #     dk, dv = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv)
            else:
                backward(dout, q, peer_k[buffer_idx_2], peer_v[buffer_idx_2], out, softmax_lse, causal=False)
                dq_delta[buffer_idx_2] = dq_buffer
                dk_delta[buffer_idx_2] = dk_buffer
                dv_delta[buffer_idx_2] = dv_buffer
                dq += dq_delta[buffer_idx_2]
                # if backward_engine == "flash":
                #     _flash_attn_backward(do, q, peer_k[buffer_idx_2], peer_v[buffer_idx_2], o, L, dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], 0.0, sm_scale, False, -1,-1, 100000000, None, False, None)
                # else:
                #     inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], peer_k[buffer_idx_2]), value=maybe_repeat_kv_bwd(q.shape[2], peer_v[buffer_idx_2]), attn_bias=None, p=0, scale=sm_scale)
                #     op_ctx = Context(lse=L, out=o, rng_state=None)
                #     grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None)
                #     dq_delta[buffer_idx_2] = grads.dq
                #     dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv)
                # dq += dq_delta[buffer_idx_2]
        elif is_idle(process_group, time_step):
            pass
        else:
            backward(peer_dout[buffer_idx_2], peer_q[buffer_idx_2], k, v, peer_o[buffer_idx_2], peer_l[buffer_idx_2], causal=False)
            dq_delta[buffer_idx_2] = dq_buffer
            dk_delta[buffer_idx_2] = dk_buffer
            dv_delta[buffer_idx_2] = dv_buffer
            dk += dk_delta[buffer_idx_2]
            dv += dv_delta[buffer_idx_2]
            # if backward_engine == "flash":
            #     _flash_attn_backward(peer_do[buffer_idx_2], peer_q[buffer_idx_2], k, v, peer_o[buffer_idx_2], peer_L[buffer_idx_2], dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], 0.0, sm_scale, False, -1,-1, 100000000, None, False, None)
            # else:
            #     inp = Inputs(query=peer_q[buffer_idx_2], key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=None, p=0, scale=sm_scale)
            #     op_ctx = Context(lse=peer_L[buffer_idx_2], out=peer_o[buffer_idx_2], rng_state=None)
            #     grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=peer_do[buffer_idx_2], op=None)
            #     dq_delta[buffer_idx_2] = grads.dq
            #     dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv)
            # dk += dk_delta[buffer_idx_2]
            # dv += dv_delta[buffer_idx_2]
        
        
        wait_async_handles(reqs)

        reqs, is_update_last_dkv = maybe_send_recv_bwd_last_dkv(process_group, dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], time_step)

        # apply dq_delta, dk_delta and dv_delta from remote
        if is_update_dq:
            dq += dq_delta[buffer_idx_1]
        if is_update_dkv:
            dk += dk_delta_from_peer
            dv += dv_delta_from_peer
       
        wait_async_handles(reqs)
        # apply dk_delta and dv_delta to sender
        if is_update_last_dkv:
            dk += dk_delta[buffer_idx_2]
            dv += dv_delta[buffer_idx_2]
    
    return dq.to(q.dtype), dk.to(q.dtype), dv.to(q.dtype)

def dist_flash_attn_backward(
    process_group,
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    softmax_scale,
    dropout_p=0,
    causal=True,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    attn_type: AttnType = AttnType.FA,
):
    """
    Distributed flash attention backward with balanced algorithm
    """
    return dist_flash_attn_backward_balanced(process_group,
                                             dout,
                                             q,
                                             k,
                                             v,
                                             out,
                                             softmax_lse,
                                             softmax_scale,
                                             dropout_p,
                                             causal,
                                             window_size,
                                             softcap,
                                             alibi_slopes,
                                             deterministic,
                                             attn_type)

class DistFlashAttnFunc(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        group,
        attn_type,
    ):
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        assert alibi_slopes is None
        k = k.contiguous()
        v = v.contiguous()
        out, softmax_lse = dist_flash_attn_forward(
            group,
            q,
            k,
            v,
            softmax_scale=softmax_scale,
            dropout_p=dropout_p,
            causal=causal,
            window_size=window_size,
            softcap=softcap,
            alibi_slopes=alibi_slopes,
            deterministic=False,
            attn_type=attn_type,
        )
        # this should be out_padded
        ctx.save_for_backward(q, k, v, out, softmax_lse)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size = window_size
        ctx.softcap = softcap
        ctx.alibi_slopes = alibi_slopes
        ctx.deterministic = deterministic
        ctx.group = group
        ctx.attn_type = attn_type
        return out if not return_softmax else (out, softmax_lse, None)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse = ctx.saved_tensors
        dq, dk, dv = dist_flash_attn_backward(
            ctx.group,
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            softmax_scale=ctx.softmax_scale,
            dropout_p=ctx.dropout_p,
            causal=ctx.causal,
            window_size=ctx.window_size,
            softcap=ctx.softcap,
            alibi_slopes=ctx.alibi_slopes,
            deterministic=ctx.deterministic,
            attn_type=ctx.attn_type,
        )
        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None


def dist_flash_attn_func(
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    group=None,
    attn_type: AttnType = AttnType.FA,
    attn_processor=None,
):
    return DistFlashAttnFunc.apply(
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_attn_probs,
        group,
        attn_type,
    )