import math
import torch
from einops import rearrange

from flash_attn_triton_for_hyper import flash_attn_func


from inspect import currentframe, getframeinfo

frameinfo = getframeinfo(currentframe())

def get_linenumber():
    cf = currentframe()
    return cf.f_back.f_lineno


def indexing(x, indices):
    """ 
    inputs:
        - x: 4d-tensor with shape [b, h, n, d] 
        - indices: 3d-tensor with shape [b, h, s] where each entry should be in [0, n-1]
    output:
        - out: 4d-tensor with shape [b, h, s, d] where out[i,j] = x[i,j][indices[i,j],:]
    
    A naive implementation:
        out = torch.zeros(b, h, s, d)
        for i in range(b):
            for j in range(h):
                out[i,j] = x[i,j][idx[i,j],:]
        return out
    """
    return x.gather(2, indices.unsqueeze(-1).expand(-1, -1, -1, x.shape[-1]))


def exact_attention(query, key, value, softmax_scale, causal=False, bias=None, normalize=True):
    out, lse = flash_attn_func(
        query.transpose(1,2), key.transpose(1,2), value.transpose(1,2),
        bias, causal, softmax_scale)
    out = out.transpose(1,2)
    if normalize:
        return out, None
    else:
        lse = lse.detach()
        if lse.shape[2] != out.shape[2]:
            lse = lse[:,:,:out.shape[2]]
        lse = lse.unsqueeze(-1)
        return out, lse
    

def add_self_attentions(attn1, lse1, attn2, lse2):
    """
    inputs:
        - attn1, attn2: 4d-tensors with shape [b, h, n, d]
        - lse1, lse2: 4d-tensors with shape [b, h, n, 1]
    output:
        - attn
        = (attn1 * exp(lse1) + attn2 * exp(lse2)) / (exp(lse1) + exp(lse2))
        = (attn1 + attn2 * exp(lse2 - lse1)) / (1 + exp(lse2-lse1))
        = attn1 * c + attn2 * (1-c), where c=1/(1 + exp(lse2-lse1)),
        - lse 
        = log(exp(lse1) + exp(lse2)) 
        = log(exp(lse1) * (1 + exp(lse2 - lse1))) 
        = lse1 + log(1 + exp(lse2 - lse1)) = lse1 - log(c)
    """
    c = (1 / (1 + (lse2 - lse1).exp())).to(dtype=attn1.dtype)
    attn = c * attn1 + (1-c) * attn2
    lse = lse1 - (c + torch.finfo(lse1.dtype).eps).log()
    return attn, lse
    # attn1 *= c
    # attn1 += (1-c)*attn2
    # lse1 -= (c + torch.finfo(lse1.dtype).eps).log()
    # return attn1, lse1


class AngularLSH(torch.nn.Module):

    def __init__(self, num_projs, dim, rng=None):
        super().__init__()
        self.num_projs = num_projs

        if num_projs > 0:
            self.register_buffer('proj_dir', torch.randn(dim + (num_projs,), generator=rng), persistent=False)
            self.register_buffer('perm', self._unit_hamming_distance_array(self.num_projs), persistent=False)
            self.register_buffer('enc_vec', 2 ** torch.arange(self.num_projs).view(1, 1, 1, -1), persistent=False)
        else:
            raise ValueError("Invaid value for num_projs")
            
    def _unit_hamming_distance_array(self, size_n):
        if size_n == 1:
            return torch.tensor([0, 1])
        a = self._unit_hamming_distance_array(size_n - 1)
        return torch.concat([a, torch.flip(a, dims=[0]) + 2 ** (size_n - 1)], 0)

    def hash(self, mat):
        mask = torch.einsum('...nd,...dr -> ...nr', mat, self.proj_dir)
        mask = mask > 0
        bin_ids = (mask * self.enc_vec).sum(-1)
        return self.perm[bin_ids]
    
    def __repr__(self):
        return f"AngularLSH(num_proj={self.num_projs}, proj_dir.shape={self.proj_dir.shape})"



class HyperAttention(torch.nn.Module):

    def __init__(self, n_heads=1, dim=64, num_projs=7, 
                min_bucket_size=128, max_bucket_size=512, bucket_size_ratio=1/32, 
                min_sample_size=128, max_sample_size=256, sample_size_ratio=1/64,
                scale=None, min_seq_len=1024, **kwargs):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.num_projs = num_projs
        self.bucket_size_ratio = bucket_size_ratio
        self.min_bucket_size = min_bucket_size
        self.max_bucket_size = max_bucket_size
        self.sample_size_ratio = sample_size_ratio
        self.min_sample_size = min_sample_size
        self.max_sample_size = max_sample_size
        self.scale = scale
        self.min_seq_len = min_seq_len
        self.lsh = AngularLSH(num_projs=self.num_projs, dim=(1, 1, dim))

        
    def forward(self, query, key, value, scale=None, mask=None, is_causal=False, normalize=True):
        # query, key, value = [x if x.stride(-1) == 1 else x.contiguous() for x in [query, key, value]]
        query = query.contiguous()
        key = key.contiguous()
        value = value.contiguous()

        n_query = query.shape[2]
        batch_size, n_heads, n_key, dim = key.shape
        scale = dim ** (-0.5) if (scale is None and self.scale is None) else scale
        assert n_query == n_key

        if is_causal is False: # without causal masking
            attn, lse = self.forward_no_causal_mask(query, key, value, scale, mask)

        else: # causal masking
            if n_key <= self.min_seq_len:
                attn, lse = exact_attention(query, key, value, scale, causal=True, normalize=normalize)
            else:
                q_bd = query.view(batch_size, 2*n_heads, n_query//2, query.shape[-1])
                k_bd = key.view(batch_size, 2*n_heads, n_key//2, key.shape[-1])
                v_bd = value.view(batch_size, 2*n_heads, n_key//2, value.shape[-1])

                attn_bd, lse_bd = self.forward(q_bd, k_bd, v_bd, scale, mask, True, False)
                attn_bd = attn_bd.view(batch_size, n_heads, n_query, -1)
                lse_bd = lse_bd.view(batch_size, n_heads, n_query, -1)

                attn_unmasked, lse_unmasked = self.forward_no_causal_mask(
                    query[:, :, n_key//2:, :], key[:, :, :n_key//2, :], 
                    value[:, :, :n_key//2, :], scale, mask)
                
                attn_up, lse_up = attn_bd[:,:,:n_query//2,:], lse_bd[:,:,:n_query//2,:]
                attn_down, lse_down = add_self_attentions(attn_bd[:,:,n_query//2:,:], lse_bd[:,:,n_query//2:,:], attn_unmasked, lse_unmasked)

                attn = torch.cat((attn_up, attn_down), dim=-2)
                lse = torch.cat((lse_up, lse_down), dim=-2)

        if normalize:
            return attn
        else:
            return attn, lse

        
    def forward_no_causal_mask(self, query, key, value, scale, mask=None):

        batch_size, head_size, n_query, dim = query.shape
        n_key = key.shape[2]

        # 1. Sorted block-diagonal via SortLSH
        _, query_sort_idx = torch.sort(self.lsh.hash(query), dim=2, stable=True) # batch_size x head_size x n
        _, key_sort_idx = torch.sort(self.lsh.hash(key), dim=2, stable=True)
        query_sort_idx_inv = torch.argsort(query_sort_idx, dim=2, stable=True) # for recovering the row order

        if mask is not None:
            value *= rearrange(mask, "b n->b 1 n 1")

        query_sorted = indexing(query, query_sort_idx)
        key_sorted = indexing(key, key_sort_idx)
        value_sorted = indexing(value, key_sort_idx)

        key_bucket_size = min(max(int(n_key * self.bucket_size_ratio), self.min_bucket_size), self.max_bucket_size) #self.bucket_size
        key_bucket_size = min(key_bucket_size, n_key)
        if key_bucket_size > 0:
            num_blocks = n_key // key_bucket_size
            query_bucket_size = n_query // num_blocks

            ## Reshape tensors to [batch_size*head_size, 1, bucket_size, dim] as Flash-attn only allows 4d-tensors
            query_split_per_block = query_sorted.view(-1, 1, query_bucket_size, query.shape[-1])
            key_split_per_block = key_sorted.view(-1, 1, key_bucket_size, key.shape[-1])
            value_split_per_block = value_sorted.view(-1, 1, key_bucket_size, value.shape[-1])

            attn_block, lse_block = exact_attention(
                query_split_per_block, key_split_per_block, value_split_per_block,
                softmax_scale=scale, normalize=False)

            attn_block = attn_block.view(batch_size, head_size, n_query, -1)
            lse_block = lse_block.view(batch_size, head_size, n_query, -1)
        else:
            print("check this out!!")
            query_bucket_size = -1
            query_bucket_size = -1
            attn_block, lse_block = 0, 0

        # 2. Residual low-rank part via uniform sampling
        # Sample indices uniformly at random
        # sample_size = -1
        sample_size = min(max(int(n_key * self.sample_size_ratio), self.min_sample_size), self.max_sample_size)
        if sample_size > 0 and (n_query > query_bucket_size) and (n_key > key_bucket_size):
            sampled_set = torch.randint(n_key, size=(batch_size, head_size, sample_size), device=query.device)

            # if sample_size == 1234: #
            #     from einops import repeat
            #     sampled_set = repeat(torch.arange(n_key, device=query.device), 'n -> b h n', b=batch_size, h=head_size)
            #     sample_size = n_key
            # else:
            #     sampled_set = torch.randint(n_key, size=(batch_size, head_size, sample_size), device=query.device)

            ## Compute mask for hiding A_ij computed in block-diagonal attention
            offset_n = rearrange(torch.arange(n_query, device=query.device), 'n -> 1 n 1')
            block_mask = (offset_n // query_bucket_size) == (sampled_set // key_bucket_size).view(-1, 1, sample_size)
            block_mask = block_mask.view(batch_size, head_size, -1, sample_size) #rearrange(block_mask, "(b h) n d->b h n d", b=batch_size) # * torch.finfo(query.dtype).min
            block_mask = block_mask.to(query.dtype)
            block_mask *= torch.finfo(query.dtype).min
            weights = n_key / sample_size

            value_subset = indexing(value_sorted, sampled_set)
            key_subset = indexing(key_sorted, sampled_set)

            attn_res, lse_res = exact_attention(query_sorted, key_subset, value_subset, scale, False, block_mask, normalize=False)
            lse_res = lse_res + math.log(weights)
            ## Add two attentions
            if key_bucket_size > 0: #attn_block != 0 and lse_block != 0:
                attn, lse = add_self_attentions(attn_block, lse_block, attn_res, lse_res)
            else:
                attn, lse = attn_res, lse_res
        else:
            attn, lse = attn_block, lse_block

        ## Re-order rows with the inverse order for query_sorted -> query
        attn = indexing(attn, query_sort_idx_inv)
        lse = indexing(lse, query_sort_idx_inv)
        return attn, lse