import torch

import numpy as np
from sklearn.decomposition import PCA, randomized_svd
from scipy.spatial.distance import cdist as scipy_cdist


def find_pca(x, ratio=0.95):
    n, m = x.shape
    u, s, v = randomized_svd(x, n_components=m)  # full svd
    r = s**2 / (s**2).sum()
    r_prefix = np.cumsum(r)
    dim = np.where(r_prefix>=ratio)[0][0] + 1
##    x_hat = x @ (v.T)[:, :dim]
    x_hat = u[:, :dim] @ np.diag(s[:dim])
    return x_hat, dim, v


def find_out_cdist(x):
    X_test = x
    mu = np.mean(X_test, axis=0)
    std = np.std(X_test, axis=0)
    X_test = (X_test - mu) / (std + 1e-9)

    # cdist = calculate_dist(X_test, normalized=False)
    cdist = scipy_cdist(X_test, X_test).astype(np.float16)
    return cdist


seed = 0
for d_name in ['mnist']:
    for j in range(5, 6):
        for i in range(0, 252):
            cdist, x, y, _ = torch.load(f"./features/{d_name}_{j}class_comb{i}_seed{seed}_clip_cdist_3000.tar")

##            pca = PCA(512, svd_solver='randomized')
##            pca.fit(x)

##            u, s, v = randomized_svd(x, n_components=512)
            _, dim, __ = find_pca(x)
            print(f'{d_name}_{j}class_comb{i}_seed{seed}: {dim}')
            break

for d_name in ['mnist_group2', 'mnist_group1', 'fmnist_group2', 'fmnist_group1', 'cifar10_group1', 'cifar10_group2']:
    cdist, x, y = torch.load(f"./features/{d_name}_clip_cdist_3000.tar")
    x_reduced, dim, v = find_pca(x)
    cdist_reduced = find_out_cdist(x_reduced)
    torch.save((cdist_reduced, x, y), f"./features/{d_name}_clip_95pcas_cdist_3000.tar")
    
            
            
            

            
