import os
import torch

LIST_NUM_SAMPLES = [10, 50, 100, 200]
SAVE_PATH = "preprocessed_dataset/split_point_cloud"
categories = ['table', 'chair', 'airplane', 'car', 'sofa', 'rifle', 'lamp', 'vessel', 'bench', 'speaker', 
            'cabinet', 'monitor', 'bus', 'bathtub', 'guitar', 'faucet', 'clock', 'pot', 'cellphone', 
            'jar', 'bottle', 'telephone', 'laptop', 'bookshelf', 'knife', 'train', 'motorcycle', 
            'can', 'file', 'pistol', 'bed', 'piano', 'stove', 'mug', 'bowl', 'washer', 
            'printer', 'helmet', 'microwave', 'skateboard', 'tower', 'camera', 
            'basket', 'tin_can', 'pillow', 'dishwasher', 'mailbox', 
            'rocket', 'bag', 'earphone', 'birdhouse', 
            'microphone', 'remote_control', 
            'keyboard', 'cap'][:10]
data_map = {}
for cat in categories:
    fp = os.path.join(SAVE_PATH, f"{cat}.pt")
    data_map[cat] = torch.load(fp, map_location="cpu")

for num_samples in LIST_NUM_SAMPLES:
    num_per_cat = num_samples // len(categories)
    all_samples = []
    for cat in categories:
        data = data_map[cat]
        idx = torch.randperm(len(data))[:num_per_cat]
        all_samples.append(data.index_select(0, idx))
    final_tensor = torch.cat(all_samples, dim=0)
    perm = torch.randperm(final_tensor.size(0))
    final_tensor = final_tensor.index_select(0, perm)[:num_samples]

    train_dir = os.path.join(SAVE_PATH, "train", f"num_samples_{final_tensor.size(0)}")
    os.makedirs(train_dir, exist_ok=True)
    torch.save(final_tensor, os.path.join(train_dir, "samples.pt"))

all_list = [data_map[cat] for cat in categories]
X_all = torch.cat(all_list, dim=0)
N = X_all.size(0)

P = min(10_000, N // 2)
perm_all = torch.randperm(N)
X_all = X_all.index_select(0, perm_all)

pcs1 = X_all[:P]
pcs2 = X_all[P:2*P]

test_dir = os.path.join(SAVE_PATH, "test", f"num_pairs_{P}")
os.makedirs(test_dir, exist_ok=True)
torch.save(pcs1, os.path.join(test_dir, "pcs1.pt"))
torch.save(pcs2, os.path.join(test_dir, "pcs2.pt"))
