import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import h5py
import numpy as np
from PIL import Image
import io

# Open the original file

# f = h5py.File("/scratch/user/repos/ssl/solo-learn-gaussianization/datasets/3dshapes/3dshapes_val.h5", 'r')
with h5py.File("/scratch/user/repos/ssl/solo-learn-gaussianization/datasets/3dshapes/3dshapes_train_subset.h5", "r") as f:
    images = f["images"][:]
    labels = f["labels"][:]

im = np.asarray(images[0])

breakpoint()

Image.open(io.BytesIO(images[0])).convert("RGB")
breakpoint()

# Define image transform (e.g., crop, resize, to tensor)
transform = transforms.Compose([
    transforms.ToTensor()            # scales pixels to [0,1]
])


# celeba_train = datasets.CelebA(root='path/to/data', split='train', transform=transform, download=True)
loader = DataLoader(celeba_train, batch_size=64, shuffle=False, num_workers=4)

# Initialize sums
mean = 0.
std = 0.
nb_samples = 0.

for data, _ in tqdm(loader):
    batch_samples = data.size(0)
    data = data.view(batch_samples, data.size(1), -1)  # (B, C, H*W)
    mean += data.mean(2).sum(0)  # sum over batch
    std += data.std(2).sum(0)
    nb_samples += batch_samples

mean /= nb_samples
std /= nb_samples

print("Mean:", mean)
print("Std:", std)

breakpoint()
