import os
import torch_dct as dct
from torchvision import transforms, utils
from PIL import Image
import torch
import numpy as np
import json
from tqdm import tqdm
from IPython import embed

class dct_wrapper():
    def __call__(self, tensor):
        tensor = dct.dct_2d(tensor)
        tensor = torch.abs(tensor)
        tensor += 1e-12
        tensor = torch.log(tensor)
        return tensor


pre_transform = transforms.Compose([
    transforms.ToTensor(),
    dct_wrapper()
])

data_path = "./data/"
all_files = [data_path+"real/train/"+x for x in os.listdir(data_path+"real/train/")]
all_files += [data_path + "fake/train/" + x for x in os.listdir(data_path + "fake/train/")]
images = []
for f in tqdm(all_files):
    image = Image.open(f)
    image_tensor = pre_transform(image)
    images.append(image_tensor)

torch.save(torch.stack(images).mean(axis=0), "./mean.pt")
torch.save(torch.stack(images).var(axis=0), "./var.pt")