from re import A
from typing import Final, Optional, Type

import torch
from torch import nn as nn
from torch.nn import functional as F
import einops


from .contextualization import Contextualization


def create_sliding_window_matrix(n: int, m: int, device) -> torch.Tensor:
    """
    Creates an (n, m) matrix with sliding window values efficiently.

    Args:
        n: The number of rows.
        m: The number of columns. Must be an even number.

    Returns:
        A torch.Tensor of shape (n, m) with the specified values.
    """
    # Ensure m is even as per the problem description
    if m % 2 != 0:
        raise ValueError("m must be an even integer.")

    # 1. Create the base row: [-m/2+1, -m/2+2, ..., m/2]
    # torch.arange is highly optimized for creating integer sequences [8]
    start_val = -m // 2 + 1
    base_row = torch.arange(start_val, start_val + m).to(device)

    # 2. Create the column offsets: [0, 1, ..., n-1]
    # .unsqueeze(1) converts the row tensor of shape [n] to a column of shape [n, 1]
    row_offsets = torch.arange(n).to(device).unsqueeze(1)

    # 3. Add the base row and column offsets.
    # PyTorch uses broadcasting to add the row_offsets to each column of the base_row.
    # This is a highly efficient, vectorized operation [2][4].
    matrix = base_row + row_offsets
    
    return matrix


class InternalMemory(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        head_dim: int = 64,
        num_memory_cells: int = 64,
        conceptual_representation_size: int = 8,
        topk: int = 8,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.sqrt_num_memory_cells = int(num_memory_cells ** 0.5)
        self.num_memory_cells = self.sqrt_num_memory_cells ** 2
        self.conceptual_representation_size = conceptual_representation_size
        self.topk = topk

        # asserts anmd checks
        assert head_dim % 2 == 0, "Head dimension must be even for memory attention"
        #assert (num_memory_cells ** 0.5).is_integer(), "Number of memory cells must be a perfect square"

        self.internal_search_pattern_query = nn.Parameter(
            torch.empty(num_heads, self.conceptual_representation_size, self.head_dim),
        )

        self.memory_static_keys = nn.Parameter(
            torch.empty(2 * self.sqrt_num_memory_cells, self.head_dim // 2),
        )

        self.qkv = nn.Embedding(self.num_memory_cells, (self.head_dim * 3))

        # Initialize the parameters
        self.reset_parameters()


    def get_trainable_params(self):
        return [
            self.internal_search_pattern_query,
            self.memory_static_keys,
            self.qkv.weight,
        ]

    def reset_parameters(self):
        nn.init.uniform_(self.internal_search_pattern_query, -self.head_dim ** -0.5, self.head_dim ** -0.5)
        nn.init.uniform_(self.memory_static_keys, -self.head_dim ** -0.5, self.head_dim ** -0.5)
        nn.init.normal_(self.qkv.weight, 0, self.head_dim ** -0.5)


    def build_mixed_search_patterns(
        self,
        base_search_pattern_keys, #shape [B, h, t, d]
        base_search_patterns, #shape [B, h, t, d]
        kv_padding_mask: Optional[torch.Tensor] = None,

    ):
        scores = einops.einsum(
            base_search_pattern_keys, self.internal_search_pattern_query,
            'b h t d, h m d -> b h m t'
        ) * self.head_dim ** -0.5

        # Apply mask if provided
        if kv_padding_mask is not None:
            # kv_padding_mask shape is [B, t], we need to expand it to [B, 1, 1, t] to match scores shape
            mask = kv_padding_mask.unsqueeze(1).unsqueeze(1)
            scores = scores.masked_fill(mask, float('-inf'))
        scores = scores.softmax(dim=-1)  # Normalize the scores, shape [B, h, m, t]
        mixed_search_patterns = einops.einsum(
            scores, base_search_patterns,
            'b h m t, b h t d -> b h m d'
        )
        return mixed_search_patterns


    def gumbel_max_sample_from_probs(self, probs):
        # Ensure no zeros in log (add small epsilon if needed)
        eps = 1e-10
        logits = torch.log(probs + eps)
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs)))
        return (logits + gumbel_noise).argmax(dim=-1)


    def search_memory(
        self,
        search_patterns, #shape [B, h, m, d]
    ): # returns query and value tensor of  shape [B, h, m, d] (retrieved from memory)
        batch_size =  search_patterns.shape[0]
        search_patterns = einops.rearrange(search_patterns, 'b h m d -> (b h m ) d') # [b, d]
        half = self.head_dim // 2
        
        search_patterns1 = search_patterns[:, :half] # [b, half]
        search_patterns2 = search_patterns[:, half:]

        keys1 = self.memory_static_keys[:self.sqrt_num_memory_cells,:]
        keys2 = self.memory_static_keys[self.sqrt_num_memory_cells:,:]
        scores1 = einops.einsum(
            search_patterns1, keys1,
            'b d, s d -> b s'
        )
        scores2 = einops.einsum(
            search_patterns2, keys2,
            'b d, s d -> b s'
        )
        scores1, indices1 = torch.topk(scores1, self.topk, dim=-1) # [B, topk]
        scores2, indices2 = torch.topk(scores2, self.topk, dim=-1) # [B, topk]
        
        scores1 = einops.rearrange(scores1, 'b (s1 s2)-> b s1 s2', s1 = self.topk, s2 = 1)
        scores2 = einops.rearrange(scores2, 'b (s1 s2)-> b s2 s1', s1 = self.topk, s2 = 1)
        scores = scores1 + scores2 # [B, s1, s2] where s1 and s2 are the conceptual representation size
        scores = einops.rearrange(scores, 'b s1 s2 -> b (s1 s2)') # [B, s1 * s2]
        
        indices1 = einops.rearrange(indices1, 'b (s1 s2)-> b s1 s2', s1 = self.topk, s2 = 1)
        indices2 = einops.rearrange(indices2, 'b (s1 s2)-> b s2 s1', s1 = self.topk, s2 = 1)
        indices = indices1 * self.sqrt_num_memory_cells + indices2 # [B, s1, s2]
        indices = einops.rearrange(indices, 'b s1 s2 -> b (s1 s2)') # [B, s1 * s2]

        scores, best_indices = torch.topk(scores, self.topk, dim=-1) # [B, topk]
        scores = scores.softmax(dim=-1) # [B, s1 * s2]
        indices = indices.gather(dim=-1, index=best_indices) # [B, topk]

        qkvs = self.qkv(indices) # [B, ,topk, 3*head_dim]

        # reorder queries and values
        qkvs = einops.einsum(scores, qkvs, 'b k, b k d -> b d')
        qkvs = einops.rearrange(qkvs, '(b h m) d -> b h m d', b=batch_size, h=self.num_heads, m=self.conceptual_representation_size)
        # split into queries and values
        queries, keys, values = torch.split(qkvs, 3 * [self.head_dim], dim=-1) # shape [B, h, m, d]
        return queries, keys, values


    def forward(
        self,
        base_search_pattern_keys, #shape [B, h, t, d]
        base_search_patterns, #shape [B, h, t, d]
        kv_padding_mask: Optional[torch.Tensor] = None,
    ):
        search_patterns = self.build_mixed_search_patterns(
            base_search_pattern_keys, #shape [B, h, t, d]
            base_search_patterns, #shape [B, h, t, d]
            kv_padding_mask = kv_padding_mask
        ) # shape [B, h, m, d]

        return self.search_memory(search_patterns) # shape [B, h, m, d]


class Manar(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            num_memory_cells: int = 64,
            conceptual_representation_size: int = 8,
            context_window_len: int = 64,
            qkv_bias: bool = False,
            search_patterns_bias: bool = False,
            qk_norm: bool = False,
            scale_norm: bool = False,
            proj_bias: bool = True,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: Optional[Type[nn.Module]] = None,
            seperate_qkv: bool = False,
    ) -> None:
        """Initialize the Attention module.

        Args:
            dim: Input dimension of the token embeddings
            num_heads: Number of attention heads
            qkv_bias: Whether to use bias in the query, key, value projections
            qk_norm: Whether to apply normalization to query and key vectors
            proj_bias: Whether to use bias in the output projection
            attn_drop: Dropout rate applied to the attention weights
            proj_drop: Dropout rate applied after the output projection
            norm_layer: Normalization layer constructor for QK normalization if enabled
        """
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        if qk_norm or scale_norm:
            assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
        assert context_window_len % 2 == 0
        self.num_heads = num_heads
        self.dim = dim
        self.num_memory_cells = num_memory_cells
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.context_window_len = context_window_len #TODO: fix it to make it dynamic

        self.dummy_padding = torch.zeros([512,512], device="cuda" if torch.cuda.is_available() else "cpu") == 0

        self.memory = InternalMemory(
            dim=dim,
            num_heads=num_heads,
            head_dim=self.head_dim,
            num_memory_cells=num_memory_cells,
            conceptual_representation_size=conceptual_representation_size,
        )

        self.seperate_qkv = seperate_qkv
        if self.seperate_qkv:
            self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
            self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
            self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
        else: 
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.search_patterns = nn.Linear(dim, dim * 3, bias=search_patterns_bias) # for the base search patterns
        self.concept_keys = nn.Linear(self.head_dim, self.head_dim)

        # norm layers
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.norm = norm_layer(dim) if scale_norm else nn.Identity()

        self.proj = nn.Linear(dim, dim, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)

        self.mask = None


    def get_trainable_params(self):
        return [
            self.search_patterns.weight,
            self.concept_keys.weight,
            *self.memory.get_trainable_params(),
        ]

    def _pad_and_unfold(self, x: torch.Tensor, context_window_len: int, pad_value: int = 0, pad_dim: int = -1) -> torch.Tensor:
        """Pad and unfold the input tensor to create a context window."""
        assert context_window_len > 0, "Context window length must be positive"
        if context_window_len > 1:
            assert context_window_len % 2 == 0, "Context window length must be either 1 or even for symmetric padding"
            # pad the input tensor
            if pad_dim < 0:
                pad_dim = x.ndim + pad_dim
            zeros_to_append = 2*(x.ndim - pad_dim - 1)
            pad = (0,) * zeros_to_append + (context_window_len // 2 - 1, context_window_len // 2)
            x = F.pad(x, pad, value=pad_value)
            # unfold the tensor to create a sliding window
            x = x.unfold(
                dimension=pad_dim,
                size=context_window_len,
                step=1
            )
        else:
            x = x.unsqueeze(-1)  # if context_window_len is 1, just add a dimension
        return x # shape (..., cwl)


    def _make_cat_compatible(self, A: torch.Tensor, B: torch.Tensor):
        shared = torch.max(torch.tensor(A.shape[:-1]), torch.tensor(B.shape[:-1]))
        A_new_shape = torch.cat([shared, torch.tensor(A.shape[-1]).reshape(1)])
        B_new_shape = torch.cat([shared, torch.tensor(B.shape[-1]).reshape(1)])
        return A.expand(A_new_shape.tolist()), B.expand(B_new_shape.tolist())


    def create_mask(self, score: torch.Tensor):
        # The tensor `score` is expected to be of shape (..., n, m) where n is the number of queries and m is the number of keys.
        sequence_length = score.shape[-2]
        sliding_window = create_sliding_window_matrix(score.shape[-2], score.shape[-1], device=score.device) # shape (..., n, m)
        mask = torch.where(sliding_window < 0, True, False) # mask out the negative values
        mask = torch.where(sliding_window >= sequence_length, True, mask) # mask out the values greater than sequence length

        # unsqueeze the mask to match the ndim of the score tensor [1, ..., 1, n, m]
        new_shape = (1,) * (score.ndim - 2) + mask.shape
        mask = mask.reshape(new_shape)
        return mask


    def _concatenated_attention(
        self,
        queries: torch.Tensor, # tensor of shape (..., n, d)
        self_keys: torch.Tensor, # tensor of shape (..., n, d)
        self_values: torch.Tensor, # tensor of shape (..., n, d)
        external_keys: torch.Tensor, # tensor of shape (..., m, d)
        external_values: torch.Tensor, # tensor of shape (..., m, d)
        external_kv_padding_mask: Optional[torch.Tensor] = None, # tensor of shape (..., m)
        context_window_len: int = 1, # length of the context window for local attention
    ):
        # all shapes must be broadcast compatible before passed to this function
        self_values = self._pad_and_unfold(self_values, context_window_len, pad_dim=-2, pad_value=0) # (..., n, d, cwl)
        self_keys = self._pad_and_unfold(self_keys, context_window_len, pad_dim=-2, pad_value=0) # (..., n, d, cwl)
        # make self_keys and queries compatible for einsum
        self_score = einops.einsum(queries, self_keys, '... n d, ... n d cwl -> ... n cwl') * self.scale # (..., n, cwl)
        if context_window_len > 1:
            if self.mask is None:
                self.mask = self.create_mask(self_score) # (..., n, cwl)
            self_score = self_score.masked_fill(self.mask, float('-inf')) # mask out the negative values

        external_score = einops.einsum(queries, external_keys, '... n d, ... m d -> ... n m') * self.scale #(..., n, m)
        if external_kv_padding_mask is not None:
            # external_kv_padding_mask shape is (..., m), we need to expand it to (..., 1, m) to match external_score shape
            mask = external_kv_padding_mask.unsqueeze(-2).unsqueeze(-2)
            #print(f"The shape of the mask is: {mask.shape}, and the shape of the external_score is: {external_score.shape}")
            #print(f"The shape of the extermal_score is: {external_score.shape}")
            external_score = external_score.masked_fill(mask, float('-inf'))

        # make self_score and external_score compatible for concatenation
        self_score, external_score = self._make_cat_compatible(self_score, external_score) # (..., n, cwl) and (..., n, m)
        score = torch.cat([self_score, external_score], dim=-1) # (..., n, cwl + m)
        score = score.softmax(dim=-1) # (..., n, cwl + m)
        self_score, external_score = torch.split(score, [context_window_len, external_keys.shape[-2]], dim=-1) # (...n,cwl) and (...,n,m)

        # calculate the response
        self_response = einops.einsum(self_score, self_values, '... n cwl, ... n d cwl -> ... n d') # (..., n, d)
        external_response = einops.einsum(external_score, external_values, '... n m, ... m d -> ... n d') # (..., n, d)
        return self_response + external_response


    def forward(
            self,
            x: torch.Tensor,
            attn_mask: Optional[torch.Tensor] = None,
            kv_padding_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        B, N, C = x.shape
        if self.seperate_qkv:
            q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
            k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
            v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        else:
            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        search_patterns = self.search_patterns(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        base_search_pattern_keys, base_search_patterns, value_keys = search_patterns.unbind(0) # shape [B, h, t, d]

        # Build the context of the memory cells (memory cells query the input tokens)
        # 1. Make mem_queries, mem_keys and mem_values compatible for einsum
        mem_queries, mem_keys, mem_values = self.memory(
            base_search_pattern_keys,  # (b, h, t, d)
            base_search_patterns,  # (b, h, t, d)
            kv_padding_mask=kv_padding_mask
        )

        mem_retrieved_concepts = self._concatenated_attention(
            queries=mem_queries, # (b, h, m, d)
            self_keys=mem_keys, # (b, h, m, d)
            self_values=mem_values, # (b, h, m, d)
            external_keys=value_keys, # (b, h, t, d)
            external_values=v, # (b, h, t, d)
            external_kv_padding_mask=kv_padding_mask,
            context_window_len=1
        ) # (b, h, m ,d)

        response = Contextualization.apply(
            q, k, v,
            self.concept_keys(mem_retrieved_concepts), 
            mem_retrieved_concepts,
            self.context_window_len,
            kv_padding_mask != None,
            kv_padding_mask if kv_padding_mask is not None else self.dummy_padding,
        )

        response = einops.rearrange(response, 'b h t d -> b t (h d)')
        return self.proj_drop(self.proj(response))  # (b, t, d)

