import os
import torch
from torchvision import datasets, transforms

N_POINTS = 100
THRESHOLD = 0.5
SAVE_PATH = 'preprocessed_dataset/pcmnist'
os.makedirs(SAVE_PATH, exist_ok=True)

mnist = datasets.MNIST(root="data", train=True, download=False, transform=transforms.ToTensor())
point_clouds = []
for idx, (img, _) in enumerate(mnist):
    img = img.squeeze(0)
    coords = torch.nonzero(img > THRESHOLD, as_tuple=False).float()
    if abs(coords.shape[0] - N_POINTS) < 5 and coords.shape[0] > 0:
        if coords.shape[0] > N_POINTS:
            indices = torch.randperm(coords.shape[0])[:N_POINTS]
            coords = coords[indices]
        elif coords.shape[0] < N_POINTS:
            pad_indices = torch.randint(0, coords.shape[0], (N_POINTS - coords.shape[0],))
            coords = torch.cat([coords, coords[pad_indices]], dim=0)
        point_clouds.append(coords)

point_clouds = torch.stack(point_clouds)
print(f"Collected {point_clouds.shape[0]} point clouds with {N_POINTS} points each.")

torch.save(point_clouds, f"{SAVE_PATH}/all_point_clouds.pt")
print(f"Saved all point clouds to {SAVE_PATH}/all_point_clouds.pt")
