import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

def split_heads(x,num_heads,head_dim):
        newshape = x.shape[:-1] + (num_heads, head_dim)
        x = torch.reshape(x, newshape)
        if x.ndim == 5:
            # [batch, blocks, head, block_len, head_dim]
            return x.permute(0, 1, 3, 2, 4)
        elif x.ndim == 4:
            # [batch, head, seq_len, head_dim]
            return x.permute(0, 2, 1, 3)
        else:
            raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.')

def attention(query, key, value, casual_mask, masked_bias, dropout, scale_attn_weights, training, attn_mask=None, head_mask=None, feedback=None):
    """
    Computes Dot-Product Attention for the given query, key and value.
    
    Args:
        query (tensor): Query, shape [B, num_heads, seq_len, embd_dim].
        key (tensor): Key, shape [B, num_heads, seq_len, embd_dim].
        value (tensor): Value, shape [B, num_heads, seq_len, embd_dim].
        casual_mask (tensor): Mask to ensure that attention is only applied to the left of the input sequence, 
                              shape [1, 1, key_len - query_len :key_len, :key_len].
        masked_bias (float): Value to insert for masked part of the sequence.
        dropout (nn.Dropout): Dropout module that is applied to the attention output.
        scale_attn_weights (bool): If True, scale the attention weights.
        training (bool): Training mode.
        attn_mask (tensor): Mask to avoid performing attention on padded tokens indices, shape [B, seq_len].
        head_mask (tensor): Mask to nullify selected heads of the self-attention modules, shape [num_heads,] or [num_layers, num_heads].
        feedback (tensor): external feedback with marked points.

    Returns:
        (tensor): Attention output, shape [B, num_heads, seq_len, embd_dim].
        (tensor): Attention weights, shape [B, num_heads, seq_len, seq_len].
        (tensor): KLD loss with external feedback, float.
    """
    query = query.to(torch.float32)
    key = key.to(torch.float32)
    #attn_weights = jnp.matmul(query, jnp.swapaxes(key, -1, -2))
    attn_weights = torch.matmul(query, key.transpose(-1, -2))
    
    if scale_attn_weights:
        attn_weights = attn_weights / (float(value.shape[-1]) ** 0.5)

    attn_weights = torch.where(casual_mask, attn_weights, masked_bias)

    if attn_mask is not None:
        attn_weights = attn_weights + attn_mask
   
    _attn_weights = F.softmax(attn_weights, dim=-1)
    attn_weights = _attn_weights.to(value.dtype)
    attn_weights = dropout(attn_weights)

    if head_mask is not None:
        attn_weights = attn_weights * head_mask

    out = torch.matmul(attn_weights, value)
    return out, _attn_weights 

def merge_heads(x, num_heads, head_dim):
    """
    Merge embeddings for different heads.

    Args:
        x (tensor): Input tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim].
        num_heads (int): Number of heads.
        head_dim (int): Dimension of embedding for each head.

    Returns:
        (tensor): Output tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim].
    """
    if x.ndim == 5:
        #x = jnp.transpose(x, axes=(0, 1, 3, 2, 4))
        x.permute(0, 1, 3, 2, 4)
    elif x.ndim == 4:
        x.permute(0, 2, 1, 3)
    else:
        raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.')

    # newshape = x.shape[:-2] + (num_heads * head_dim,)
    # x = torch.reshape(x, newshape)
    # 计算 embd_dim
    num_head = x.size(1)
    head_dim = x.size(3)
    embd_dim = num_head * head_dim

    # 将输入张量重塑为 [B, seq_len, embd_dim]
    x = x.view(x.size(0), x.size(2), embd_dim)

    return x