import torch
import numpy as np
import scipy.linalg

# FID评价保真度，越小越好
def calculate_fid(
    featuresdict_1, featuresdict_2, feat_layer_name
):  # using 2048 layer to calculate
    eps = 1e-6
    features_1 = featuresdict_1[feat_layer_name]
    features_2 = featuresdict_2[feat_layer_name]

    assert torch.is_tensor(features_1) and features_1.dim() == 2
    assert torch.is_tensor(features_2) and features_2.dim() == 2

    stat_1 = {
        "mu": np.mean(features_1.numpy(), axis=0),
        "sigma": np.cov(features_1.numpy(), rowvar=False),
    }
    stat_2 = {
        "mu": np.mean(features_2.numpy(), axis=0),
        "sigma": np.cov(features_2.numpy(), rowvar=False),
    }

    print("Computing Frechet Distance")

    mu1, sigma1 = stat_1["mu"], stat_1["sigma"]
    mu2, sigma2 = stat_2["mu"], stat_2["sigma"]
    assert mu1.shape == mu2.shape and mu1.dtype == mu2.dtype
    assert sigma1.shape == sigma2.shape and sigma1.dtype == sigma2.dtype

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)


    # 检查协方差矩阵的特征值（是否半正定）
    eigvals1 = np.linalg.eigvals(sigma1)
    eigvals2 = np.linalg.eigvals(sigma2)
    if np.min(eigvals1) < 0:
        print("Min eigval of sigma1:", np.min(eigvals1))
        # 强制协方差矩阵严格正定
        sigma1 = sigma1 + eps * np.eye(sigma1.shape[0])
        eigvals1 = np.linalg.eigvals(sigma1)
        # 验证修正后的特征值
        print("修正后的最小特征值 (sigma1):", np.min(eigvals1))
    if np.min(eigvals2) < 0:
        print("Min eigval of sigma2:", np.min(eigvals2))
        sigma2 = sigma2 + eps * np.eye(sigma2.shape[0])
        eigvals2 = np.linalg.eigvals(sigma2)      
        print("修正后的最小特征值 (sigma2):", np.min(eigvals2))

    assert (
        mu1.shape == mu2.shape
    ), "Training and test mean vectors have different lengths"
    assert (
        sigma1.shape == sigma2.shape
    ), "Training and test covariances have different dimensions"

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        print(
            f"WARNING: fid calculation produces singular product; adding {eps} to diagonal of cov"
        )
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            assert False, "Imaginary component {}".format(m)
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    fid = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean

    return {
        "frechet_distance": float(fid),
    }
