from typing import Callable, Optional, Tuple, Dict
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

# Define the z_order function
def z_order(x):
    min_vals = x.min(dim=1, keepdim=True)[0] 
    max_vals = x.max(dim=1, keepdim=True)[0]
    normalized_data = (x - min_vals) / (max_vals - min_vals)

    grid_size = 4096 # 2**12=4096 2**20=1048576 for million level long context learning
    discretized_data = (normalized_data * (grid_size - 1)).to(torch.int64)

    z = torch.zeros(x.shape[:2], dtype=torch.int64, device=x.device)
    for i in range(x.size(-1)):
        z |= interleave_bits(discretized_data[:,:, i]) << i
    return z

# Define the interleave_bits function
def interleave_bits(x):
    x = (x | (x << 16)) & 0x0000FFFF0000FFFF
    x = (x | (x << 8)) & 0x00FF00FF00FF00FF
    x = (x | (x << 4)) & 0x0F0F0F0F0F0F0F0F
    x = (x | (x << 2)) & 0x3333333333333333
    x = (x | (x << 1)) & 0x5555555555555555
    return x

class MLP(nn.Module):
    def __init__(
        self,
        d_model: int,
        out_dim: int,
        hidden_mult: int=1,
        activation: callable=F.elu,  
        return_residual: bool=False,  
        **kwargs
    ):
        super().__init__()
        in_features, out_features = d_model, out_dim
        hidden_features = d_model * hidden_mult
        self.return_residual = return_residual
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.activation = activation
        self.fc2 = nn.Linear(hidden_features, out_features)

    def forward(self, x):
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)

class OneDAttention(nn.Module):
    def __init__(self, config, onedattn_config):
        super(OneDAttention, self).__init__()
        self.config = config
        self.n_heads = config.num_attention_heads
        self.d_model = config.hidden_size
        if self.d_model % self.n_heads != 0:
            raise ValueError(
                "The hidden size is not divisble by the number of attention heads! Make sure to update them"
            )
        self.head_dim = self.d_model // self.n_heads
        self.rotary_ndims = 2 # int(self.head_size * config.rotary_pct)
        self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
        self._init_rope()

        self.kq_dim = 3 
        self.Wv = nn.Linear(
            self.d_model, self.d_model, bias=config.attention_bias
        )
        self.Wkq = MLP(self.d_model,2*self.kq_dim*self.n_heads)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
        self.attention_dropout = nn.Dropout(config.attention_dropout)
        self.is_causal = True

        self.causal = self.is_causal
        self.eps = onedattn_config.eps
        self.k = onedattn_config.k
        if self.causal:
            self.num_chunk = onedattn_config.num_chunks 
        else:
            self.num_chunk = 1 
        self.num_heads = self.n_heads
        self.gamma_sq = nn.Parameter(torch.rand(1))

    def _init_bias(self, max_positions, device=None):
        self.register_buffer(
            "bias",
            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                1, 1, max_positions, max_positions
            ),
            persistent=False,
        )
        if device is not None:
            self.bias = self.bias.to(device)

    def _init_rope(self):
        if self.config.rope_scaling is None:
            self.rotary_emb = GPTNeoXRotaryEmbedding(
                self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base
            )
        else:
            scaling_type = self.config.rope_scaling["type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = GPTNeoXLinearScalingRotaryEmbedding(
                    self.rotary_ndims,
                    self.config.max_position_embeddings,
                    base=self.config.rotary_emb_base,
                    scaling_factor=scaling_factor,
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = GPTNeoXDynamicNTKScalingRotaryEmbedding(
                    self.rotary_ndims,
                    self.config.max_position_embeddings,
                    base=self.config.rotary_emb_base,
                    scaling_factor=scaling_factor,
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
            
    def forward(
            self,
            hidden_states: torch.FloatTensor,
            attention_mask: torch.FloatTensor,
            position_ids: torch.LongTensor,
            head_mask: Optional[torch.FloatTensor] = None,
            layer_past: Optional[Tuple[torch.Tensor]] = None,
            use_cache: Optional[bool] = False,
            output_attentions: Optional[bool] = False
    ):
        
        # Apply attention-specific projections and rope
        q, k, v, present = self._attn_projections_and_rope(
            hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
        )
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        # Compute attention
        attn_output, attn_weights = self._attn(q, k, v, attention_mask, head_mask)
        attn_output = self.dense(rearrange(attn_output, "... h d -> ... (h d)"))
        
        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs
    
    def _attn_projections_and_rope(
        self,
        hidden_states: torch.FloatTensor,
        position_ids: torch.LongTensor,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        use_cache: Optional[bool] = False,
    ):
        has_layer_past = layer_past is not None

        # Compute QKV
        # Attention heads [batch, seq_len, hidden_size]
        #   --> [batch, seq_len, (3 * 2 * head_size)] & [batch, seq_len, (np * 1 * head_size)]
        kq, v = self.Wkq(hidden_states), self.Wv(hidden_states)

        # [batch, seq_len, (num_heads * 3 * head_size)]
        #   --> [batch, seq_len, num_heads, 3 * head_size]
        # new_kqv_shape = kqv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
        # kqv = kqv.view(*new_kqv_shape)
        v = rearrange(
            v, "... (h d) -> ... h d", d=self.head_dim
        )
        kq = rearrange(
            kq, "... (h d) -> ... h d", h=self.num_heads
        )
        k=kq[...,:self.kq_dim]; q=kq[...,self.kq_dim:]

        # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
        query = q.permute(0, 2, 1, 3)
        key = k.permute(0, 2, 1, 3)
        value = v.permute(0, 2, 1, 3)

        # Compute rotary embeddings on rotary_ndims
        # query_rot = query[..., : self.rotary_ndims]   # 2 dimension
        # query_pass = query[..., self.rotary_ndims :]  # 1 dimension
        # key_rot = key[..., : self.rotary_ndims]
        # key_pass = key[..., self.rotary_ndims :]

        # Compute token offset for rotary embeddings (when decoding)
        #seq_len = key.shape[-2]
        #if has_layer_past:
        #    seq_len += layer_past[0].shape[-2]
        #cos, sin = self.rotary_emb(value, seq_len=seq_len)
        #query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
        #query = torch.cat((query, query_pass), dim=-1)
        #key = torch.cat((key, key_pass), dim=-1)

        # Cache QKV values
        if has_layer_past:
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)
        present = (key, value) if use_cache else None

        return query, key, value, present

    def chunk_causal_sort(self, hash_vals):
        chunk_size = self.chunk_size
        batch_size, seq_len = hash_vals.shape[:2]
        # Sort the hash_vals tensor along the second dimension
        hash_vals_sorted, sorted_indices = torch.sort(hash_vals, dim=1)

        # Initialize the result tensors with large values (padding)
        hash_vals_causal = torch.full((batch_size, seq_len // chunk_size, seq_len), 1e12, dtype=hash_vals.dtype,device=hash_vals.device)
        hash_vals_causal_original_index = torch.full((batch_size, seq_len // chunk_size, seq_len), -seq_len, dtype=sorted_indices.dtype,device=hash_vals.device)

        # Generate a mask for causal sorting
        indices = torch.arange(seq_len, device=hash_vals.device)
        chunk_positions = (torch.arange(1, seq_len // chunk_size + 1, device=hash_vals.device) * chunk_size)
        # mask = (indices.unsqueeze(0) < chunk_positions).int()

        # mask keep the chunked index: index sorted index [2, 3, 1, 0] -> [1, 0], [2, 3, 1, 0]
        mask = sorted_indices.unsqueeze(1).expand(-1, seq_len // chunk_size, -1) < chunk_positions.unsqueeze(0).unsqueeze(-1).detach()
        # mask keep the index values in final result: [0, 1, 2, 3] -> [0,1], [0, 1, 2, 3]
        mask2 = indices.unsqueeze(0).unsqueeze(0).expand(batch_size,seq_len // chunk_size, -1) < chunk_positions.unsqueeze(0).unsqueeze(-1).detach()

        # Apply the mask and fill hash_vals_causal and hash_vals_causal_original_index
        hash_vals_causal[mask2] = hash_vals_sorted.unsqueeze(1).expand(-1, seq_len // chunk_size, -1)[mask].detach()
        hash_vals_causal_original_index[mask2] = sorted_indices.unsqueeze(1).expand(-1, seq_len // chunk_size, -1)[mask].detach()

        del hash_vals_sorted; del sorted_indices; del mask; del mask2
        # torch.cuda.empty_cache()
        return hash_vals_causal.detach(), hash_vals_causal_original_index.detach(), chunk_positions.unsqueeze(0).unsqueeze(-1).detach()


    def _attn(self, q, k, v, attention_mask=None, head_mask=None):
        # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
        # compute causal mask from causal mask buffer
        batch_size, seq_len, num_head, feat_dim = v.shape
        num_attn = 2*self.k; self.chunk_size = math.ceil(seq_len / self.num_chunk)
        original_seq_len = seq_len; seq_len = self.num_chunk * self.chunk_size
        kq = torch.cat([k, q], dim=-1)
        kq = torch.cat([kq, torch.zeros([batch_size, seq_len-original_seq_len, num_head, kq.shape[-1]], device=v.device)], dim=1)
        v = torch.cat([v, torch.zeros([batch_size, seq_len-original_seq_len, num_head, feat_dim], device=v.device)] , dim=1)
        kq = rearrange(
            kq, "b s h d -> (b h) s d", h=num_head
        )
        # kq=kq.view(batch_size*num_head,seq_len,-1)
        kq_dim=kq.shape[-1]//2
        k=kq[:,:,:kq_dim]; q=kq[:,:,kq_dim:]  
        
        kq_z_order = kq.detach().reshape(batch_size*num_head, -1, kq_dim)
        
        ############ multiple z-order mapping, each feature can be important
        num_z_order = 1
        kq_z_order = z_order(kq_z_order).view(batch_size*num_head*num_z_order, seq_len,-1)
        ##############################
        q_z_order = kq_z_order[:,:,0]; k_z_order = kq_z_order[:,:,1]


        hash_vals_causal, hash_vals_causal_original_index, chunk_positions = self.chunk_causal_sort(k_z_order)
        q_z_order = q_z_order.view(batch_size*num_head*num_z_order,  seq_len//self.chunk_size, self.chunk_size)
        # Find the queried indices for each element  batch*num_head, num_chunk, chunk_size
        indices_to_keys_val = torch.searchsorted(hash_vals_causal, q_z_order.contiguous())
        # to index the keys using a self.k size window/find knn
        indices_to_keys_range_idx = indices_to_keys_val.view(batch_size*num_head*num_z_order, seq_len, 1) + torch.arange(-self.k, self.k,device=v.device).reshape(1, 1, -1)  # pos to insert query
        chunk_range = chunk_positions.expand(-1,-1, self.chunk_size).reshape(1, seq_len, 1).detach()
        
        out_range_mask = (indices_to_keys_range_idx <0) + (indices_to_keys_range_idx>chunk_range-1)
        # Clamp the start and end indices to ensure they are within bounds
        indices_to_keys_range_idx = torch.clamp(indices_to_keys_range_idx, min=0)  # attend to position i and its history
        indices_to_keys_range_idx = torch.clamp(indices_to_keys_range_idx, max=seq_len-1)
        #                                                                                 
        attended_indices = torch.gather(hash_vals_causal_original_index, -1, indices_to_keys_range_idx.view(batch_size*num_head*num_z_order, self.num_chunk,self.chunk_size*num_attn)).view(batch_size*num_head*num_z_order, self.num_chunk*self.chunk_size, num_attn)
        causal_range = torch.arange(seq_len, device=v.device).reshape(1, -1, 1) # 
        if self.causal:
            mask = (attended_indices<0) +  (attended_indices>causal_range)
        else:
            mask = (attended_indices<0)
        attended_indices[mask] = 0
        mask = mask + out_range_mask
        v = rearrange(
            v, "b s h d -> (b h) s d"
        )
        
        avg_denom = 1/(1+torch.arange(seq_len, device=v.device).view(1, seq_len, 1))
        k_cum_avg = torch.cumsum(k, dim=1)*avg_denom  # (B * H), S, 1
        v_cum_avg = torch.cumsum(v, dim=1)*avg_denom   # (B * H), S, d  # TODO: 1. without this avg 2. -dist, 3. -log dist
        global_scores = torch.mean((q.unsqueeze(2) - k_cum_avg.unsqueeze(2))** 2,dim=-1)/ float(8) 

        v = v.unsqueeze(1).expand(-1, num_z_order, -1, -1).reshape(batch_size*num_head*num_z_order, seq_len, feat_dim)
        k = k.unsqueeze(1).expand(-1, num_z_order, -1, -1).reshape(batch_size*num_head*num_z_order, seq_len, kq_dim)

        attended_k = torch.gather(k, 1, attended_indices.view(batch_size*num_head*num_z_order, -1, 1).expand(-1, -1, kq_dim)).view(batch_size*num_head*num_z_order, seq_len, num_attn, -1)
        attended_v = torch.gather(v, 1, attended_indices.view(batch_size*num_head*num_z_order, -1, 1).expand(-1, -1, feat_dim)).view(batch_size*num_head*num_z_order, seq_len, num_attn, -1)

        epsilon = self.eps * torch.sigmoid(self.gamma_sq) #(q)) # learnable converge in 22 epochs! # 0.008 converge in 31 epochs # 0.01
        q = q.unsqueeze(1).expand(-1, num_z_order, -1, -1).reshape(batch_size*num_head*num_z_order, seq_len, kq_dim)
        attended_scores = torch.mean((q.unsqueeze(2) - attended_k)**2,dim=-1) / float(8) # (B*H), S, k, 1
        # attended_scores[mask] = -10000.0  # -> inplace operations, ruin the whole updates!
        mask_to_add = torch.zeros_like(attended_scores)

        mask_to_add[mask] =  10000.0 # -10000.0  # try: setting here to 10000, not working  # for t-distribution
        num_b_h = batch_size * num_head
        mask_to_add = rearrange(mask_to_add, "(bh nz) s t-> bh s (t nz)", bh=num_b_h, nz=num_z_order)
        mask_to_add = torch.cat([mask_to_add, torch.zeros([mask_to_add.shape[0],mask_to_add.shape[1],1], dtype=v.dtype,device=v.device)],dim=-1)
        
        attended_scores = rearrange(attended_scores, "(bh nz) s t-> bh s (t nz)", bh=num_b_h, nz=num_z_order) 
        scores = 1/ (epsilon+torch.cat([attended_scores, global_scores], dim=2)+ mask_to_add)  # + mask_to_add the mask ruin everything! 

        attended_v = rearrange(attended_v, "(bh nz) s t d-> bh s (t nz) d", bh=num_b_h, nz=num_z_order, d=feat_dim)

        # delete all the index and mask tensor to free gpu memory
        del kq_z_order; del q_z_order; del k_z_order
        del hash_vals_causal; del hash_vals_causal_original_index; del chunk_positions
        del indices_to_keys_val; del indices_to_keys_range_idx; del chunk_range; del out_range_mask
        del attended_indices; del mask; del mask_to_add
        # torch.cuda.empty_cache()

        attended_v = torch.cat([attended_v, v_cum_avg.unsqueeze(2)], dim=2) 
        attention = scores / torch.sum(scores.squeeze(-1), dim=-1, keepdim=True) 
        attention = rearrange(
            attention, "(b h) l k -> b l h k 1", b=batch_size, h=self.num_heads, l=seq_len,k=2*self.k*num_z_order+1
        )
        attended_v = rearrange(
            attended_v, "(b h) l ...  -> b l h ...", b=batch_size, h=self.num_heads
        )
        output = torch.sum(attended_v * attention, dim=3)  
        return output[:, :original_seq_len], attention[:, :original_seq_len]

        






class GPTNeoXRotaryEmbedding(nn.Module):
    # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len],
            self.sin_cached[:seq_len],
        )


# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
# TODO @gante bring compatibility back
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
    """GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        t = t / self.scaling_factor

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)


class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
    """GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    # copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
    # TODO @gante no longer copied from
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len

        if seq_len > self.max_position_embeddings:
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))
            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
            self.register_buffer("inv_freq", inv_freq, persistent=False)

        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

