import os
import torch
import numpy as np
from itertools import combinations

NUM_PAIRS = 10000
num_samples_list = [10, 50, 100, 200]
SAVE_PATH = 'preprocessed_dataset/pcmnist'

os.makedirs(f"{SAVE_PATH}/train", exist_ok=True)
os.makedirs(f"{SAVE_PATH}/test", exist_ok=True)

point_clouds = torch.load(f"{SAVE_PATH}/all_point_clouds.pt")
print(f"{point_clouds.shape[0]} point clouds loaded.")

num_clouds = point_clouds.shape[0]
all_pairs = list(combinations(range(num_clouds), 2))
print(f"Total possible pairs: {len(all_pairs)}")

np.random.seed(42)
selected_idx = np.random.choice(len(all_pairs), size=NUM_PAIRS, replace=False)
selected_pairs = [all_pairs[i] for i in selected_idx]

pcs1 = torch.stack([point_clouds[i] for i, j in selected_pairs])
pcs2 = torch.stack([point_clouds[j] for i, j in selected_pairs])

test_save_dir = f"{SAVE_PATH}/test/num_pairs_{NUM_PAIRS}"
os.makedirs(test_save_dir, exist_ok=True)
torch.save(pcs1, f"{test_save_dir}/pcs1.pt")
torch.save(pcs2, f"{test_save_dir}/pcs2.pt")
print(f"Saved {NUM_PAIRS} test pairs to {test_save_dir}")

for num_samples in num_samples_list:
    if num_samples > point_clouds.size(0):
        continue
    perm = torch.randperm(point_clouds.size(0))[:num_samples]
    samples = point_clouds[perm]
    train_save_dir = f"{SAVE_PATH}/train/num_samples_{num_samples}"
    os.makedirs(train_save_dir, exist_ok=True)
    torch.save(samples, f"{train_save_dir}/samples.pt")
    print(f"Saved {num_samples} train samples to {train_save_dir}/samples.pt")
