import torch
import math
from IPython import embed

def attn_core(query_chunk, key_chunk, value_chunk, mask_chunk=None, offset_chunk=None, query_r=None, key_r=None, offset_clip_range=128, normalized=False):
    attn_scores = torch.einsum('bnqd,bnkd->bnqk', query_chunk, key_chunk)
    attn_scores = attn_scores / math.sqrt(value_chunk.shape[-1])

    if mask_chunk is not None:
        attn_scores = attn_scores + mask_chunk.unsqueeze(1)
    if offset_chunk is not None:
        offset_chunk = offset_chunk.clamp(min=-offset_clip_range, max=offset_clip_range)
        indexed_offset = offset_chunk + offset_clip_range
        attn_scores_qbias = torch.einsum('ind,bnkd->bnik', query_r, key_chunk)
        attn_scores_kbias = torch.einsum('bnqd,jnd->bnqj', query_chunk, key_r)

        # Use the offset to index into the attn_scores_qbias
        # Expand indexed_offset dimensions to match attn_scores_qbias
        indexed_offset = indexed_offset.unsqueeze(1).expand(-1, query_r.size(1), -1, -1)

        # Gather the values using the computed offset index
        attn_scores_qbias = torch.gather(attn_scores_qbias, dim=2, index=indexed_offset) / math.sqrt(value_chunk.shape[-1])
        attn_scores_kbias = torch.gather(attn_scores_kbias, dim=3, index=indexed_offset) / math.sqrt(value_chunk.shape[-1])
        attn_scores = attn_scores + attn_scores_qbias + attn_scores_kbias

    # Compute logsumexp for numerical stability
    amax_weights_norm = attn_scores.amax(dim=-1).to(torch.float32) # bnq

    # Compute softmaxed attention weights without storing attn_scores
    amax_weights = torch.exp(attn_scores - amax_weights_norm.unsqueeze(dim=-1))
    # Compute weighted sum of values
    chunk_reduced_value = torch.einsum('bnkd,bnqk->bnqd', value_chunk, amax_weights.to(value_chunk.dtype))

    if normalized:
        return amax_weights.to(value_chunk.dtype), amax_weights_norm, chunk_reduced_value / amax_weights.sum(dim=-1, keepdim=True)
    else:
        return amax_weights.to(value_chunk.dtype), amax_weights_norm, chunk_reduced_value

def attn_core_logsumexp(query_chunk, key_chunk, value_chunk,
                        mask_chunk=None, offset_chunk=None, query_r=None, key_r=None, offset_clip_range=128):
    attn_scores = torch.einsum('bnqd,bnkd->bnqk', query_chunk, key_chunk)
    attn_scores = attn_scores / math.sqrt(value_chunk.shape[-1])

    if mask_chunk is not None:
        attn_scores = attn_scores + mask_chunk.unsqueeze(1)
    if offset_chunk is not None:
        offset_chunk = offset_chunk.clamp(min=-offset_clip_range, max=offset_clip_range)
        indexed_offset = offset_chunk + offset_clip_range
        attn_scores_qbias = torch.einsum('ind,bnkd->bnik', query_r, key_chunk)
        attn_scores_kbias = torch.einsum('bnqd,jnd->bnqj', query_chunk, key_r)

        # Use the offset to index into the attn_scores_qbias
        # Expand indexed_offset dimensions to match attn_scores_qbias
        indexed_offset = indexed_offset.unsqueeze(1).expand(-1, query_r.size(1), -1, -1)

        # Gather the values using the computed offset index
        attn_scores_qbias = torch.gather(attn_scores_qbias, dim=2, index=indexed_offset) / math.sqrt(value_chunk.shape[-1])
        attn_scores_kbias = torch.gather(attn_scores_kbias, dim=3, index=indexed_offset) / math.sqrt(value_chunk.shape[-1])
        attn_scores = attn_scores + attn_scores_qbias + attn_scores_kbias

    # Compute logsumexp for numerical stability
    attn_scores = attn_scores.to(torch.float32)
    attn_weight = attn_scores.logsumexp(dim=-1) # bnq
    attn_distro = attn_scores.softmax(dim=-1)

    # Compute softmaxed attention weights without storing attn_scores
    # Compute weighted sum of values
    chunk_reduced_value = torch.einsum('bnkd,bnqk->bnqd', value_chunk, attn_distro.to(value_chunk.dtype))

    return attn_weight, attn_distro, chunk_reduced_value

def forward_core(query, key, value, query_chunk_size, key_chunk_size,
                 mask=None, offset_matrix=None, query_r=None, key_r=None,
                 offset_clip_range=128):
    reduced_values = []
    all_attn_weights = []
    mask_value = torch.finfo(query.dtype).min
    with (torch.no_grad()):
        for i in range(0, query.size(2), query_chunk_size):
            cumu_reduced_chunk = None
            cumu_attn_weight = None
            query_chunk = query[:, :, i:i + query_chunk_size, :]
            attn_weights_ = []
            if mask is not None:
                mask_qchunk = mask[:, i:i + query_chunk_size, :].to(query_chunk.device)
            else:
                mask_qchunk = None
            if offset_matrix is not None:
                offset_qchunk = offset_matrix[:, i:i + query_chunk_size, :].to(query_chunk.device)
            else:
                offset_qchunk = None

            for j in range(0, key.size(2), key_chunk_size):
                key_chunk = key[:, :, j:j + key_chunk_size, :]
                value_chunk = value[:, :, j:j + key_chunk_size, :]
                if mask is not None:
                    mask_chunk = mask_qchunk[:, :, j:j + key_chunk_size]
                else:
                    mask_chunk = None
                if offset_matrix is not None:
                    offset_chunk = offset_qchunk[:, :, j:j + key_chunk_size]
                else:
                    offset_chunk = None
                attn_weight, _, reduced_chunk = \
                attn_core_logsumexp(query_chunk, key_chunk, value_chunk,
                       mask_chunk=mask_chunk, offset_chunk=offset_chunk,
                       query_r=query_r, key_r=key_r,
                       offset_clip_range=offset_clip_range
                       )

                if cumu_reduced_chunk is None:
                    cumu_reduced_chunk = reduced_chunk.to(torch.float32)
                    cumu_attn_weight = attn_weight
                else:
                    cumu_attn_weight = torch.stack([cumu_attn_weight, attn_weight])
                    cumu_reduced_chunk = torch.stack([cumu_reduced_chunk, reduced_chunk])
                    cumu_reduced_chunk = torch.einsum(
                        "tbnqd,tbnq->bnqd",
                        cumu_reduced_chunk.to(torch.float32),
                        cumu_attn_weight.softmax(dim=0)
                    ).to(torch.float32)
                    cumu_attn_weight = cumu_attn_weight.logsumexp(dim=0)
                attn_weights_.append(attn_weight)
            reduced_values.append(cumu_reduced_chunk.to(query_chunk.dtype))
            all_attn_weights.append(cumu_attn_weight)
        reduced_values = torch.cat(reduced_values, dim=2)

        all_attn_weights = torch.cat(all_attn_weights, dim=2)

        return reduced_values, all_attn_weights

def backward_core(grad_output, query, key, value, all_attn_weights, chunk_size,
                  mask=None, offset_matrix=None, query_r=None, key_r=None, offset_clip_range=128):
    grad_query = torch.zeros_like(query, dtype=torch.float32)
    grad_key = torch.zeros_like(key, dtype=torch.float32)
    grad_value = torch.zeros_like(value, dtype=torch.float32)
    grad_query_r = torch.zeros_like(query_r, dtype=torch.float32) if query_r is not None else None
    grad_key_r = torch.zeros_like(key_r, dtype=torch.float32) if key_r is not None else None
    scale = 1.0 / math.sqrt(value.shape[-1])

    with torch.no_grad():
        for i in range(0, query.size(2), chunk_size):
            query_chunk = query[:, :, i:i + chunk_size, :].contiguous()
            attn_weights_chunk = all_attn_weights[:, :, i:i + chunk_size].contiguous()
            grad_output_chunk = grad_output[:, :, i:i + chunk_size, :].contiguous()

            if mask is not None:
                mask_qchunk = mask[:, i:i + chunk_size, :].to(query.device)
            else:
                mask_qchunk = None
            if offset_matrix is not None:
                offset_qchunk = offset_matrix[:, i:i + chunk_size, :].to(query.device)
            else:
                offset_qchunk = None

            accumu_modifier = 0.

            for j in range(0, key.size(2), chunk_size):
                key_chunk = key[:, :, j:j + chunk_size, :]
                value_chunk = value[:, :, j:j + chunk_size, :]

                if mask is not None:
                    mask_chunk = mask_qchunk[:, :, j:j + chunk_size]
                else:
                    mask_chunk = None
                if offset_matrix is not None:
                    offset_chunk = offset_qchunk[:, :, j:j + chunk_size]
                else:
                    offset_chunk = None

                # Forward pass to recompute intermediates
                attn_weight, attn_distro, reduced_chunk = attn_core_logsumexp(
                    query_chunk, key_chunk, value_chunk,
                    mask_chunk=mask_chunk, offset_chunk=offset_chunk,
                    query_r=query_r, key_r=key_r,
                    offset_clip_range=offset_clip_range
                )

                adjustment = (attn_weight - attn_weights_chunk).exp().unsqueeze(dim=-1)
                attn_global = attn_distro * adjustment

                grad_attn_global = torch.einsum(
                    "bnqd,bnkd->bnqk", grad_output_chunk, value_chunk
                )
                grad_value_chunk = torch.einsum('bnqk,bnqd->bnkd',
                                                attn_global.to(grad_output_chunk.dtype),
                                                grad_output_chunk)
                grad_value[:, :, j:j + chunk_size, :] += grad_value_chunk.to(torch.float32)

                accumu_modifier = accumu_modifier - torch.einsum("bnqk,bnqk->bnq",
                                                                 grad_attn_global.to(torch.float32),
                                                                 attn_global).unsqueeze(dim=-1)

            for j in range(0, key.size(2), chunk_size):
                key_chunk = key[:, :, j:j + chunk_size, :]
                value_chunk = value[:, :, j:j + chunk_size, :]

                if mask is not None:
                    mask_chunk = mask_qchunk[:, :, j:j + chunk_size]
                else:
                    mask_chunk = None
                if offset_matrix is not None:
                    offset_chunk = offset_qchunk[:, :, j:j + chunk_size]
                else:
                    offset_chunk = None

                # Forward pass to recompute intermediates
                attn_weight, attn_distro, reduced_chunk = attn_core_logsumexp(
                    query_chunk, key_chunk, value_chunk,
                    mask_chunk=mask_chunk, offset_chunk=offset_chunk,
                    query_r=query_r, key_r=key_r,
                    offset_clip_range=offset_clip_range
                )

                adjustment = (attn_weight - attn_weights_chunk).exp().unsqueeze(dim=-1)
                attn_global = attn_distro * adjustment

                grad_attn_global = torch.einsum(
                    "bnqd,bnkd->bnqk", grad_output_chunk, value_chunk
                )

                grad_attn_scores = attn_global * (
                    grad_attn_global + accumu_modifier
                )

                # Compute gradients w.r.t. query and key
                grad_attn_scores_qk = (grad_attn_scores * scale).to(query_chunk.dtype)

                grad_query_chunk = torch.einsum('bnqk,bnkd->bnqd', grad_attn_scores_qk, key_chunk)
                grad_key_chunk = torch.einsum('bnqk,bnqd->bnkd', grad_attn_scores_qk, query_chunk)


                # If query_r and key_r are used, compute their gradients
                if offset_chunk is not None and query_r is not None and key_r is not None:
                    offset_chunk = offset_chunk.clamp(min=-offset_clip_range, max=offset_clip_range)
                    indexed_offset = offset_chunk + offset_clip_range
                    indexed_offset = indexed_offset.unsqueeze(1).expand(-1, query_r.size(1), -1, -1)

                    # Gradients w.r.t. attn_scores_qbias and attn_scores_kbias
                    grad_attn_scores_qbias = grad_attn_scores_qk.to(torch.float32)
                    grad_attn_scores_kbias = grad_attn_scores_qk.to(torch.float32)

                    # Ungather gradients
                    grad_attn_scores_qbias_full = torch.zeros(
                        grad_attn_scores_qbias.size(0),
                        grad_attn_scores_qbias.size(1),
                        query_r.size(0),
                        grad_attn_scores_qbias.size(3),
                        device=grad_attn_scores_qbias.device,
                        dtype=torch.float32
                    ).scatter_add_(2, indexed_offset, grad_attn_scores_qbias)

                    grad_attn_scores_kbias_full = torch.zeros(
                        grad_attn_scores_kbias.size(0),
                        grad_attn_scores_kbias.size(1),
                        grad_attn_scores_kbias.size(2),
                        key_r.size(0),
                        device=grad_attn_scores_kbias.device,
                        dtype=torch.float32
                    ).scatter_add_(3, indexed_offset, grad_attn_scores_kbias)

                    # Compute gradients w.r.t. query_r and key_r
                    grad_query_r += torch.einsum('bnik,bnkd->ind', grad_attn_scores_qbias_full, key_chunk.to(torch.float32))
                    grad_key_chunk += torch.einsum('bnik,ind->bnkd', grad_attn_scores_qbias_full, query_r.to(torch.float32))

                    grad_key_r += torch.einsum('bnqj,bnqd->jnd', grad_attn_scores_kbias_full, query_chunk.to(torch.float32))
                    grad_query_chunk += torch.einsum('bnqj,jnd->bnqd', grad_attn_scores_kbias_full, key_r.to(torch.float32))

                grad_query[:, :, i:i + chunk_size, :] += grad_query_chunk.to(torch.float32)
                grad_key[:, :, j:j + chunk_size, :] += grad_key_chunk.to(torch.float32)

    return grad_query, grad_key, grad_value, grad_query_r, grad_key_r




class MemoryEfficientAttention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, query, key, value, mask=None, offset_matrix=None, query_r=None, key_r=None, chunk_size=1024):
        dtype = value.dtype
        # query, key, value = query, key, value
        key_chunk_size = min(chunk_size * 2, key.size(2))
        query_chunk_size = min(chunk_size, query.size(2))

        with torch.no_grad():
            reduced_values, all_attn_weights = forward_core(query, key, value,
                                                                           query_chunk_size, key_chunk_size,
                                                                           mask, offset_matrix,
                                                                           query_r, key_r
                                                                           )

        # Save full query, key, and value tensors, but not the intermediates
        ctx.save_for_backward(query, key, value, mask, offset_matrix, all_attn_weights, query_r, key_r)
        ctx.chunk_size = chunk_size

        return reduced_values.to(value.dtype)

    @staticmethod
    def backward(ctx, grad_output):

        dtype = grad_output.dtype
        query, key, value, mask, offset_matrix, all_attn_weights, query_r, key_r = ctx.saved_tensors
        chunk_size = ctx.chunk_size

        # Initialize gradients
        with (torch.no_grad()):
            grad_query, grad_key, grad_value, grad_query_r, grad_key_r = \
            backward_core(grad_output, query, key, value,
                          all_attn_weights, chunk_size,
                          mask=mask, offset_matrix=offset_matrix, query_r=query_r, key_r=key_r)

        return grad_query.to(dtype), grad_key.to(dtype), grad_value.to(dtype), None, None, grad_query_r, grad_key_r, None

def memory_efficient_attention(query, key, value, mask=None, offset_matrix=None, query_r=None, key_r=None, chunk_size=2048):
    return MemoryEfficientAttention.apply(query, key, value, mask, offset_matrix, query_r, key_r, chunk_size)

def test_backward():
    # Set up random inputs

    # Set up parameters
    batch_size = 2
    num_heads = 4
    seq_len_q = 4096
    seq_len_kv = 2048
    head_dim = 16
    num_rel_pos_bins = 257
    offset_clip_range = 128
    chunk_size = 512    # Adjust as needed

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.manual_seed(187)
    torch.cuda.synchronize()

    dtype = torch.bfloat16

    # Create random inputs
    query = torch.randn(batch_size, num_heads, seq_len_q, head_dim, requires_grad=True, device=device, dtype=dtype)
    key = torch.randn(batch_size, num_heads, seq_len_kv, head_dim, requires_grad=True, device=device, dtype=dtype)
    value = torch.randn(batch_size, num_heads, seq_len_kv, head_dim, requires_grad=True, device=device, dtype=dtype)
    mask = torch.randn(batch_size, seq_len_q, seq_len_kv, device=device, dtype=dtype)
    offset_matrix = torch.randint(-offset_clip_range, offset_clip_range + 1, (batch_size, seq_len_q, seq_len_kv),
                                  device=device)
    offset_matrix = torch.ones_like(offset_matrix)
    query_r = torch.randn(num_rel_pos_bins, num_heads, head_dim, requires_grad=True, device=device, dtype=dtype)
    key_r = torch.randn(num_rel_pos_bins, num_heads, head_dim, requires_grad=True, device=device, dtype=dtype)

    # Make copies for both methods
    query1 = query.clone().to(torch.float32).detach().requires_grad_(True)
    key1 = key.clone().to(torch.float32).detach().requires_grad_(True)
    value1 = value.clone().to(torch.float32).detach().requires_grad_(True)
    query_r1 = query_r.clone().to(torch.float32).detach().requires_grad_(True)
    key_r1 = key_r.clone().to(torch.float32).detach().requires_grad_(True)

    query2 = query.clone().detach().requires_grad_(True)
    key2 = key.clone().detach().requires_grad_(True)
    value2 = value.clone().detach().requires_grad_(True)
    query_r2 = query_r.clone().detach().requires_grad_(True)
    key_r2 = key_r.clone().detach().requires_grad_(True)

    # Run the reference implementation using forward_core and PyTorch autograd
    _, _, output1= attn_core(
        query1, key1, value1,
        mask_chunk=mask,
        offset_chunk=offset_matrix,
        query_r=query_r1, key_r=key_r1,
        offset_clip_range=offset_clip_range,
        normalized=True
    )
    loss1 = output1.sum()
    loss1.backward()



    # output2 = memory_efficient_attention(
    #     query2, key2, value2,
    #     mask=mask,
    #     # offset_matrix=offset_matrix,
    #     # query_r=query_r2, key_r=key_r2,
    #     chunk_size=4
    # )
    # loss2 = output2.sum()
    # loss2.backward()
    # query2.grad = None
    # key2.grad = None
    # value2.grad = None
    # query_r2.grad = None
    # key_r2.grad = None

    output2 = memory_efficient_attention(
        query2, key2, value2,
        mask=mask,
        offset_matrix=offset_matrix,
        query_r=query_r2, key_r=key_r2,
        chunk_size=chunk_size
    )
    loss2 = output2.sum()
    loss2.backward()

    # Compare the gradients
    print("Max absolute difference in forward:", torch.max(torch.abs(output1 - output2).detach()).item())
    print("Max absolute difference in grad_query:", torch.max(torch.abs(query1.grad - query2.grad)))
    print("Max absolute difference in grad_key:", torch.max(torch.abs(key1.grad - key2.grad)))
    print("Max absolute difference in grad_value:", torch.max(torch.abs(value1.grad - value2.grad)))

    print("Max absolute difference in grad_query_r:", torch.max(torch.abs(query_r1.grad - query_r2.grad)))
    print("Max absolute difference in grad_key_r:", torch.max(torch.abs(key_r1.grad - key_r2.grad)))
    embed()
    exit()

if __name__ == "__main__":
    test_backward()
