# Imports
import torch
from torchmetrics.image.fid import FrechetInceptionDistance
import os
from PIL import Image
import torchvision.transforms as transforms
import argparse
_ = torch.manual_seed(123)

# Add argparse setup
parser = argparse.ArgumentParser(description='Get FID scores')

parser.add_argument(
    '--folder_ddim', 
    type=str,
    default='./example/A_3D_model_of_an_adorable_cottage_with_a_thatched_roof@20240507-161523/save/it9000-test', 
    help='Path to the folder containing DDIM ground truth images'
    )
parser.add_argument(
    '--folder_sds', 
    type=str,
    default='./example/A_baby_bunny_sitting_on_top_of_a_stack_of_pancakes@20240507-182038/save/it9000-test', 
    help='Path to the folder containing SDS images'
    )

# GPU device to run CLIP network
parser.add_argument(
    '--device',
    type=str,
    default='1' if torch.cuda.is_available() else 'cpu',
    help='Device to run inference on (cuda or cpu)'
)

args = parser.parse_args()


device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

# Function to load and preprocess images
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = Image.open(os.path.join(folder, filename))
        img = img.convert('RGB')  # Convert to RGB if necessary
        transform = transforms.Compose([
            transforms.Lambda(lambda img: img.crop((0, 0, 512, 512))),
            transforms.Resize((299, 299)),  # Resize to InceptionV3 input size
            transforms.ToTensor(),  # Convert to tensor
        ])
        img = transform(img)
        images.append(img)
    return torch.stack(images)

fid = FrechetInceptionDistance(feature=64)
fid = fid.to(device)

# Load images from folders and move to GPU
ddim_images = load_images_from_folder(args.folder_ddim)*255
ddim_images = ddim_images.to(device, dtype=torch.uint8)
sds_images = load_images_from_folder(args.folder_sds)*255
sds_images = sds_images.to(device, dtype=torch.uint8)

# Compute FID
fid.update(ddim_images, real=True)
fid.update(sds_images, real=False)
fid_score = fid.compute()

print("FID Score:", fid_score.item())


