"""Normalized kernelized with RPE."""

import torch
from torch._C import device
import torch.fft
from torch.nn import Module
import torch.autograd.profiler as profiler

from ..attention_registry import AttentionRegistry, Optional, Callable, Int
from ..feature_maps import PositiveRandomFeatures

class NKARPEAttention(Module):
    """Implement Normalized Kernelized Attention with RPE
    
    Arguments
    ---------
        feature_map: callable, a callable that applies the feature map to the
                     last dimension of a tensor (default: PRF)
        eps: float, a small number to ensure the numerical stability of the
             denominator (default: 1e-6)
    """
    def __init__(self, query_dimensions, feature_map=None, n_dims=16, eps=1e-6):
        super(NKARPEAttention, self).__init__()
        self.feature_map = (
            feature_map(query_dimensions) if feature_map else
            PositiveRandomFeatures(query_dimensions, n_dims=n_dims)
        )
        self.eps = eps

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths, rpe):
        # Normalize the queries and keys
        queries = queries / (queries.norm(dim=-1, keepdim=True) + 1e-8)
        keys = keys / (keys.norm(dim=-1, keepdim=True) + 1e-8)
        # Apply the feature map to the queries and keys
        self.feature_map.new_feature_map(queries.device)
        Q = self.feature_map.forward_queries(queries)
        K = self.feature_map.forward_keys(keys)
        tgt_len = K.shape[1]
        # Apply the key padding mask and make sure that the attn_mask is all_ones
        if not attn_mask.all_ones:
            raise RuntimeError(("Kernelized Attention does not support arbitrary "
                                "attention masks"))
        K = K * key_lengths.float_matrix[:, :, None, None]
        K2 = K.permute(0, 2, 3, 1)

        # Compute the KV matrix
        KV = torch.einsum("nhdl,nlhm->nhmdl", K2, values)
        
        # Efficient matrix multiplication
        u = torch.fft.rfft(rpe, dim=-1)             #rpe.shape = [num_heads, 2*tgt_len]
        
        y = torch.fft.rfft(KV, n=2*tgt_len, dim=-1) #KV.shape  = [bsz, num_heads, v_dim, k_dim, tgt_len]            
        y = torch.einsum("hl,nhmdl->nhmdl", u, y)
        weighted_KV = torch.fft.irfft(y, dim=-1)[:, :,:,:,tgt_len:]

        y1= torch.fft.rfft(K2, n=2*tgt_len, dim=-1) #K2.shape  = [bsz, num_heads, k_dim, tgt_len]
        y1 = torch.einsum("hl,nhdl->nhdl", u, y1)
        weighted_K = torch.fft.irfft(y1 ,dim=-1)[:, :,:,tgt_len:]
    
        # Compute the normalizer
        Z = 1/(torch.einsum("nlhd,nhdl->nlh", Q, weighted_K) + self.eps)
    
        # Finally compute and return the new values
        # Equivalent to V = torch.einsum("nlhd,nhmdl,nhl->nlhm", Q, weighted_KV, Z)
        V = final_compute(Q, weighted_KV, Z)

        return V.contiguous()


# Register the attention implementation so that it becomes available in the builders
AttentionRegistry.register(
    "nkarpe", NKARPEAttention,
    [
        ("query_dimensions", Int),
        ("n_dims", Int),
        ("feature_map", Optional(Callable)),
    ]
)

def final_compute(Q, KV, Z):
    KV = KV.permute(0, 4, 1, 2, 3)
    Q = Q.unsqueeze(4)
    dn, dl, dh, dm, dd = KV.size()
    KVQ = torch.bmm(KV.contiguous().view(dn * dl * dh, dm, dd),
        Q.view(dn * dl * dh, dd, 1)).view(dn, dl, dh, dm)
    return KVQ * Z.unsqueeze(3).expand((dn, dl, dh, dm))