
from itertools import combinations

import torch
import numpy as np
import random
import math
from scipy.spatial.distance import cdist as scipy_cdist
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedShuffleSplit


def stratified_sampling(x_data, y_data, n_samples, random_seed=42):
    """
    Perform stratified sampling and return sampled data, labels, and indices.

    Parameters:
    - x_data: numpy array or list of feature data.
    - y_data: numpy array or list of labels (used for stratification).
    - n_samples: Number of samples to select.
    - random_seed: Random seed for reproducibility.

    Returns:
    - sampled_x: The sampled feature data.
    - sampled_y: The sampled labels.
    - sampled_indices: The indices of the selected samples.
    """
    x_data = np.array(x_data)  # Convert to numpy array if not already
    y_data = np.array(y_data)

    sss = StratifiedShuffleSplit(n_splits=1, test_size=n_samples, random_state=random_seed)

    for _, test_index in sss.split(x_data, y_data):
        sampled_x = x_data[test_index]
        sampled_y = y_data[test_index]
        sampled_indices = test_index

        return sampled_x, sampled_y, sampled_indices

def find_out_cdist(x_data, y_data, downsample=2000, seed=0, normalized=True):
    '''
    Parameters
    ----------
    x_data: original training data in numpy form
    y_data: original training data label in numpy form
    downsample: how many sample to form the graph
    seed: randomness
    normalized: whether normalize the data before calculate the cdist

    Returns
    -------

    '''
    X_test, y_test, indices = stratified_sampling(
        x_data, y_data, n_samples=downsample, random_seed=42 +seed)

    if normalized:
        mu = np.mean(X_test, axis=0)
        std = np.std(X_test, axis=0)
        X_test = (X_test - mu) / (std + 1e-9)

    # cdist = calculate_dist(X_test, normalized=False)
    X_test = torch.from_numpy(X_test).cuda()
    cdist = torch.cdist(X_test, X_test)
    cdist = cdist.cpu().numpy().astype(np.float16)
    # cdist = scipy_cdist(X_test, X_test).astype(np.float16)
    return cdist, X_test, y_test, indices


def generate_size_diverse_cdist(x_data, y_data, min_size, max_size, n_interval,
                                normalized=True, repeat=1, save_path='./generated_cdist',
                                additional_save_info=(None, ), save_verbo='dataset'):
    interval = np.linspace(min_size, max_size, n_interval+1)
    for i in range(n_interval):
        sampled_sizes = np.random.randint(interval[i], interval[i+1], repeat)
        if sampled_sizes > x_data.shape[0]:
            continue
        for j in range(repeat):
            cdist, X_test, y_test, indices = find_out_cdist(x_data, y_data, downsample=sampled_sizes[j],
                                                            seed=j, normalized=normalized)
            _save = (cdist, y_test, X_test, indices) + additional_save_info
            torch.save(_save, f'{save_path}/{save_verbo}_size{i}_seed{j}_cdist.tar')
            print(f'{save_path}/{save_verbo}_size{sampled_sizes[j]}_seed{j}_cdist.tar')
            pass


def unique_random_combinations(n, k, count):
    """ 生成 count 个唯一的 C(n, k) 组合 """
    total_combinations = math.comb(n, k)  # 计算组合数 C(n, k)
    if count > total_combinations:
        count = total_combinations

    seen = set()
    results = []
    while len(results) < count:
        combo = tuple(sorted(random.sample(range(n), k)))  # 生成随机组合
        if combo not in seen:  # 检查是否已存在
            seen.add(combo)
            results.append(combo)
    return results


def generate_from_multi_class_dataset(dataset_name, x_data, y_data, n_total_class, n_selected_class, n_used,
                                      min_size, max_size, n_interval,
                                      normalized=True, repeat=1,
                                      save_path='./generated_cdist',
                                      ):
    '''
    Parameters
    ----------
    dataset_name: str verbo dataset name
    x_data: full dataset feature
    y_data: full dataset label
    n_total_class: total class number in classification task
    n_selected_class: how many class data you want to use
    n_used: among all combination of C(n_total_class, n_selected_class), how many you want to use
    min_size: min number in the down sampling
    max_size: max number in the down sampling
    n_interval: to increase diversity, randomize the downsampling number
    normalized:
    repeat: how many times to repeat the downsampling
    save_path

    Returns
    -------
    '''
    label_combinations = list(combinations(range(n_total_class), n_selected_class))
    n_used = min(n_used, len(label_combinations))
    # used_combination_idx = np.random.choice(list(range(len(label_combinations))), n_used, replace=False)
    # used_combination = [label_combinations[idx] for idx in used_combination_idx]
    used_combination = unique_random_combinations(n_total_class, n_selected_class, n_used)
    print('12234556', len(used_combination))
    for i, selected_labels in enumerate(used_combination):
        # try:
        selected_indices = np.isin(y_data, selected_labels)
        print(selected_labels)
        x_selected = x_data[selected_indices]
        y_selected = y_data[selected_indices]

        save_verbo = f'{dataset_name}-{n_selected_class}class-comb{i}'
        generate_size_diverse_cdist(x_data=x_selected, y_data=y_selected,
                                    min_size=min_size, max_size=max_size, n_interval=n_interval,
                                    normalized=normalized, repeat=repeat, save_path=save_path,
                                    additional_save_info=(selected_labels,), save_verbo=save_verbo)
        # except ValueError:
        #     continue


if __name__ == "__main__":
    for d_name in [
        'cifar10',
        'mnist',
        'fmnist'
    ]:
        features, labels = torch.load(f"/home/****/autovisual/prepare_data/data/{d_name}_features_clip.tar")

        # x, y = mnist_features, mnist_labels
        for n_selected_class in range(2, 11):
            generate_from_multi_class_dataset(dataset_name=d_name, x_data=features, y_data=labels,
                                              n_total_class=10, n_selected_class=n_selected_class, n_used=64,
                                              min_size=100, max_size=3000, n_interval=10,
                                              normalized=True, repeat=1,
                                              save_path=f'/mnt/data01/public/aad_data/{d_name}',
                                              )


