"""
Memory-efficient attention alternatives for large sequences.
Use these instead of chunked_attention when running out of GPU memory.
"""

import torch
import math
from typing import Tuple


def chunked_attention_memory_efficient(
    Q,
    K,
    V,
    chunk_size=512,
    attention_head_size=64,
    all_head_size=768,
    use_fp16=True,
    temperature_scale=1.0,
    acc_dtype=torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Memory-efficient attention that chunks BOTH Q and K.

    Key differences from original chunked_attention:
    - Chunks Q as well (not just K/V), so max tensor is [batch, heads, q_chunk, k_chunk]
    - Uses float16 by default to halve memory
    - Smaller default chunk_size (512 vs 4096)

    Memory usage: O(chunk_size^2) instead of O(seq_len * chunk_size)

    Args:
        Q, K, V: [batch, seq_len, hidden_size] tensors
        chunk_size: Size of chunks for both Q and K dimensions
        use_fp16: Use half precision (recommended for memory savings)
        temperature_scale: Scale factor for attention scores
        acc_dtype: Accumulator dtype for mean/cov

    Returns:
        mean: [batch, hidden_size]
        cov:  [batch, hidden_size, hidden_size]
    """
    batch_size, seq_len, hidden_size = Q.shape
    num_heads = all_head_size // attention_head_size
    base_scale = 1.0 / math.sqrt(attention_head_size)
    effective_scale = temperature_scale * base_scale

    orig_dtype = Q.dtype
    if use_fp16 and Q.dtype != torch.float16:
        Q = Q.half()
        K = K.half()
        V = V.half()

    def transpose_for_scores(x):
        # Same reshape/permute as BigBird's transpose_for_scores (aka to_heads).
        x = x.view(batch_size, -1, num_heads, attention_head_size)
        return x.permute(0, 2, 1, 3)  # [batch, heads, seq, head_dim]

    sum_y = torch.zeros((batch_size, hidden_size), device=Q.device, dtype=acc_dtype)
    sum_yy = torch.zeros((batch_size, hidden_size, hidden_size), device=Q.device, dtype=acc_dtype)

    num_q_chunks = (seq_len + chunk_size - 1) // chunk_size

    for q_idx in range(num_q_chunks):
        q_start = q_idx * chunk_size
        q_end = min(q_start + chunk_size, seq_len)
        q_len = q_end - q_start

        # Get Q chunk and transpose
        Q_chunk = transpose_for_scores(Q[:, q_start:q_end, :])  # [batch, heads, q_chunk, head_dim]

        # Accumulators for this Q chunk (online softmax)
        m = torch.full((batch_size, num_heads, q_len, 1), float('-inf'), device=Q.device, dtype=Q.dtype)
        l = torch.zeros((batch_size, num_heads, q_len, 1), device=Q.device, dtype=Q.dtype)
        o = torch.zeros((batch_size, num_heads, q_len, attention_head_size), device=Q.device, dtype=Q.dtype)

        num_k_chunks = (seq_len + chunk_size - 1) // chunk_size

        for k_idx in range(num_k_chunks):
            k_start = k_idx * chunk_size
            k_end = min(k_start + chunk_size, seq_len)

            # Get K, V chunks
            K_chunk = transpose_for_scores(K[:, k_start:k_end, :])  # [batch, heads, k_chunk, head_dim]
            V_chunk = transpose_for_scores(V[:, k_start:k_end, :])

            # Compute attention scores for this chunk pair
            # Shape: [batch, heads, q_chunk, k_chunk] - MUCH smaller!
            scores = torch.matmul(Q_chunk, K_chunk.transpose(-1, -2)) * effective_scale

            # Online softmax update
            m_chunk = scores.max(dim=-1, keepdim=True).values
            m_new = torch.maximum(m, m_chunk)

            exp_m_diff = torch.exp(m - m_new)
            l = l * exp_m_diff
            o = o * exp_m_diff

            exp_scores = torch.exp(scores - m_new)
            l = l + exp_scores.sum(dim=-1, keepdim=True)
            o = o + torch.matmul(exp_scores, V_chunk)

            m = m_new

            del scores, exp_scores, m_chunk

        # Finalize output for this Q chunk and accumulate stats
        y_block = (o / l).permute(0, 2, 1, 3).contiguous().view(batch_size, q_len, hidden_size)
        y_block_acc = y_block.to(acc_dtype)
        sum_y += y_block_acc.sum(dim=1)
        sum_yy += y_block_acc.transpose(1, 2) @ y_block_acc

        del Q_chunk, m, l, o, y_block, y_block_acc
        torch.cuda.empty_cache()

    mean = sum_y / seq_len
    cov = (sum_yy / seq_len) - mean.unsqueeze(2) @ mean.unsqueeze(1)
    cov = 0.5 * (cov + cov.transpose(-1, -2))

    if use_fp16 and orig_dtype != torch.float16:
        mean = mean.to(orig_dtype)
        cov = cov.to(orig_dtype)

    return mean, cov


def chunked_attention_cpu_offload(
    Q,
    K,
    V,
    chunk_size=512,
    attention_head_size=64,
    all_head_size=768,
    temperature_scale=1.0,
    acc_dtype=torch.float32,
    compute_device=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Ultra memory-efficient attention using CPU offloading.

    Keeps only the current chunk on GPU, offloads everything else to CPU.
    Slower but uses minimal GPU memory.

    Use this when even chunked_attention_memory_efficient runs OOM.

    Returns:
        mean: [batch, hidden_size]
        cov:  [batch, hidden_size, hidden_size]
    """
    batch_size, seq_len, hidden_size = Q.shape
    num_heads = all_head_size // attention_head_size
    base_scale = 1.0 / math.sqrt(attention_head_size)
    effective_scale = temperature_scale * base_scale
    if compute_device is None:
        device = Q.device
    else:
        device = torch.device(compute_device)
        if device.type == "cuda" and not torch.cuda.is_available():
            device = torch.device("cpu")

    def transpose_for_scores(x):
        # Same reshape/permute as BigBird's transpose_for_scores (aka to_heads).
        x = x.view(batch_size, -1, num_heads, attention_head_size)
        return x.permute(0, 2, 1, 3)

    # Move K, V to CPU after transposing
    K_heads = transpose_for_scores(K).cpu()  # [batch, heads, seq, head_dim]
    V_heads = transpose_for_scores(V).cpu()
    Q_cpu = Q.cpu()

    # Free GPU memory
    del K, V
    torch.cuda.empty_cache()

    sum_y = torch.zeros((batch_size, hidden_size), device="cpu", dtype=acc_dtype)
    sum_yy = torch.zeros((batch_size, hidden_size, hidden_size), device="cpu", dtype=acc_dtype)

    num_q_chunks = (seq_len + chunk_size - 1) // chunk_size

    for q_idx in range(num_q_chunks):
        q_start = q_idx * chunk_size
        q_end = min(q_start + chunk_size, seq_len)
        q_len = q_end - q_start

        # Move Q chunk to GPU
        Q_chunk = transpose_for_scores(Q_cpu[:, q_start:q_end, :].to(device))

        # Accumulators on GPU
        m = torch.full((batch_size, num_heads, q_len, 1), float('-inf'), device=device, dtype=Q_chunk.dtype)
        l = torch.zeros((batch_size, num_heads, q_len, 1), device=device, dtype=Q_chunk.dtype)
        o = torch.zeros((batch_size, num_heads, q_len, attention_head_size), device=device, dtype=Q_chunk.dtype)

        num_k_chunks = (seq_len + chunk_size - 1) // chunk_size

        for k_idx in range(num_k_chunks):
            k_start = k_idx * chunk_size
            k_end = min(k_start + chunk_size, seq_len)

            # Move K, V chunks to GPU
            K_chunk = K_heads[:, :, k_start:k_end, :].to(device)
            V_chunk = V_heads[:, :, k_start:k_end, :].to(device)

            scores = torch.matmul(Q_chunk, K_chunk.transpose(-1, -2)) * effective_scale

            m_chunk = scores.max(dim=-1, keepdim=True).values
            m_new = torch.maximum(m, m_chunk)

            exp_m_diff = torch.exp(m - m_new)
            l = l * exp_m_diff
            o = o * exp_m_diff

            exp_scores = torch.exp(scores - m_new)
            l = l + exp_scores.sum(dim=-1, keepdim=True)
            o = o + torch.matmul(exp_scores, V_chunk)

            m = m_new

            del K_chunk, V_chunk, scores, exp_scores
            torch.cuda.empty_cache()

        y_block = (o / l).permute(0, 2, 1, 3).contiguous().view(batch_size, q_len, hidden_size)
        y_block_cpu = y_block.to(dtype=acc_dtype, device="cpu")
        sum_y += y_block_cpu.sum(dim=1)
        sum_yy += y_block_cpu.transpose(1, 2) @ y_block_cpu

        del Q_chunk, m, l, o, y_block, y_block_cpu
        torch.cuda.empty_cache()

        if q_idx % 10 == 0:
            print(f"  Q chunk {q_idx+1}/{num_q_chunks} done")

    mean = sum_y / seq_len
    cov = (sum_yy / seq_len) - mean.unsqueeze(2) @ mean.unsqueeze(1)
    cov = 0.5 * (cov + cov.transpose(-1, -2))
    return mean.to(device), cov.to(device)


# Drop-in replacement function - auto-selects best method
def chunked_attention_auto(Q, K, V, chunk_size=None, attention_head_size=64, all_head_size=768,
                           force_cpu_offload=False, temperature_scale=1.0, acc_dtype=torch.float32):
    """
    Automatically selects the best chunked attention method based on available memory.

    Args:
        chunk_size: If None, auto-selects based on available memory
        force_cpu_offload: If True, always use CPU offloading (slowest but safest)
    Returns:
        mean: [batch, hidden_size]
        cov:  [batch, hidden_size, hidden_size]
    """
    batch_size, seq_len, hidden_size = Q.shape

    if force_cpu_offload:
        print("  Using CPU offload attention (slow but memory-safe)")
        return chunked_attention_cpu_offload(Q, K, V, chunk_size=512,
                                              attention_head_size=attention_head_size,
                                              all_head_size=all_head_size,
                                              temperature_scale=temperature_scale,
                                              acc_dtype=acc_dtype)

    # Check available GPU memory
    if torch.cuda.is_available():
        free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()
        free_memory_gb = free_memory / (1024**3)
    else:
        free_memory_gb = 0

    # Auto-select chunk size based on memory
    if chunk_size is None:
        if free_memory_gb > 10:
            chunk_size = 1024
        elif free_memory_gb > 5:
            chunk_size = 512
        elif free_memory_gb > 2:
            chunk_size = 256
        else:
            chunk_size = 128

    print(f"  Using memory-efficient attention (chunk_size={chunk_size}, fp16=True, free={free_memory_gb:.1f}GB)")

    try:
        return chunked_attention_memory_efficient(Q, K, V, chunk_size=chunk_size,
                                                   attention_head_size=attention_head_size,
                                                   all_head_size=all_head_size,
                                                   use_fp16=True,
                                                   temperature_scale=temperature_scale,
                                                   acc_dtype=acc_dtype)
    except torch.cuda.OutOfMemoryError:
        print("  GPU OOM, falling back to CPU offload...")
        torch.cuda.empty_cache()
        return chunked_attention_cpu_offload(Q, K, V, chunk_size=256,
                                              attention_head_size=attention_head_size,
                                              all_head_size=all_head_size,
                                              temperature_scale=temperature_scale,
                                              acc_dtype=acc_dtype)


# ============== TEST FUNCTIONS ==============

def standard_attention(Q, K, V, attention_head_size=64, all_head_size=768):
    """
    Standard full attention (no chunking) - for comparison.
    Only use on small sequences that fit in memory.
    """
    batch_size, seq_len, hidden_size = Q.shape
    num_heads = all_head_size // attention_head_size
    scale = 1.0 / math.sqrt(attention_head_size)

    def transpose_for_scores(x):
        # Same reshape/permute as BigBird's transpose_for_scores (aka to_heads).
        x = x.view(batch_size, -1, num_heads, attention_head_size)
        return x.permute(0, 2, 1, 3)

    Q_heads = transpose_for_scores(Q)
    K_heads = transpose_for_scores(K)
    V_heads = transpose_for_scores(V)

    # Full attention matrix
    scores = torch.matmul(Q_heads, K_heads.transpose(-1, -2)) * scale
    attn_weights = torch.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V_heads)

    output = output.permute(0, 2, 1, 3).contiguous()
    return output.view(batch_size, seq_len, all_head_size)


def get_mean_cov(x):
    """Input [b, n, d]. Returns mean [b, d] and cov [b, d, d]."""
    n = x.size(1)
    mean = x.mean(dim=1)
    m2 = (x.transpose(1, 2) @ x) / n
    cov = m2 - mean.unsqueeze(2) @ mean.unsqueeze(1)
    return mean, cov


def test_attention_equivalence(seq_len=1024, hidden_size=768, chunk_sizes=[64, 128, 256, 512],
                                device='cuda' if torch.cuda.is_available() else 'cpu'):
    """
    Test that all attention implementations produce the same output.

    Run this to verify correctness before using on real data.
    Uses a small sequence that fits in memory for comparison.
    """
    print("=" * 60)
    print("ATTENTION EQUIVALENCE TEST")
    print("=" * 60)
    print(f"seq_len={seq_len}, hidden_size={hidden_size}, device={device}")
    print()

    # Create random Q, K, V
    torch.manual_seed(42)
    Q = torch.randn(1, seq_len, hidden_size, device=device)
    K = torch.randn(1, seq_len, hidden_size, device=device)
    V = torch.randn(1, seq_len, hidden_size, device=device)

    # Ground truth: standard full attention
    print("Computing standard (full) attention...")
    with torch.no_grad():
        output_standard = standard_attention(Q, K, V)
    mean_standard, cov_standard = get_mean_cov(output_standard)
    print(f"  Shape: {output_standard.shape}")

    results = []

    # Test memory-efficient chunked attention with different chunk sizes
    for chunk_size in chunk_sizes:
        print(f"\nTesting chunked_attention_memory_efficient (chunk_size={chunk_size}, fp16=False)...")
        with torch.no_grad():
            mean_chunked, cov_chunked = chunked_attention_memory_efficient(
                Q.clone(), K.clone(), V.clone(),
                chunk_size=chunk_size,
                use_fp16=False  # Use fp32 for fair comparison
            )

        # Compare
        max_diff = max(
            (mean_standard - mean_chunked).abs().max().item(),
            (cov_standard - cov_chunked).abs().max().item(),
        )
        mean_diff = (mean_standard - mean_chunked).abs().mean().item()
        rel_error = mean_diff / mean_standard.abs().mean().item()

        status = "PASS" if max_diff < 1e-4 else "FAIL"
        print(f"  Max diff:  {max_diff:.2e}")
        print(f"  Mean diff: {mean_diff:.2e}")
        print(f"  Rel error: {rel_error:.2e}")
        print(f"  Status:    {status}")

        results.append({
            'method': f'memory_efficient (chunk={chunk_size})',
            'max_diff': max_diff,
            'mean_diff': mean_diff,
            'status': status
        })

    # Test with fp16
    print(f"\nTesting chunked_attention_memory_efficient (chunk_size=256, fp16=True)...")
    with torch.no_grad():
        mean_fp16, cov_fp16 = chunked_attention_memory_efficient(
            Q.clone(), K.clone(), V.clone(),
            chunk_size=256,
            use_fp16=True
        )

    max_diff = max(
        (mean_standard - mean_fp16).abs().max().item(),
        (cov_standard - cov_fp16).abs().max().item(),
    )
    mean_diff = (mean_standard - mean_fp16).abs().mean().item()
    rel_error = mean_diff / mean_standard.abs().mean().item()

    # fp16 has lower precision, so higher tolerance
    status = "PASS" if max_diff < 1e-2 else "FAIL"
    print(f"  Max diff:  {max_diff:.2e}")
    print(f"  Mean diff: {mean_diff:.2e}")
    print(f"  Rel error: {rel_error:.2e}")
    print(f"  Status:    {status} (fp16 has lower precision, ~1e-3 expected)")

    results.append({
        'method': 'memory_efficient (fp16)',
        'max_diff': max_diff,
        'mean_diff': mean_diff,
        'status': status
    })

    # Test CPU offload version
    if device != 'cpu':
        print(f"\nTesting chunked_attention_cpu_offload (chunk_size=256)...")
        with torch.no_grad():
            mean_cpu_offload, cov_cpu_offload = chunked_attention_cpu_offload(
                Q.clone(), K.clone(), V.clone(),
                chunk_size=256
            )

        max_diff = max(
            (mean_standard - mean_cpu_offload).abs().max().item(),
            (cov_standard - cov_cpu_offload).abs().max().item(),
        )
        mean_diff = (mean_standard - mean_cpu_offload).abs().mean().item()

        status = "PASS" if max_diff < 1e-4 else "FAIL"
        print(f"  Max diff:  {max_diff:.2e}")
        print(f"  Mean diff: {mean_diff:.2e}")
        print(f"  Status:    {status}")

        results.append({
            'method': 'cpu_offload',
            'max_diff': max_diff,
            'mean_diff': mean_diff,
            'status': status
        })

    # Summary
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    all_pass = all(r['status'] == 'PASS' for r in results)
    for r in results:
        print(f"  {r['method']:40s} {r['status']} (max_diff={r['max_diff']:.2e})")

    print()
    if all_pass:
        print("ALL TESTS PASSED - implementations are mathematically equivalent")
    else:
        print("SOME TESTS FAILED - check the results above")

    return all_pass


if __name__ == "__main__":
    # Run the test
    test_attention_equivalence()




#old version of chuned attention that computed the attention output
# def chunked_attention_memory_efficient(Q, K, V, chunk_size=512, attention_head_size=64, all_head_size=768, use_fp16=True, temperature_scale=1.0):
#     """
#     Memory-efficient attention that chunks BOTH Q and K.

#     Key differences from original chunked_attention:
#     - Chunks Q as well (not just K/V), so max tensor is [batch, heads, q_chunk, k_chunk]
#     - Uses float16 by default to halve memory
#     - Smaller default chunk_size (512 vs 4096)

#     Memory usage: O(chunk_size^2) instead of O(seq_len * chunk_size)

#     Args:
#         Q, K, V: [batch, seq_len, hidden_size] tensors
#         chunk_size: Size of chunks for both Q and K dimensions
#         use_fp16: Use half precision (recommended for memory savings)
#     """
#     batch_size, seq_len, hidden_size = Q.shape
#     num_heads = all_head_size // attention_head_size
#     base_scale = 1.0 / math.sqrt(attention_head_size)
#     effective_scale = temperature_scale * base_scale

#     orig_dtype = Q.dtype
#     if use_fp16 and Q.dtype != torch.float16:
#         Q = Q.half()
#         K = K.half()
#         V = V.half()

#     def transpose_for_scores(x):
#         # Same reshape/permute as BigBird's transpose_for_scores (aka to_heads).
#         x = x.view(batch_size, -1, num_heads, attention_head_size)
#         return x.permute(0, 2, 1, 3)  # [batch, heads, seq, head_dim]

#     # Transpose K and V once (they're accessed multiple times)
#     K_heads = transpose_for_scores(K)  # [batch, heads, seq, head_dim]
#     V_heads = transpose_for_scores(V)

#     # Output accumulator - keep on GPU
#     output = torch.zeros(batch_size, seq_len, hidden_size, device=Q.device, dtype=Q.dtype)

#     num_q_chunks = (seq_len + chunk_size - 1) // chunk_size

#     for q_idx in range(num_q_chunks):
#         q_start = q_idx * chunk_size
#         q_end = min(q_start + chunk_size, seq_len)
#         q_len = q_end - q_start

#         # Get Q chunk and transpose
#         Q_chunk = transpose_for_scores(Q[:, q_start:q_end, :])  # [batch, heads, q_chunk, head_dim]

#         # Accumulators for this Q chunk (online softmax)
#         m = torch.full((batch_size, num_heads, q_len, 1), float('-inf'), device=Q.device, dtype=Q.dtype)
#         l = torch.zeros((batch_size, num_heads, q_len, 1), device=Q.device, dtype=Q.dtype)
#         o = torch.zeros((batch_size, num_heads, q_len, attention_head_size), device=Q.device, dtype=Q.dtype)

#         num_k_chunks = (seq_len + chunk_size - 1) // chunk_size

#         for k_idx in range(num_k_chunks):
#             k_start = k_idx * chunk_size
#             k_end = min(k_start + chunk_size, seq_len)

#             # Get K, V chunks
#             K_chunk = K_heads[:, :, k_start:k_end, :]  # [batch, heads, k_chunk, head_dim]
#             V_chunk = V_heads[:, :, k_start:k_end, :]

#             # Compute attention scores for this chunk pair
#             # Shape: [batch, heads, q_chunk, k_chunk] - MUCH smaller!
#             scores = torch.matmul(Q_chunk, K_chunk.transpose(-1, -2)) * effective_scale

#             # Online softmax update
#             m_chunk = scores.max(dim=-1, keepdim=True).values
#             m_new = torch.maximum(m, m_chunk)

#             exp_m_diff = torch.exp(m - m_new)
#             l = l * exp_m_diff
#             o = o * exp_m_diff

#             exp_scores = torch.exp(scores - m_new)
#             l = l + exp_scores.sum(dim=-1, keepdim=True)
#             o = o + torch.matmul(exp_scores, V_chunk)

#             m = m_new

#             del scores, exp_scores, m_chunk

#         # Finalize output for this Q chunk
#         o_final = o / l
#         o_final = o_final.permute(0, 2, 1, 3).contiguous()  # [batch, q_chunk, heads, head_dim]
#         output[:, q_start:q_end, :] = o_final.view(batch_size, q_len, hidden_size)

#         del Q_chunk, m, l, o, o_final
#         torch.cuda.empty_cache()

#     # Convert back to original dtype if needed
#     if use_fp16 and orig_dtype != torch.float16:
#         output = output.to(orig_dtype)

#     return output
