from numpy import var
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, BlockMask
from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_varlen_func
flex_attention_compiled = torch.compile(flex_attention, dynamic=False)


def gen_tf_blockmask(num_denoise_block, device="cuda"):
    n = num_denoise_block
    total_blocks = num_denoise_block * 2
    block_mask = torch.zeros(total_blocks, total_blocks, device=device, dtype=torch.bool)
    block_mask[:n, :n] = torch.tril(torch.ones(n, n, device=device, dtype=torch.bool))
    block_mask[n:, :n] = torch.tril(torch.ones(n, n, device=device, dtype=torch.bool), diagonal=-1)
    block_mask[n:, n:] = torch.eye(n, device=device, dtype=torch.bool)
    return block_mask

def get_naive_causal_block_mask_for_tf_training(seq_len, num_denoise_block, block_size=16 * 16, device="cuda"):
    """
    Teacher-Forcing training

    Example:
    num_denoise_block = 3

      clean  |  noisy
    [1, 0, 0, 0, 0, 0]
    [1, 1, 0, 0, 0, 0]
    [1, 1, 1, 0, 0, 0]
    [0, 0, 0, 1, 0, 0]
    [1, 0, 0, 0, 1, 0]
    [1, 1, 0, 0, 0, 1]
    """
    assert seq_len % block_size == 0, "seq_len must be divisible by block_size"
    total_blocks = seq_len // block_size
    assert total_blocks == num_denoise_block * 2, "total_blocks must be 2 * num_denoise_block"

    block_mask = gen_tf_blockmask(num_denoise_block, device=device)
    expanded_block_mask = block_mask.repeat_interleave(block_size, dim=-1).repeat_interleave(block_size, dim=-2)
    return expanded_block_mask


def BLHD_sdpa_with_mask(q, k, v, attention_mask=None):
    if attention_mask.numel() > 2147483647:
        return BLHD_sdpa_with_mask_chunk(q, k, v, attention_mask=attention_mask)
    q, k, v = map(lambda x: rearrange(x, "b l h d -> b h l d"), (q, k, v))
    o = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
    return rearrange(o, "b h l d -> b l h d")

def BLHD_sdpa_with_mask_chunk(q, k, v, attention_mask=None):
    q, k, v = map(lambda x: rearrange(x, "b l h d -> b h l d"), (q, k, v))
    chunk_size = 8
    qs = torch.chunk(q, chunk_size, dim=2)
    masks = torch.chunk(attention_mask, chunk_size, dim=0)
    os = []
    for subq, mask in zip(qs, masks):
        assert mask.numel() <= 2147483647, "mask.numel() must be less than 2147483647"
        o = torch.nn.functional.scaled_dot_product_attention(subq, k, v, attn_mask=mask)
        os.append(o)
    o = torch.cat(os, dim=2)
    return rearrange(o, "b h l d -> b l h d")

def get_flex_causal_block_mask_for_tf_training(seq_len, num_denoise_block, block_size=16 * 16, device="cuda"):
    # do not call mask mod at all
    def block_mask_mod(b, h, q_idx, kv_idx):
        bqi = q_idx // block_size
        bki = kv_idx // block_size
        return torch.where(bqi == bki, True, torch.where(bqi < num_denoise_block, bqi >= bki, bqi - num_denoise_block > bki))
    block_mask = create_block_mask(block_mask_mod, None, None, Q_LEN=seq_len, KV_LEN=seq_len, device=device, _compile=True)
    return block_mask

def get_flex_causal_block_mask_for_prefill(num_denoise_block, block_size=16 * 16, device="cuda"):
    seq_len = num_denoise_block * block_size
    def block_mask_mod(b, h, q_idx, kv_idx):
        bqi = q_idx // block_size
        bki = kv_idx // block_size
        return bqi >= bki
    block_mask = create_block_mask(block_mask_mod, None, None, Q_LEN=seq_len, KV_LEN=seq_len, device=device, _compile=True)
    return block_mask


def get_flex_block_mask_chunk_prefill(num_qblock, num_kvblock, block_size=16 * 16, device="cuda"):
    """
    attention mask like, numq = 2, kvlen=4:
    [1, 1, 1, 1, 1, 0]
    [1, 1, 1, 1, 1, 1]
    """
    Q_LEN = num_qblock * block_size
    KV_LEN = (num_kvblock + num_qblock) * block_size
    def block_mask_mod(b, h, q_idx, kv_idx):
        bqi = q_idx // block_size
        bki = kv_idx // block_size
        return torch.where(bki < num_kvblock, True, bqi + num_kvblock >= bki)
    block_mask = create_block_mask(block_mask_mod, None, None, Q_LEN=Q_LEN, KV_LEN=KV_LEN, device=device, _compile=True)
    return block_mask

def BLHD_flex_with_mask(q, k, v, attention_mask=None):
    q, k, v = map(lambda x: rearrange(x, "b l h d -> b h l d").contiguous(), (q, k, v))
    o = flex_attention_compiled(q, k, v, block_mask=attention_mask)
    return rearrange(o, "b h l d -> b l h d")

def get_varlen_causal_block_mask_for_tf_training(seq_len, num_denoise_block, block_size=16 * 16, device="cuda"):
    block_mask = gen_tf_blockmask(num_denoise_block)
    return { 
            'block_mask': block_mask,
            'block_size': block_size
        }


def BLHD_varlen_with_mask(q, k, v, attention_mask):
    block_size = attention_mask['block_size']
    block_mask = attention_mask['block_mask']
    device = q.device
    B = q.shape[0]
    q, k, v = map(lambda x: rearrange(x, "b l h d -> (b l) h d"), (q, k, v))
    q_cu_seq_lens = [0]
    k_cu_seq_lens = [0]
    k_segs = []
    v_segs = []
    max_len = 0
    for b in range(B):
        cu_seq_len_base = b * block_size * block_mask.shape[0]
        for qii in range(block_mask.shape[0]):
            q_cu_seq_lens.append((qii + 1) * block_size + cu_seq_len_base)
            cur_k_len = 0
            for kjj in range(block_mask.shape[1]):
                if block_mask[qii][kjj].item():
                    cur_k_len += block_size
                    k_segs.append(k[cu_seq_len_base + kjj * block_size: cu_seq_len_base + (kjj + 1) * block_size])
                    v_segs.append(v[cu_seq_len_base + kjj * block_size: cu_seq_len_base + (kjj + 1) * block_size])
                else:
                    continue
            k_cu_seq_lens.append(k_cu_seq_lens[-1] + cur_k_len)
            max_len = max(max_len, cur_k_len)
    k = torch.cat(k_segs, dim=0)
    v = torch.cat(v_segs, dim=0)
    o = flash_attn_varlen_func(q, k, v, torch.tensor(q_cu_seq_lens, dtype=torch.int, device=device), torch.tensor(k_cu_seq_lens, dtype=torch.int, device=device), block_size, max_len)
    o = rearrange(o, "(b l) h d -> b l h d", b=B)
    return o


def test_correctness():
    """
    Test BLHD_flex_with_mask correctness against BLHD_sdpa_with_mask as ground truth.
    Tests both forward and backward passes.
    """
    import torch
    import numpy as np
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    num_denoise_block = 4
    head_dim = 64
    num_heads = 2
    
    print("Testing BLHD_flex_with_mask correctness...")
    print("=" * 60)
    
    # Test different block sizes from 2**5 to 2**10
    for block_size_exp in range(5, 11):
        block_size = 2 ** block_size_exp
        seq_len = block_size * num_denoise_block * 2
        batch_size = 2
        
        print(f"\nBlock size: {block_size}, Seq len: {seq_len}")
        print("-" * 40)
        
        # Create test data
        torch.manual_seed(42)  # For reproducibility
        q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, requires_grad=True, dtype=torch.float16)
        k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, requires_grad=True, dtype=torch.float16)
        v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, requires_grad=True, dtype=torch.float16)
        
        # Create copies for flex attention
        q_flex = q.clone().detach().requires_grad_(True)
        k_flex = k.clone().detach().requires_grad_(True)
        v_flex = v.clone().detach().requires_grad_(True)
        
        q_varlen = q.clone().detach().requires_grad_(True)
        k_varlen = k.clone().detach().requires_grad_(True)
        v_varlen = v.clone().detach().requires_grad_(True)
        
        # Generate attention masks
        naive_mask = get_naive_causal_block_mask_for_tf_training(
            seq_len, num_denoise_block, block_size=block_size, device=device
        )
        flex_mask = get_flex_causal_block_mask_for_tf_training(
            seq_len, num_denoise_block, block_size=block_size, device=device
        )
        varlen_mask = get_varlen_causal_block_mask_for_tf_training(
            seq_len, num_denoise_block, block_size=block_size, device=device
        )
        # Forward pass
        out_sdpa = BLHD_sdpa_with_mask(q, k, v, attention_mask=naive_mask)
        out_flex = BLHD_flex_with_mask(q_flex, k_flex, v_flex, attention_mask=flex_mask)
        out_varlen = BLHD_varlen_with_mask(q_varlen, k_varlen, v_varlen, attention_mask=varlen_mask)
        # out_flex = BLHD_sdpa_with_mask_chunk(q_flex, k_flex, v_flex, attention_mask=naive_mask)
        
        
        def print_diff(out1, out2):
            # Calculate forward pass metrics
            forward_diff = out1 - out2
            forward_mean_diff = forward_diff.abs().mean().item()
            forward_max_diff = forward_diff.abs().max().item()
            forward_norm_ratio = (forward_diff.norm() / out_sdpa.norm()).item()
            
            print(f"Forward pass:")
            print(f"  Mean diff: {forward_mean_diff:.4f}")
            print(f"  Max diff:  {forward_max_diff:.4f}")
            print(f"  Norm ratio: {forward_norm_ratio:.4f}")
        print_diff(out_sdpa, out_varlen)
        print_diff(out_sdpa, out_flex)
        
        
        # Backward pass test
        # Create a dummy loss
        loss_sdpa = out_sdpa.sum()
        loss_flex = out_flex.sum()
        loss_varlen = out_varlen.sum()
        
        # Compute gradients
        loss_sdpa.backward()
        loss_flex.backward()
        loss_varlen.backward()
        
        
        def print_diff_backward(q, k, v, q1, k1, v1):
            # Compare gradients for q, k, v
            grad_metrics = {}
            for name, grad, grad1 in [
                ('q', q.grad, q1.grad),
                ('k', k.grad, k1.grad), 
                ('v', v.grad, v1.grad)
            ]:
                if grad is not None and grad1 is not None:
                    grad_diff = grad1 - grad
                    grad_mean_diff = grad_diff.abs().mean().item()
                    grad_max_diff = grad_diff.abs().max().item()
                    grad_norm_ratio = (grad_diff.norm() / grad.norm()).item()
                    
                    grad_metrics[name] = {
                        'mean': grad_mean_diff,
                        'max': grad_max_diff,
                        'norm_ratio': grad_norm_ratio
                    }
            print(f"Backward pass:")
            for name, metrics in grad_metrics.items():
                print(f"  {name}_grad - Mean diff: {metrics['mean']:.5f}, "
                        f"Max diff: {metrics['max']:.5f}, "
                        f"Norm ratio: {metrics['norm_ratio']:.5f}")
        print_diff_backward(q, k, v, q_flex, k_flex, v_flex)
        print_diff_backward(q, k, v, q_varlen, k_varlen, v_varlen)
    
    print("\n" + "=" * 60)
    print("Test completed!")


def benchmark_speed():
    import math
    import json
    import itertools
    from tqdm import tqdm
    from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined
    
    def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
        assert mode in ["fwd", "bwd", "fwd_bwd"]
        f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
        return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)

    def efficiency(flop, time):
        return (flop / time / 10**12) if not math.isnan(time) else 0.0
    
    head_dim_ls = [64, 128]
    num_head_ls = [16]
    batch_size_ls = [1]
    block_size_ls = [1024, 1840]  # Different block sizes to test
    num_denoise_block_ls = [8, 16]  # Different numbers of denoise blocks
    device = 'cuda'
    mode = 'fwd'
    device_name = torch.cuda.get_device_name(device)
    
        # Test methods - wrapper functions to match the benchmark interface
    def sdpa_method(q, k, v, **kwargs):
        attention_mask = kwargs.get('attention_mask', None)
        return BLHD_sdpa_with_mask(q, k, v, attention_mask=attention_mask)
    
    def flex_method(q, k, v, **kwargs):
        attention_mask = kwargs.get('attention_mask', None)
        return BLHD_flex_with_mask(q, k, v, attention_mask=attention_mask)

    def flash_attention_method(q, k, v, **kwargs):
        from flash_attn.flash_attn_interface import flash_attn_func
        return flash_attn_func(q, k, v)
    
    def varlen_method(q, k, v, **kwargs):
        attention_mask = kwargs.get('attention_mask', None)
        return BLHD_varlen_with_mask(q, k, v, attention_mask=attention_mask)
    
    method_ls = [sdpa_method, flex_method, flash_attention_method, varlen_method]
    
    for headdim, head, batch, block_size, num_denoise_block, method in tqdm(list(itertools.product(
        head_dim_ls, num_head_ls, batch_size_ls, block_size_ls, num_denoise_block_ls, method_ls))):
        torch.cuda.empty_cache()
        
        # Calculate seq_len based on block_size and num_denoise_block
        seq_len = block_size * num_denoise_block * 2    
        # Create input tensors in b h s d format for benchmark functions
        q = torch.randn(batch, seq_len, head, headdim, dtype=torch.float16, device=device)
        k, v = torch.rand_like(q), torch.rand_like(q)
        
        # Generate appropriate attention mask
        if method == sdpa_method:
            attention_mask = get_naive_causal_block_mask_for_tf_training(
                seq_len, num_denoise_block, block_size=block_size, device=device
            )
        elif method == varlen_method:
            attention_mask = get_varlen_causal_block_mask_for_tf_training(
                seq_len, num_denoise_block, block_size=block_size, device=device
            )
        else:  # flex_method
            attention_mask = get_flex_causal_block_mask_for_tf_training(
                seq_len, num_denoise_block, block_size=block_size, device=device
            )
        
        try:
            if mode == 'fwd':
                # Warmup
                _, time = benchmark_forward(method, q, k, v, attention_mask=attention_mask, repeats=3, verbose=False, desc='forward')
                torch.cuda.synchronize()
                # Actual measurement
                _, time = benchmark_forward(method, q, k, v, attention_mask=attention_mask, repeats=6, verbose=False, desc='forward')
            elif mode == 'bwd':
                _, time = benchmark_backward(method, q, k, v, attention_mask=attention_mask, repeats=3, verbose=False, desc='backward')
                torch.cuda.synchronize()
                _, time = benchmark_backward(method, q, k, v, attention_mask=attention_mask, repeats=6, verbose=False, desc='backward')
            elif mode == 'fwd_bwd':
                _, time = benchmark_combined(method, q, k, v, attention_mask=attention_mask, repeats=3, verbose=False, desc='combine')
                torch.cuda.synchronize()
                _, time = benchmark_combined(method, q, k, v, attention_mask=attention_mask, repeats=6, verbose=False, desc='combine')
            
            flop_count = flops(batch, seq_len, headdim, head, False, mode)
            efficiency_score = efficiency(flop_count, time.mean)
            print(f"{method.__name__}: batch={batch}, heads={head}, headdim={headdim}, seq_len={seq_len}, block_size={block_size}, num_denoise={num_denoise_block}, time={time.mean:.4f}ms, efficiency={efficiency_score:.2f}TFLOPs")
            
        except RuntimeError as e:
            print(f"Error with {method.__name__}: {e}")
            continue
    

if __name__ == "__main__":
    # test_correctness()
    benchmark_speed()
    
    # large test
    import IPython; IPython.embed()
    block_size = 1280
    num_denoise_block = 24
    seq_len = block_size * num_denoise_block * 2
    batch_size = 1
    head_dim = 128
    num_heads = 4
    device = 'cuda'
    q = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.float16, device=device)
    k, v = torch.randn_like(q), torch.randn_like(q)
    attention_mask = get_flex_causal_block_mask_for_tf_training(seq_len, num_denoise_block, block_size=block_size, device=device)
    o = flex_attention_compiled(q, k, v, block_mask=attention_mask)
    attention_mask_native = get_naive_causal_block_mask_for_tf_training(seq_len, num_denoise_block, block_size=block_size, device=device)
    o_native = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask_native)
