from itertools import combinations

import numpy as np
from scipy.spatial.distance import cdist as scipy_cdist

import torch

##import load_dataset


def mat_cdist(x1, x2):
    cdist = x1[None, :, :] - x2[ :, None, :]
    cdist = np.linalg.norm(cdist, axis=2)
    return cdist


def calculate_dist(x, normalized=True):
    n, m = x.shape
##    cdist = np.matmul(x[None, :, :] , x.T[ None, :, :])[0]
    cdists = []

    n_chunk = n // 1024
    if n_chunk > 0:
        n_tail = n % 1024
        if n_tail > 0:
            x_head = x[:-n_tail]
            x_tail = x[-n_tail:]
            for x_ in np.split(x_head, n_chunk):
                c = mat_cdist(x_, x)
                cdists.append(c.T)
            cdists.append(mat_cdist(x_tail, x).T)
        else:
            for x_ in np.split(x, n_chunk):
                c = mat_cdist(x_, x)
                cdists.append(c.T)
        cdist = np.concatenate(cdists)
    else:
        cdist = mat_cdist(x, x)
    if normalized:
        up_tri = cdist[np.triu_indices_from(cdist, k=1)]
        mean = np.mean(up_tri)
        std = np.std(up_tri)
        cdist = (cdist - mean) / (std + 1e-9)
        cdist = cdist - np.eye(n) * (- mean / (std + 1e-9))
    return cdist


from sklearn.model_selection import train_test_split


def find_out_cdist(x_data, y_data, downsample=2000, seed=0):
    X_train, X_test, y_train, y_test = train_test_split(
        x_data, y_data, test_size=downsample, random_state=42+seed, stratify=y_data)
    
    mu = np.mean(X_test, axis=0)
    std = np.std(X_test, axis=0)
    X_test = (X_test - mu) / (std + 1e-9)

    # cdist = calculate_dist(X_test, normalized=False)
    cdist = scipy_cdist(X_test, X_test).astype(np.float16)
    print(f'[cdist X_test y_test] = [{cdist.shape} {X_test.shape} {y_test.shape}]')
    return cdist, y_test, X_test


def main():
    for d_name in [
        'cifar10',
        # 'mnist',
        # 'fmnist'
        ]:
        
        mnist_features, mnist_labels = torch.load(f"/home/****/autovisual/prepare_data/data/{d_name}_features_clip.tar", weights_only=False)
        # print(f'[features labels] = [{type(mnist_features)} {type(mnist_labels)}]')
        # [features labels] = [<class 'numpy.ndarray'> <class 'numpy.ndarray'>]
        # print(f'[features labels] = [{mnist_features.shape} {mnist_labels.shape}]')
        # [features labels] = [(50000, 512) (50000,)]

        x, y = mnist_features, mnist_labels

        # group1_indices = np.where((y >= 0) & (y <= 4))[0]  # Group 1: Labels 0-4
        # group2_indices = np.where((y >= 5) & (y <= 9))[0]  # Group 2: Labels 5-9
        #
        # group1 = find_out_cdist(x[group1_indices], y[group1_indices], downsample=3000)
        # group2 = find_out_cdist(x[group2_indices], y[group2_indices], downsample=3000)
        #
        # torch.save(group1, f'./features/{d_name}_clip_cdist_3000_group1.tar')
        # torch.save(group2, f'./features/{d_name}_clip_cdist_3000_group2.tar')
        seed = 0
        for j in range(3, 11):
            label_combinations = list(combinations(range(10), j))[:5]
            for i, selected_labels in enumerate(label_combinations):
                selected_indices = np.isin(y, selected_labels) # logical mask like [True, False, True, ...]
                x_selected = x[selected_indices]
                y_selected = y[selected_indices]

                cdist_res = find_out_cdist(x_selected, y_selected, downsample=10000, seed=seed)
                _save = cdist_res + (selected_labels, )
                torch.save(_save, f'/mnt/data01/public/aad_data/cifar10_large/{d_name}_{j}class_comb{i}_seed{seed}_clip_cdist_10000.tar')
                print(f'/mnt/data01/public/aad_data/cifar10_large/{d_name}_{j}class_comb{i}_seed{seed}_clip_cdist_10000.tar')


    
##    for dataset in [
##        'arrhythmia', 'wine', 'lympho', 'glass', 'vertebral', 'wbc', 'ecoli', 'ionosphere', 'breastw', 'pima',
##                'vowels',
##        'letter',
##                     'cardio', 'seismic',
##        'musk', 'speech', 'abalone', 'pendigits', 'mammography',
##                'mulcross',
##                'forest_cover'
##                ]:
##        normalize = True
##        data, _, _ = load_dataset.load_dataset('./data', dataset)
##        data = data[np.random.permutation(data.shape[0])][: 3000]
##
##        x, y = data[:, :-1], data[:, -1].astype(int)    
##
##        mu = np.mean(x, axis=0)
##        std = np.std(x, axis=0)
##        if normalize:
##            x = (x - mu) / (std + 1e-5)
##        cdist = calculate_dist(x, normalized=False)
##        torch.save((cdist, x, y), f'./cdist_data/{dataset}_cdist_data_for_tsne2.tar', pickle_protocol=5)
        
if __name__ == "__main__":
    main()