import torch
import os
import torch.nn as nn
import torch.nn.functional as F

from .utils import RingComm, update_out_and_lse
from yunchang.kernels import AttnType, select_flash_attn_impl
from yunchang.globals import U_HANDLE, clear_u_handle, clear_o_handle, PROCESS_GROUP
from yunchang.comm.all_to_all import vanilla_all_to_all_4D as all_to_all_4D, SeqAllToAll4D

import logging
logger = logging.getLogger(__name__)


from typing import List, Tuple, Dict, Any


global process_group, attn_type, alibi_slopes, window_size, next_ulysses_qkv, a2a_available, ulysses_group, two_streams

import torch.distributed as dist

def check_nan_inf(tensor, name, rank=0):
    """Check for NaN or Inf values in tensor"""
    if torch.isnan(tensor).any():
        print(f"[RANK {rank}] NaN detected in {name}, shape: {tensor.shape}")
        print(f"[RANK {rank}] NaN locations: {torch.isnan(tensor).sum().item()}")
        return True
    if torch.isinf(tensor).any():
        print(f"[RANK {rank}] Inf detected in {name}, shape: {tensor.shape}")
        print(f"[RANK {rank}] Inf locations: {torch.isinf(tensor).sum().item()}")
        return True
    return False

# @torch.library.custom_op("yunchang::_fully_fused_ring_flash_attn_forward", mutates_args=(), device_types="cuda")
# @torch.no_grad()
def fully_fused_ring_flash_attn_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: float,
    dropout_p: float = 0,
    causal: bool = True,
    softcap: float = 0.0,
    deterministic: bool=False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    global process_group, attn_type, alibi_slopes, window_size
    
    assert causal == True, "zigzag ring is meaningless for causal=False"
    comm = RingComm(process_group)

    block_seq_len = q.shape[1] // 2
    q1 = q[:, block_seq_len:]

    out = None
    lse = None
    next_k, next_v = None, None
    
    def forward(q, k, v, causal):
        fn = select_flash_attn_impl(attn_type, stage="fwd-only")
        block_out, block_lse = fn(
            q,
            k,
            v,
            dropout_p,
            softmax_scale,
            causal=causal,
            window_size=window_size,
            softcap=softcap,
            alibi_slopes=alibi_slopes,
            return_softmax=True and dropout_p > 0,
        )
        return block_out, block_lse

    for step in range(comm.world_size):
        if step + 1 != comm.world_size:
            next_k: torch.Tensor = comm.send_recv(k)
            next_v: torch.Tensor = comm.send_recv(v)
            comm.commit()

        if step == 0:
            block_out, block_lse = forward(q, k, v, causal=True)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)
        elif step <= comm.rank:
            k0 = k[:, :block_seq_len]
            v0 = v[:, :block_seq_len]
            block_out, block_lse = forward(q, k0, v0, causal=False)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)
        else:
            block_out, block_lse = forward(q1, k, v, causal=False)
            out, lse = update_out_and_lse(
                out,
                lse,
                block_out,
                block_lse,
                slice_=(slice(None), slice(block_seq_len, None)),
            )

        if step + 1 != comm.world_size:
            comm.wait()
            k = next_k
            v = next_v

    out = out.to(q.dtype)
    lse = lse.squeeze(dim=-1).transpose(1, 2)
    return out, lse

# @fully_fused_ring_flash_attn_forward.register_fake
# def _(
#     q: torch.Tensor,
#     k: torch.Tensor,
#     v: torch.Tensor,
#     softmax_scale: float,
#     dropout_p: float = 0,
#     causal: bool = True,
#     softcap: float = 0.0,
#     deterministic: bool=False,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
#     bs, sl, nh, d = q.shape
#     out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
#     lse = torch.empty([bs, nh, sl], dtype=q.dtype, device=q.device)
#     return out, lse

# @torch.no_grad()
def fully_fused_ring_flash_attn_backward(
    process_group,
    dout,
    q,
    k,
    v,
    out,
    softmax_lse,
    softmax_scale,
    dropout_p=0,
    causal=True,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    attn_type: AttnType = AttnType.FA,
):
    assert causal == True, "upipe ring is meaningless for causal=False"
    kv_comm = RingComm(process_group)
    d_kv_comm = RingComm(process_group)
    dq, dk, dv = None, None, None
    next_dk, next_dv = None, None
    next_k, next_v = None, None
    dk_comm_buffer, dv_comm_buffer = None, None

    dout1 = dout.chunk(2, dim=1)[1]
    q1 = q.chunk(2, dim=1)[1]
    out1 = out.chunk(2, dim=1)[1]
    softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous()
    block_seq_len = q.shape[1] // 2

    # repeatly allocating buffer may be slow...
    # if k.shape[2] == q.shape[2]:
    dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
    dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
    # else:
    #     dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
    #     dk_buffer = torch.empty((k.shape[0], k.shape[1], q.shape[2], k.shape[3]), dtype=k.dtype, device=k.device)
    #     dv_buffer = torch.empty((v.shape[0], v.shape[1], q.shape[2], v.shape[3]), dtype=v.dtype, device=v.device)

    def backward(dout, q, k, v, out, softmax_lse, causal):
        seqlen_q = q.shape[1]
        seqlen_kv = k.shape[1]
        fn = select_flash_attn_impl(attn_type, stage="bwd-only")

        fn(
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            dq_buffer[:, :seqlen_q],
            dk_buffer[:, :seqlen_kv],
            dv_buffer[:, :seqlen_kv],
            dropout_p,
            softmax_scale,
            causal,
            window_size,
            softcap,
            alibi_slopes,
            deterministic,
            rng_state=None,
        )

    for step in range(kv_comm.world_size):
        # assert not check_nan_inf(dout, f"dout -{step}", dist.get_rank()), f"Pipe: dout is nan or inf at step {step}"
        # assert not check_nan_inf(q, f"q -{step}", dist.get_rank()), f"Pipe: q is nan or inf at step {step}"
        # assert not check_nan_inf(k, f"k -{step}", dist.get_rank()), f"Pipe: k is nan or inf at step {step}"
        # assert not check_nan_inf(v, f"v -{step}", dist.get_rank()), f"Pipe: v is nan or inf at step {step}"
        # assert not check_nan_inf(out, f"out -{step}", dist.get_rank()), f"Pipe: out is nan or inf at step {step}"
        # assert not check_nan_inf(softmax_lse, f"softmax_lse -{step}", dist.get_rank()), f"Pipe: softmax_lse is nan or inf at step {step}"

        if step + 1 != kv_comm.world_size:
            next_k = kv_comm.send_recv(k)
            next_v = kv_comm.send_recv(v)
            kv_comm.commit()

        if step == 0:
            backward(dout.contiguous(), q.contiguous(), k.contiguous(), v.contiguous(), out.contiguous(), softmax_lse.contiguous(), causal=True)
            dq = dq_buffer#.to(torch.float32)
            dk = dk_buffer#.to(torch.float32)
            dv = dv_buffer#.to(torch.float32)
            if kv_comm.world_size == 1:
                return dq, dk, dv
            # assert not check_nan_inf(dq, f"dq -{step}", dist.get_rank()), f"Pipe: dq is nan or inf at step {step}"
            # assert not check_nan_inf(dk, f"dk -{step}", dist.get_rank()), f"Pipe: dk is nan or inf at step {step}"
            # assert not check_nan_inf(dv, f"dv -{step}", dist.get_rank()), f"Pipe: dv is nan or inf at step {step}"
        else:
            if step <= kv_comm.rank:
                k0 = k[:, :block_seq_len]
                v0 = v[:, :block_seq_len]
                backward(dout, q, k0, v0, out, softmax_lse, causal=False)
                dq += dq_buffer
                # assert not check_nan_inf(dq, f"dq 251 -{step}", dist.get_rank()), f"Pipe: dq is nan or inf at step {step}"
            else:
                backward(dout1, q1, k, v, out1, softmax_lse1, causal=False)
                # always use the first half in dq_buffer.
                dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len]
                # assert not check_nan_inf(dq, f"dq 256 -{step}", dist.get_rank()), f"Pipe: dq is nan or inf at step {step}"

            d_kv_comm.wait()
            dk_comm_buffer, dv_comm_buffer = dk, dv
            dk, dv = next_dk, next_dv

            if step <= kv_comm.rank:
                dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len]
                dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len]
            else:
                dk += dk_buffer
                dv += dv_buffer
                # assert not check_nan_inf(dk, f"dk 268 -{step}", dist.get_rank()), f"Pipe: dk is nan or inf at step {step}"
                # assert not check_nan_inf(dv, f"dv 269 -{step}", dist.get_rank()), f"Pipe: dv is nan or inf at step {step}"

        if step + 1 != kv_comm.world_size:
            kv_comm.wait()
            k = next_k
            v = next_v

        
        next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer)
        next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer)
        d_kv_comm.commit()

    
    d_kv_comm.wait()

    orig_q_dtype = q.dtype

    del dout, q, k, v, out, softmax_lse, dq_buffer, dk_buffer, dv_buffer, dk_comm_buffer, dv_comm_buffer, next_k, next_v

    # if k.shape[2] != q.shape[2]:
    #     bs, slen, nqh, hdim = q.shape
    #     return dq.to(q.dtype), next_dk.view(bs, slen, k.shape[2], (q.shape[2]//k.shape[2]), hdim).sum(dim=3).to(q.dtype), next_dv.view(bs, slen, v.shape[2], (q.shape[2]//v.shape[2]), hdim).sum(dim=3).to(q.dtype)
    # else:
    return dq.to(orig_q_dtype).detach(), next_dk.to(orig_q_dtype).detach(), next_dv.to(orig_q_dtype).detach()

from torch import Tensor
# from yunchang.ring.upipe_ring_flash_attn import upipe_ring_flash_attn_forward, upipe_ring_flash_attn_backward 

from deepspeed.sequence.fpdt_layer import FPDT_Attention, FPDT_FFN, FPDT_InputConstruct, FPDT_LogitsLoss

def fully_fused_attn_func(
        x,
        wq,
        wk,
        wv,
        freqs_cis,
        head_dim,
        dropout_p=0.0,
        softmax_scale=None,
        causal=False,
        window_size=(-1, -1),
        softcap=0.0,
        alibi_slopes=None,
        deterministic=False,
        return_attn_probs=False,
        ring_group=None,
        ulysses_group=None,
        offload_stream=None,
        fetch_stream=None,
        two_streams=None,
        attn_type: AttnType = AttnType.FA,
        attn_processor=None,
        use_pack_qkv=True,
        dualstage=False,
        layer_id=None,
        fused_attn_type="fully_fused",
        w_rms=None,
        eps_rms=None,
):
    
    if "mha" in fused_attn_type:#wq.shape[0] == wk.shape[0] and wq.shape[1] == wk.shape[1]:
        output = FullyFusedAttnFunc.apply(
                x,
                wq,
                wk,
                wv,
                freqs_cis,
                head_dim,
                dropout_p,
                softmax_scale,
                causal,
                window_size,
                softcap,
                alibi_slopes,
                deterministic,
                return_attn_probs,
                ring_group,
                ulysses_group,
                offload_stream,
                fetch_stream,
                two_streams,
                attn_type,
                layer_id,
            )
    else:
        if "ultra_fused" in fused_attn_type:
            output = UltraFusedAttnGQAFunc.apply(
                x,
                w_rms,
                eps_rms,
                wq,
                wk,
                wv,
                freqs_cis,
                head_dim,
                dropout_p,
                softmax_scale,
                causal,
                window_size,
                softcap,
                alibi_slopes,
                deterministic,
                return_attn_probs,
                ring_group,
                ulysses_group,
                offload_stream,
                fetch_stream,
                two_streams,
                attn_type,
                layer_id,
            )
        else:
            output = FullyFusedAttnGQAFunc.apply(
                x,
                wq,
                wk,
                wv,
                freqs_cis,
                head_dim,
                dropout_p,
                softmax_scale,
                causal,
                window_size,
                softcap,
                alibi_slopes,
                deterministic,
                return_attn_probs,
                ring_group,
                ulysses_group,
                offload_stream,
                fetch_stream,
                two_streams,
                attn_type,
                layer_id,
            )
    
    return output
