import os

from argparse import ArgumentParser
import torch
from datasets import load_dataset
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.transforms import v2
from tqdm import tqdm
from torch.utils.data import Dataset
from PIL import Image


class SimpleImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.transform = transform
        self.image_paths = image_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image


def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    fid = FrechetInceptionDistance()
    fid = fid.to(device)
    fid.eval()

    gt_images = load_dataset(
        "sayakpaul/coco-30-val-2014",
        split="train",
        cache_dir=args.dataset_folder,
        trust_remote_code=True,
        num_proc=12,
    )
    gt_images = gt_images.remove_columns("caption")
    preprocess_image = v2.Compose(
        [
            v2.Resize(256, interpolation=v2.InterpolationMode.BICUBIC, antialias=True),
            v2.CenterCrop(256),
            v2.PILToTensor(),
        ]
    )

    def preprocess_fn(batch):
        batch["image"] = [
            preprocess_image(image.convert("RGB")) for image in batch["image"]
        ]
        return batch

    gt_images.set_transform(preprocess_fn)
    print("Number of real images: ", len(gt_images))
    gt_images_loader = torch.utils.data.DataLoader(
        gt_images,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=True,
        drop_last=False,
        persistent_workers=True,
        collate_fn=lambda x: torch.stack([item["image"] for item in x], dim=0),
    )

    # Process real images
    for batch in tqdm(gt_images_loader):
        batch = batch.to(device)
        fid.update(batch, real=True)

    # Load generated images using the custom dataset
    generated_images = SimpleImageDataset(
        image_paths=[
            os.path.join(args.image_folder, fname)
            for fname in os.listdir(args.image_folder)
            if fname.lower().endswith(("png", "jpg", "jpeg"))
        ],
        transform=preprocess_image,
    )
    print("Number of fake images: ", len(generated_images))
    generated_images_loader = torch.utils.data.DataLoader(
        generated_images,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=True,
        drop_last=False,
        persistent_workers=True,
        collate_fn=lambda x: torch.stack(x, dim=0),
    )

    # Process generated images
    for batch in tqdm(generated_images_loader):
        batch = batch.to(device)
        fid.update(batch, real=False)

    fid_score = fid.compute().item()
    print("FID: ", fid_score)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--dataset_folder", type=str, default="/path/to/dataset_folder")
    parser.add_argument("--image_folder", type=str, default="/path/to/image_folder")
    parser.add_argument("--batch_size", type=int, default=50)
    parser.add_argument("--num_workers", type=int, default=12)
    args = parser.parse_args()
    main(args)
