
import os
import multiprocessing

def main():
    import argparse
    import sys
    import torch
    import torchvision
    import torch.nn as nn
    from torchvision.transforms import transforms
    from utils.str2bool import str2bool
    from utils.load_dataset import load_dataset

    parser = argparse.ArgumentParser(description='Calculate Mean for the dataset', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--dataset',          default='cifar10',      type=str,       help='Set dataset to use')
    parser.add_argument('--batch_size',       default=512,            type=int,       help='Batch size')

    global args
    args = parser.parse_args()
    print(args)

    # Setup right device to run on
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    # Use the following transform for training and testing
    print('\n')
    dataset = load_dataset(dataset=args.dataset,
                           train_batch_size=args.batch_size,
                           test_batch_size=args.batch_size,
                           val_split=0.0,
                           augment=False,
                           shuffle=True,
                           random_seed=0,
                           device=device,
                           mean=[0,0,0],
                           std=[1,1,1])

    # loop over the dataset
    img_sum = torch.zeros((dataset.img_ch)).to(device)
    num_examples = 0
    for batch_idx, (data, labels) in enumerate(dataset.train_loader):
        data = data.to(device)
        img_sum += data.sum(axis=(0, 2, 3))
        num_examples += data.shape[0]

    n_per_channel = num_examples *  dataset.img_dim * dataset.img_dim
    mean = img_sum / n_per_channel
    img_sum_sq = torch.zeros((dataset.img_ch)).to(device)
    expanded_mean = torch.transpose(mean.repeat((dataset.img_dim, dataset.img_dim, 1)), 0, 2)

    for batch_idx, (data, labels) in enumerate(dataset.train_loader):
        data = data.to(device)
        img_sum_sq += torch.square(data - expanded_mean).sum(axis=(0, 2, 3))

    n_per_channel = num_examples *  dataset.img_dim * dataset.img_dim
    var = img_sum_sq / (n_per_channel - 1)
    std = torch.sqrt(var)
    print("Mean {}".format(mean))
    print("Std {}".format(std))
  

if __name__ == "__main__":
    if os.name == 'nt':
        # On Windows calling this function is necessary for multiprocessing
        multiprocessing.freeze_support()
    
    main()