import torch
import math
import torch.nn.functional as F
import torch.nn as nn
import torch
import numpy as np
from kernels import hilbert


def hilbert_feature_map(A):


    B, N, d = A.shape

    K = hilbert(A[:,-1,:].unsqueeze(1), A, d )
    K = K[:,:,:-1]/K[:,:,:-1].sum(dim=2, keepdim=True)
    result = torch.bmm(K, A[:,:-1,:])

    return result



def gd_feature_map(A):
    B, N, d = A.shape
    A_transpose = A.transpose(1, 2)
    AA_T = torch.bmm(A[:,-1,:].unsqueeze(1), A_transpose)/math.sqrt(d)
    result = torch.bmm(AA_T, A) / math.sqrt(N)
    return result
