import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
import triton
from torch.nn.attention.flex_attention import flex_attention as flex_attention_eager
from torch import Tensor
from torch.nn.attention.bias import CausalBias

flex_attention = torch.compile(
    flex_attention_eager, dynamic=False
)

# flex_attention = flex_attention_eager
# import torch._dynamo
# torch._dynamo.config.cache_size_limit = 4


# import torch._dynamo
# torch._dynamo.config.cache_limit = 64
flex_attention_unc = torch.compile(
    flex_attention_eager,
)



@torch.compile
def rotate_half(h_in):
    return torch.cat((
        -h_in[...,h_in.size(-1)//2:], h_in[...,:h_in.size(-1)//2]
    ), dim=-1)
def memory_efficient_attention_inner(query, key, value, mask=None, offset_matrix=None, pos_emb=None):

    if offset_matrix is None:
        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]):
            return scaled_dot_product_attention(
                query=query, key=key, value=value, attn_mask=mask.unsqueeze(dim=1)
            )

    else:

        offset_dist = offset_matrix#.clamp(min=-128, max=128)
        backwards_prot = (query.size(1) * 2) // 2
        forward_prot = query.size(1) - backwards_prot
        alibi_base_backwards = torch.exp2(
            -torch.arange(1., backwards_prot + 1, device=query.device, dtype=torch.float32) * 8 / backwards_prot)
        # alibi_base_forwards = torch.exp2(
        #     -torch.arange(1., forward_prot + 1, device=query.device, dtype=torch.float32) * 8 / forward_prot)
        # alibi_base = torch.cat((-alibi_base_backwards, torch.zeros_like(alibi_base_forwards))).contiguous()
        alibi_base = -alibi_base_backwards.contiguous()

        q_base = offset_matrix.shape[2]
        pseudo_rotator_x = torch.tanh(pos_emb[0]).reshape(1, query.size(1), 1, query.size(-1)).contiguous()
        pseudo_rotator_y = torch.tanh(pos_emb[1]).reshape(1, query.size(1), 1, query.size(-1)).contiguous()

        key_1 = key * pseudo_rotator_x + rotate_half(key) * pseudo_rotator_y
        key_2 = key * pseudo_rotator_x - rotate_half(key) * pseudo_rotator_y
        key_ = torch.cat((key_1, key_2), dim=-2)
        value_ = torch.cat((value, value), dim=-2)

        def score_mod(
                score: Tensor,
                batch: Tensor,
                head: Tensor,
                q_idx: Tensor,
                k_idx: Tensor
        ):
            # return score - alibi_base[head] * offset_dist_new[batch * b_base + q_idx * q_base + k_idx] / 128.
            return torch.where(
                torch.logical_and(mask[batch, q_idx, k_idx % q_base], ((k_idx - q_base) * offset_dist[batch, q_idx, k_idx % q_base]) >= 0),
                score - torch.abs(alibi_base[head] * offset_dist[batch, q_idx, k_idx % q_base]).to(score.dtype),
                -float("inf"))

            return score_mod

        return flex_attention(
            query=query, key=key_, value=value_, score_mod=score_mod,
        )

def memory_efficient_attention(query, key, value, mask=None, offset_matrix=None, pos_emb=None):
    mask = (mask == 0)
    return memory_efficient_attention_inner(query, key, value, mask, offset_matrix, pos_emb)


def memory_efficient_attention_flex_grouprope(query, key, value, mask=None, offset_matrix=None, pos_emb=None, key_ln=None):
    mask = (mask == 0)

    if offset_matrix is None:
        if key_ln is not None:
            key = key_ln(key)
        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]):
            return scaled_dot_product_attention(
                query=query, key=key, value=value, attn_mask=mask.unsqueeze(dim=1)
            )

    else:

        offset_dist = offset_matrix#.clamp(min=-128, max=128)
        backwards_prot = (query.size(1) * 2) // 2
        alibi_base_backwards = torch.exp2(
            -torch.arange(1., backwards_prot + 1, device=query.device, dtype=torch.float32) * 8 / backwards_prot)
        # alibi_base_forwards = torch.exp2(
        #     -torch.arange(1., forward_prot + 1, device=query.device, dtype=torch.float32) * 8 / forward_prot)
        # alibi_base = torch.cat((-alibi_base_backwards, torch.zeros_like(alibi_base_forwards))).contiguous()
        alibi_base = -alibi_base_backwards.contiguous()

        q_base = offset_matrix.shape[2]
        pseudo_rotator_x = pos_emb[0].reshape(1, query.size(1), 1, query.size(-1)).contiguous()
        pseudo_rotator_y = pos_emb[1].reshape(1, query.size(1), 1, query.size(-1)).contiguous()

        key_1 = key + rotate_half(key) * pseudo_rotator_x
        key_2 = key - rotate_half(key) * pseudo_rotator_x
        # value_1 = value + rotate_half(value) * pseudo_rotator_y
        # value_2 = value - rotate_half(value) * pseudo_rotator_y

        key_ = torch.cat((key_1, key_2), dim=-2)
        if key_ln is not None:
            key_ = key_ln(key)
        value_ = torch.cat((value, value), dim=-2)

        def score_mod(
                score: Tensor,
                batch: Tensor,
                head: Tensor,
                q_idx: Tensor,
                k_idx: Tensor
        ):
            # return score - alibi_base[head] * offset_dist_new[batch * b_base + q_idx * q_base + k_idx] / 128.
            return torch.where(
                torch.logical_and(mask[batch, q_idx, k_idx % q_base], ((k_idx - q_base) * offset_dist[batch, q_idx, k_idx % q_base]) >= 0),
                score * torch.exp(-torch.abs(alibi_base[head] * offset_dist[batch, q_idx, k_idx % q_base]) * 0.7854).to(score.dtype),
                -float("inf"))

            return score_mod

        return flex_attention(
            query=query, key=key_, value=value_, score_mod=score_mod,
            kernel_options={"BLOCK_M": 64, "BLOCK_N": 32}
        )


# memory_efficient_attention_flex_grouprope = \
#     _memory_efficient_attention_flex_grouprope

def memory_efficient_attention_alibi(query, key, value, mask=None, offset_matrix=None, skip_self=False):
    mask = (mask == 0)

    if offset_matrix is None:
        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]):
            return scaled_dot_product_attention(
                query=query, key=key, value=value, attn_mask=mask.unsqueeze(dim=1)
            )
    if skip_self:
        mask = torch.logical_and(
            mask, (offset_matrix != 0)
        )


    offset_dist = offset_matrix.clamp(min=-256, max=256)

    backwards_prot = (query.size(1) * 2) // 2
    forward_prot = query.size(1) - backwards_prot
    base_arc_fact = query.size(1)
    alibi_base_backwards = torch.exp2(
        -torch.arange(1., backwards_prot + 1, device=query.device, dtype=torch.float32) * base_arc_fact / backwards_prot)
    alibi_base_forwards = torch.exp2(
        -torch.arange(1., forward_prot + 1, device=query.device, dtype=torch.float32) * base_arc_fact / forward_prot)
    alibi_base_forwards = torch.zeros_like(alibi_base_forwards)
    alibi_base = torch.cat((-alibi_base_backwards, alibi_base_forwards)).contiguous()

    if query.requires_grad == False or offset_matrix.shape[1] < 128:
        mask_value = torch.finfo(query.dtype).min
        alibi_rect = -torch.abs(torch.einsum("bij,n->bnij", offset_dist.to(torch.float32), alibi_base.to(torch.float32)).to(query.dtype))
        alibi_rect_with_mask = torch.where(
            mask.unsqueeze(dim=1),
            alibi_rect,
            mask_value
        )
        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]):
            return scaled_dot_product_attention(
                query=query, key=key, value=value, attn_mask=alibi_rect_with_mask
            )
    # else:
    #     print("ERROR")
    #     exit()
    # Train

    def score_mod(
        score: Tensor,
        batch: Tensor,
        head: Tensor,
        q_idx: Tensor,
        k_idx: Tensor
    ):
        # return score - alibi_base[head] * offset_dist_new[batch * b_base + q_idx * q_base + k_idx] / 128.
        return torch.where(
            mask[batch, q_idx, k_idx],
            score - torch.abs(alibi_base[head] * offset_dist[batch, q_idx, k_idx]).to(score.dtype),
            -float("inf"))


    value_ = flex_attention(
        query=query, key=key, value=value, score_mod=score_mod,
    )
    # value_ = (value_ + value) / 2
    # with torch.no_grad():
    #
    #     distance_rect = -torch.einsum("n,bqk->bnqk", alibi_base, torch.abs(offset_matrix).clamp(min=0, max=256))

    # output = scaled_dot_product_attention(
    #     query=query, key=key_3, value=value_, attn_mask=distance_rect, is_causal=True
    # )

    output = value_

    return output

def memory_efficient_attention_melibi(query, key, value, mask=None, offset_matrix=None, skip_self=False):
    mask = (mask == 0)

    if offset_matrix is None:
        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]):
            return scaled_dot_product_attention(
                query=query, key=key, value=value, attn_mask=mask.unsqueeze(dim=1)
            )
    if skip_self:
        mask = torch.logical_and(
            mask, (offset_matrix != 0)
        )

    # Train
    offset_dist = offset_matrix #.clamp(min=-64, max=-64)

    backwards_prot = (query.size(1) * 2) // 2
    forward_prot = query.size(1) - backwards_prot
    alibi_base_backwards = torch.exp2(
        -torch.arange(1., backwards_prot + 1, device=query.device, dtype=torch.float32) * 16 / backwards_prot)
    alibi_base_forwards = torch.exp2(
        -torch.arange(1., forward_prot + 1, device=query.device, dtype=torch.float32) * 16 / forward_prot)
    alibi_base = torch.cat((-alibi_base_backwards, alibi_base_forwards)).contiguous()

    def score_mod(
        score: Tensor,
        batch: Tensor,
        head: Tensor,
        q_idx: Tensor,
        k_idx: Tensor
    ):
        # return score - alibi_base[head] * offset_dist_new[batch * b_base + q_idx * q_base + k_idx] / 128.
        return torch.where(
            mask[batch, q_idx, k_idx],
            score * torch.exp(-torch.abs(alibi_base[head] * offset_dist[batch, q_idx, k_idx])).to(score.dtype),
            -float("inf"))


    value_ = flex_attention(
        query=query, key=key, value=value, score_mod=score_mod,
        kernel_options={"BLOCK_M": 64, "BLOCK_N": 32}
    )
    # value_ = (value_ + value) / 2
    # with torch.no_grad():
    #
    #     distance_rect = -torch.einsum("n,bqk->bnqk", alibi_base, torch.abs(offset_matrix).clamp(min=0, max=256))

    # output = scaled_dot_product_attention(
    #     query=query, key=key_3, value=value_, attn_mask=distance_rect, is_causal=True
    # )

    output = value_

    return output

memory_efficient_attention_flex = memory_efficient_attention_melibi