import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchmetrics.image.fid import FrechetInceptionDistance
from PIL import Image
from tqdm import tqdm
import argparse

class ImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        # 获取所有图片文件的路径
        self.image_paths = [
            os.path.join(folder_path, f) 
            for f in os.listdir(folder_path) 
            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))
        ]
        if not self.image_paths:
            raise ValueError(f"文件夹 {folder_path} 中没有找到图片文件")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')  # 确保是RGB格式
            if self.transform:
                image = self.transform(image)
            return image
        except Exception as e:
            print(f"加载图片 {img_path} 时出错: {e}")
            # 返回一张空白图片作为替代（避免中断）
            return torch.zeros(3, 299, 299, dtype=torch.uint8)

def calculate_fid(folder1, folder2, batch_size=32, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """计算两个文件夹中图片的FID分数"""
    # 定义图像预处理（直接转换为uint8张量，不进行标准化）
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),  # 转换为[0,1]范围的float张量
        transforms.Lambda(lambda x: (x * 255).byte())  # 转换为[0,255]范围的uint8张量
    ])
    
    # 创建自定义数据集
    dataset1 = ImageDataset(folder1, transform=transform)
    dataset2 = ImageDataset(folder2, transform=transform)
    
    # 创建数据加载器
    dataloader1 = DataLoader(dataset1, batch_size=batch_size, shuffle=False, num_workers=4)
    dataloader2 = DataLoader(dataset2, batch_size=batch_size, shuffle=False, num_workers=4)
    
    print(f"加载完成: 文件夹1包含 {len(dataset1)} 张图片, 文件夹2包含 {len(dataset2)} 张图片")
    
    # 初始化FID计算器
    fid = FrechetInceptionDistance(feature=2048).to(device)
    
    # 处理第一个文件夹的图片
    print("处理第一个文件夹的图片...")
    for batch in tqdm(dataloader1, total=len(dataloader1)):
        images = batch.to(device)
        # 确保数据类型是uint8
        if images.dtype != torch.uint8:
            images = images.to(torch.uint8)
        fid.update(images, real=True)
    
    # 处理第二个文件夹的图片
    print("处理第二个文件夹的图片...")
    for batch in tqdm(dataloader2, total=len(dataloader2)):
        images = batch.to(device)
        # 确保数据类型是uint8
        if images.dtype != torch.uint8:
            images = images.to(torch.uint8)
        fid.update(images, real=False)
    
    # 计算FID分数
    fid_score = fid.compute()
    return fid_score.item()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='计算两个文件夹中图片的FID分数')
    parser.add_argument('--folder1', type=str, required=True, help='第一个图片文件夹路径')
    parser.add_argument('--folder2', type=str, required=True, help='第二个图片文件夹路径')
    parser.add_argument('--batch_size', type=int, default=32, help='批处理大小')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', 
                      help='计算设备 (cuda 或 cpu)')
    
    args = parser.parse_args()
    
    # 检查文件夹是否存在
    if not os.path.isdir(args.folder1):
        raise ValueError(f"文件夹 {args.folder1} 不存在")
    if not os.path.isdir(args.folder2):
        raise ValueError(f"文件夹 {args.folder2} 不存在")
    
    # 计算FID
    fid_score = calculate_fid(args.folder1, args.folder2, args.batch_size, args.device)
    print(f"两个文件夹图片的FID分数: {fid_score:.4f}")
