import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import copy
from dataset import MetricDataset
from metric import Predictor
from tqdm import tqdm
import numpy as np


tp = 'sem'
data_root = '~/data/tai/metric_{}'.format(tp)
metric_datasets = {
    x: MetricDataset(
        data_root=data_root,
        split=x
    )
    for x in ['train', 'val']
}
dataloaders = {
    x: DataLoader(metric_datasets[x], batch_size=1, shuffle=True, num_workers=1)
    for x in ['train', 'val']
}
dataset_sizes = {x: len(metric_datasets[x]) for x in ['train', 'val']}


total_samples = 0
image_mean = 0
image_std = 0
label_mean = 0
label_std = 0
for inputs, labels in tqdm(dataloaders['train']):
    image_mean += inputs
    total_samples += 1
    label_mean += labels[0]

for inputs, labels in tqdm(dataloaders['val']):
    image_mean += inputs
    total_samples += 1
    label_mean += labels[0]

image_mean = torch.mean(image_mean) / total_samples
label_mean = label_mean / total_samples

for inputs, labels in tqdm(dataloaders['train']):
    image_std += torch.square(inputs - image_mean)
    label_std += torch.square(labels[0] - label_mean)

for inputs, labels in tqdm(dataloaders['val']):
    image_std += torch.square(inputs - image_mean)
    label_std += torch.square(labels[0] - label_mean)

image_std = torch.sqrt(torch.mean(image_std) / total_samples)
label_std = torch.sqrt(label_std / total_samples)

print('count: ', total_samples)
print('image mean: ', image_mean)
print('image std: ', image_std)
print('label mean: ', label_mean)
print('label std: ', label_std)