import sys
sys.path.append('..')

import torch

import pandas as pd
import numpy as np

from sklearn.decomposition import PCA
import dataset_utils


def get_rank(s, th=0.9):
    s = torch.abs(s)
    cum_s = torch.cumsum(s, dim=0)
    e = cum_s / torch.sum(s)
    return len(e[e<=th])


if __name__ == "__main__":
    for d_name in [
        'Campbell',
             'PBMC68K',
            'Mouse_retina',
             'Baron Human'
    ]:
        features, labels, label_dict = torch.load(f"/mnt/data01/public/aad_data/gene/{d_name}_extracted.tar",
                                      weights_only=False)

        unique, counts = np.unique(labels, return_counts=True)
        to_keep = unique[counts >=  labels.shape[0] * 0.03]
        print(to_keep)
        # Filter the array
        features = features[np.isin(labels, to_keep)]
        labels = labels[np.isin(labels, to_keep)]
        new_dict = dict(zip(to_keep, range(len(to_keep))))
        labels = np.array([new_dict[i] for i in labels])
        #
        print(d_name, features.shape, labels.shape, len(set(labels)))
        # break
        #
        # # pca = PCA(n_components=4096)
        # # features = pca.fit_transform(features)
        # features = torch.from_numpy(features)
        # # u, s, v = torch.linalg.svd(features.cuda())
        # u, s, v = torch.svd_lowrank(features.cuda(), q=4096, niter=7)
        # # r = get_rank(s, th=0.9)
        # r = 4096
        # features = u[:, :r].cpu().numpy()
        #
        n_class = len(np.unique(labels))
        print(n_class)
        print(np.unique(labels))


        #
        # raise NotImplementedError
        # torch.save((features, labels, new_dict), f"/mnt/data01/public/aad_data/gene_filtered_90pca/{d_name}_extracted.tar")

        # x, y = mnist_features, mnist_labels
        # for n_selected_class in range(2, n_class+1):
        #
        #     dataset_utils.generate_from_multi_class_dataset(dataset_name=d_name, x_data=features, y_data=labels,
        #                                       n_total_class=n_class, n_selected_class=n_selected_class, n_used=64,
        #                                       min_size=100, max_size=3000, n_interval=10,
        #                                       normalized=True, repeat=1,
        #                                       save_path=f'/mnt/data01/public/aad_data/gene_filtered/{d_name}',
        #                                       )
