import numpy as np
import os
import torch
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.datasets import CIFAR10 
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
from einops import rearrange
from tqdm import tqdm
device = "cuda"
transform = transforms.Compose(
    [
        transforms.Resize(32),
        transforms.ToTensor()
    ]
)
real_dataset = CIFAR10(
    os.path.join("data", "cifar10"),
    train=True,
    download=True,
    transform=transform,
)
chunk_batch = 500
max_size = 50000
fid = FrechetInceptionDistance().to(device)
real_dataloader = DataLoader(real_dataset, batch_size=chunk_batch)

img_list = []

for data in tqdm(real_dataloader):
    data = data[0].to(device)
    data = (data * 255).to(torch.uint8)
    img_list.append(data)
    fid.update(data, real=True)

img_list = torch.concat(img_list, dim=0)
print(img_list.shape)
torch.save(img_list, "real_features.pt")

fid.real_features_sum
image_path = "outputs/cifar10/version_6/save/test"

img_files = os.listdir(image_path)
for img_file in tqdm(img_files):
    file = os.path.join(image_path, img_file)
    im = Image.open(file)
    im = np.array(im)
    im = rearrange(im, "h (b w) c -> b c h w", b = 500)
    im = torch.tensor(im, device=device)
    fid.update(im, real=False)

print(fid.compute())