import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# Define image transform (e.g., crop, resize, to tensor)
transform = transforms.Compose([
    transforms.CenterCrop(178),      # original CelebA crop
    transforms.Resize(128),          # optional resize
    transforms.ToTensor()            # scales pixels to [0,1]
])

# Load training set (CelebA uses split='train')
celeba_train = datasets.CelebA(
    "/scratch/user/repos/ssl/solo-learn-gaussianization/datasets",
    split='train',
    download=True,
    transform=transform,
)

# celeba_train = datasets.CelebA(root='path/to/data', split='train', transform=transform, download=True)
loader = DataLoader(celeba_train, batch_size=64, shuffle=False, num_workers=4)

# Initialize sums
mean = 0.
std = 0.
nb_samples = 0.

for data, _ in tqdm(loader):
    batch_samples = data.size(0)
    data = data.view(batch_samples, data.size(1), -1)  # (B, C, H*W)
    mean += data.mean(2).sum(0)  # sum over batch
    std += data.std(2).sum(0)
    nb_samples += batch_samples

mean /= nb_samples
std /= nb_samples

print("Mean:", mean)
print("Std:", std)

breakpoint()
