import torch
import torch.nn as nn
from classifier import Inception, InceptionBlock, Flatten, find_model

import numpy as np
from scipy.linalg import sqrtm
from torchvision.models import inception_v3
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 计算特征提取器的特征向量
def calculate_feature_vectors(images, model):
    model.eval()
    with torch.no_grad():
        features = model(images)
    return features.cpu().numpy()

# 计算均值和协方差
def calculate_statistics(features):
    mu = np.mean(features, axis=0)
    sigma = np.cov(features, rowvar=False)
    return mu, sigma

# 计算FID
def calculate_fid(real_features, fake_features):
    mu_real, sigma_real = calculate_statistics(real_features)
    mu_fake, sigma_fake = calculate_statistics(fake_features)

    # 计算特征向量之间的Frobenius范数
    diff = mu_real - mu_fake
    cov_mean_sqrt = sqrtm(sigma_real.dot(sigma_fake))
    if np.iscomplexobj(cov_mean_sqrt):
        cov_mean_sqrt = cov_mean_sqrt.real
    fid = np.sum(diff**2) + np.trace(sigma_real + sigma_fake - 2 * cov_mean_sqrt)
    return fid

# ------------------------------- load classifier ---------------------------------
def fid_score(classifier, real_dataset, fake_dataset): # input should be torch dataset
    classifier = nn.Sequential(
        *list(classifier.children())[:-1]
        )

    classifier.eval()


    batch_size = 128
    # 加载真实数据和生成数据
    real_data_loader = DataLoader(real_dataset, batch_size=batch_size, shuffle=True)
    fake_data_loader = DataLoader(fake_dataset, batch_size=batch_size, shuffle=True)



    # 提取真实数据和生成数据的特征向量
    real_features = []
    fake_features = []

    for x, _ in real_data_loader:

        real_features.append(calculate_feature_vectors(x.to('cuda'), classifier))

    for x, _ in fake_data_loader:
        x = x.squeeze(1)
        # noise = torch.randn(x.shape)

        fake_features.append(calculate_feature_vectors(x.to('cuda'), classifier))

    # 将特征向量拼接为一个numpy数组
    real_features = np.concatenate(real_features, axis=0)
    fake_features = np.concatenate(fake_features, axis=0)

    # 计算FID
    fid_score = calculate_fid(real_features, fake_features)
    return fid_score