import os
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision.models import inception_v3
from PIL import Image
from scipy.linalg import sqrtm


def get_inception_model(device):
    model = inception_v3(pretrained=True, transform_input=True)
    model.fc = torch.nn.Identity()
    model.to(device)
    model.eval()
    return model


def get_image_paths(directory):
    allowed_ext = ['.jpg', '.jpeg', '.png']
    paths = [os.path.join(root, file)
             for root, _, files in os.walk(directory)
             for file in files if any(file.lower().endswith(ext) for ext in allowed_ext)]
    return paths


def get_activations(image_paths, model, device, batch_size=32):
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    activations = []
    model.eval()

    with torch.no_grad():
        for i in range(0, len(image_paths), batch_size):
            batch_paths = image_paths[i:i + batch_size]
            images = []
            for path in batch_paths:
                try:
                    img = Image.open(path).convert('RGB')
                    images.append(transform(img))
                except Exception as e:
                    print(f"Error loading image {path}: {e}")

            if not images:
                continue  # 避免空批次

            images = torch.stack(images).to(device)
            pred = model(images)
            activations.append(pred.cpu().numpy())

    return np.concatenate(activations, axis=0)


def calculate_fid(act1, act2):
    mu1, mu2 = np.mean(act1, axis=0), np.mean(act2, axis=0)
    sigma1, sigma2 = np.cov(act1, rowvar=False), np.cov(act2, rowvar=False)

    diff_squared = np.sum((mu1 - mu2) ** 2)

    covmean, _ = sqrtm(sigma1.dot(sigma2) + 1e-6 * np.eye(2048), disp=False)
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff_squared + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
    return fid


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = get_inception_model(device)

    folder1 = 'experiment3/3view16skip'
    folder2 = 'orig'
    # folder2 = 'real_object_1'

    image_paths1 = get_image_paths(folder1)
    image_paths2 = get_image_paths(folder2)

    print(f"Set1: {len(image_paths1)} images, Set2: {len(image_paths2)} images")

    activations1 = get_activations(image_paths1, model, device)
    activations2 = get_activations(image_paths2, model, device)

    # 计算 10 次 FID 取均值和标准差
    fid_scores = [calculate_fid(activations1, activations2) for _ in range(10)]
    fid_mean = np.mean(fid_scores)
    fid_std = np.std(fid_scores)

    print(f"FID Mean: {fid_mean:.4f}, FID Std: {fid_std:.4f}")
