"""Calculate mean and std of dataset."""


import torch.utils.data
from typing import Optional, Callable
import glob
import os
import torchvision
import argparse
import torchvision.transforms
from tqdm import tqdm
import numpy as np



class SimpleImageDataset(torch.utils.data.Dataset):
    def __init__(self, root: str, transform: Optional[Callable] = None):
        self.root = root
        self.image_paths = list(sorted(list(glob.glob(os.path.join(root, "*.*")))))

        if transform is None:
            transform = lambda x: x
        self.transform = transform

        self.loader = torchvision.datasets.folder.pil_loader

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, item):
        assert 0 <= item < len(self)
        return self.transform(self.loader(self.image_paths[item]))

class SimpleImageDatasetNPZ(torch.utils.data.Dataset):

    def __init__(self, root: str, transform: Optional[Callable] = None):
        self.root = root
        npz = np.load(self.root, allow_pickle=True)
        self.data = npz[
            "images"
        ]

        if transform is None:
            transform = lambda x: x
        self.transform = transform


    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        assert 0 <= item < len(self)
        return self.transform(self.data[item])


parser = argparse.ArgumentParser()
parser.add_argument("--root-folder", required=True)
args = parser.parse_args()


dataset = SimpleImageDatasetNPZ(
    args.root_folder, transform=torchvision.transforms.ToTensor()
)

full_loader = torch.utils.data.DataLoader(
    dataset, shuffle=True, num_workers=os.cpu_count(), batch_size=256
)

mean = torch.zeros(3)
std = torch.zeros(3)
print("==> Computing mean and std..")
for inputs in tqdm(full_loader):
    for i in range(3):
        mean[i] += inputs[:, i, :, :].mean(dim=(-1, -2)).sum(0)
        std[i] += inputs[:, i, :, :].std(dim=(-1, -2)).sum(0)
mean.div_(len(dataset))
std.div_(len(dataset))
print(mean, std)