import numpy as np
import pandas as pd
from scipy.linalg import sqrtm
import torch
from torchvision import models
from torchvision import transforms
import torch
from torchvision import transforms
from torchvision import datasets
from PIL import Image

# Calculate FID
def calculate_fid(real_images, generated_images):
    inception = models.inception_v3(pretrained=True, transform_input=False).eval()

    def get_features(images):
        with torch.no_grad():
            return inception(images).cpu().numpy()

    real_features = get_features(real_images)
    generated_features = get_features(generated_images)

    mu_real, sigma_real = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu_gen, sigma_gen = generated_features.mean(axis=0), np.cov(generated_features, rowvar=False)

    fid = np.sum((mu_real - mu_gen) ** 2) + np.trace(sigma_real + sigma_gen - 2 * sqrtm(np.dot(sigma_real, sigma_gen)))
    return fid

# Calculate Inception Score
def calculate_inception_score(generated_images):
    inception = models.inception_v3(pretrained=True, transform_input=False).eval()
    with torch.no_grad():
        output = inception(generated_images)
    p_y = torch.softmax(output, dim=1).cpu().numpy()
    p_y_mean = np.mean(p_y, axis=0)
    kl_divergence = p_y * (np.log(p_y + 1e-10) - np.log(p_y_mean + 1e-10)).sum(axis=1)
    inception_score = np.exp(np.mean(kl_divergence))
    return inception_score



preprocess = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load real images from the existing CIFAR-10 dataset
real_images_path = "/idas/users/liusirui/0926/Raw/data/cifar-10-batches-py"
real_dataset = datasets.CIFAR10(root=real_images_path, train=False, download=False,transform=preprocess)
real_loader = torch.utils.data.DataLoader(real_dataset, batch_size=32, shuffle=True)

# Load generated images
generated_images_path = "/idas/users/liusirui/0926/Raw/data/result/cifar10_sampler.png"
generated_image = Image.open(generated_images_path)


# Preprocess generated images
generated_images = preprocess(generated_image).unsqueeze(0)  # Add batch dimension

# Get a batch of real images
real_images, _ = next(iter(real_loader))  # Get one batch

# Calculate metrics
fid_value = calculate_fid(real_images, generated_images)
inception_score_value = calculate_inception_score(generated_images)

# Save results to a CSV file
results = {
    'Metric': ['FID', 'Inception Score'],
    'Value': [fid_value, inception_score_value]
}

df = pd.DataFrame(results)
df.to_csv('evaluation_metrics.csv', index=False)

print("Metrics saved to evaluation_metrics.csv")

