import torch
import numpy as np
import numbers
from tqdm import tqdm
from sklearn.kernel_approximation import Nystroem

def normalized_gaussian_kernel(x, y, sigma, batchsize):
    batch_num = (y.shape[0] // batchsize) + 1
    assert (x.shape[1:] == y.shape[1:])

    total_res = torch.zeros((x.shape[0], 0), device=x.device)
    for batchidx in range(batch_num):
        y_slice = y[batchidx*batchsize:min((batchidx+1)*batchsize, y.shape[0])]
        res = torch.norm(x.unsqueeze(1)-y_slice, dim=2, p=2).pow(2)
        res = torch.exp((- 1 / (2*sigma*sigma)) * res)
        total_res = torch.hstack([total_res, res])

        del res, y_slice

    total_res = total_res / np.sqrt(x.shape[0] * y.shape[0])

    return total_res

def nyostrom_kernel(x_feats, kernel, t, sigma=None):
    assert kernel in ['gaussian', 'cosine']
    assert isinstance(t, int)
    
    if t > len(x_feats):
        t = len(x_feats)
        
    if kernel == 'gaussian': 
            kernel = 'rbf'
            sigma = 1 / (2 * sigma**2)
            
    K = None
    feature_map_nystroem = Nystroem(kernel=kernel,
                            gamma=sigma,
                            random_state=1,
                            n_components=t)
    data_transformed = feature_map_nystroem.fit_transform(x_feats)
    K = data_transformed.T @ data_transformed 
    K = K / x_feats.shape[0]
    K = torch.from_numpy(K)
    
    return K

def cosine_kernel(x, y):
    total_res = torch.zeros((x.shape[0], y.shape[0]), device=x.device)
    for i in tqdm(range(x.shape[0])):
        for j in range(y.shape[0]):
            total_res[i][j] = torch.nn.functional.cosine_similarity(x[i], y[j], dim=0)
    return total_res / np.sqrt(x.shape[0] * y.shape[0])

def calculate_stats(eigenvalues, alpha=2, t=None):
    epsilon = 1e-10

    eigenvalues = torch.clamp(eigenvalues, min=epsilon)
    eigenvalues, _ = torch.sort(eigenvalues, descending=True)
    
    if isinstance(t, int) and t < len(eigenvalues):
        tail = 1 - torch.sum(eigenvalues[:t])
        
        eigenvalues = torch.add(eigenvalues[:t], tail/t)
            

    log_eigenvalues = torch.log(eigenvalues)
    
    if alpha == 1:
        entanglement_entropy = -torch.sum(eigenvalues * log_eigenvalues)
        vendi = torch.exp(entanglement_entropy)
        return np.around(vendi.item(), 2)
    else:
        entropy = (1 / (1-alpha)) * torch.log(torch.sum(eigenvalues**alpha))
        entropy = torch.exp(entropy)
        
        return np.around(entropy.item(), 2)


def compute_score(features, alpha, t=None, sigma=None, kernel='gaussian', is_nyostrom=False, batchsize=16):
    assert kernel in ['gaussian', 'cosine']
    
    if is_nyostrom and isinstance(t, int):
        if kernel == 'gaussian':
            assert isinstance(sigma, numbers.Number)
            K = nyostrom_kernel(features, kernel, t, sigma)
        if kernel == 'cosine':
            K = nyostrom_kernel(features, kernel, t)
    else:
        if kernel == 'gaussian':
            assert isinstance(sigma, numbers.Number)
            K = normalized_gaussian_kernel(features, features, sigma, batchsize)
        if kernel == 'cosine':
            K = cosine_kernel(features, features)

    
    eigs, eigv = torch.linalg.eigh(K)
    
    return calculate_stats(eigs, alpha, t)

#%% Load features
feats = torch.load('path_to_feats')

# Examples of computing truncated VENDI, Nyostrom VENDI
print(compute_score(feats, alpha=2, sigma=30, kernel='cosine'))
print(compute_score(feats, alpha=2, t=990, sigma=30, kernel='cosine', is_nyostrom=True))
print(compute_score(feats, alpha=1, sigma=30, kernel='gaussian'))
print(compute_score(feats, alpha=1, t=990, sigma=30, kernel='gaussian', is_nyostrom=True))
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    