import numpy as np
from scipy import linalg
import torch

def calculate_activation_statistics(features):
    """Calculation of the statistics used by the FID.
    Params:
    -- features : Numpy array of features, shape (N, D)

    Returns:
    -- mu    : The mean over samples of the activations of the pool_3 layer of
               the inception model.
    -- sigma : The covariance matrix of the activations of the pool_3 layer of
               the inception model.
    """
    mu = np.mean(features, axis=0)
    sigma = np.cov(features, rowvar=False)
    return mu, sigma

def calculate_activation_statistics_torch(features, device='cuda'):
    """PyTorch GPU implementation of the statistics used by the FID.
    Params:
    -- features : PyTorch tensor of features, shape (N, D)
    -- device   : Device to perform calculations on ('cuda' or 'cpu')

    Returns:
    -- mu    : The mean over samples of the activations
    -- sigma : The covariance matrix of the activations
    """
    features = features.to(device)
    mu = torch.mean(features, dim=0)

    # Calculate covariance matrix (centered features)
    centered_feats = features - mu.unsqueeze(0)
    sigma = torch.matmul(centered_feats.t(), centered_feats) / (features.size(0) - 1)

    return mu, sigma

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).

    Stable version by Dougal J. Sutherland.

    Params:
    -- mu1   : Numpy array containing the activations of a layer of the
               inception net (like returned by the function 'get_predictions')
               for generated samples.
    -- mu2   : The sample mean over activations, precalculated on an
               representative data set.
    -- sigma1: The covariance matrix over activations for generated samples.
    -- sigma2: The covariance matrix over activations, precalculated on an
               representative data set.

    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert (
        mu1.shape == mu2.shape
    ), "Training and test mean vectors have different lengths"
    assert (
        sigma1.shape == sigma2.shape
    ), "Training and test covariances have different dimensions"

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = (
            "fid calculation produces singular product; "
            "adding %s to diagonal of cov estimates"
        ) % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean

def calculate_frechet_distance_torch(mu_x, sigma_x, mu_y, sigma_y):
    r"""Computes the Fréchet distance between two multivariate normal distributions :cite:`dowson1982frechet`.

    Concretely, for multivariate Gaussians :math:`X(\mu_X, \Sigma_X)`
    and :math:`Y(\mu_Y, \Sigma_Y)`, the function computes and returns :math:`F` as

    .. math::
        F(X, Y) = || \mu_X - \mu_Y ||_2^2
        + \text{Tr}\left( \Sigma_X + \Sigma_Y - 2 \sqrt{\Sigma_X \Sigma_Y} \right)

    Args:
        mu_x (torch.Tensor): mean :math:`\mu_X` of multivariate Gaussian :math:`X`, with shape `(N,)`.
        sigma_x (torch.Tensor): covariance matrix :math:`\Sigma_X` of :math:`X`, with shape `(N, N)`.
        mu_y (torch.Tensor): mean :math:`\mu_Y` of multivariate Gaussian :math:`Y`, with shape `(N,)`.
        sigma_y (torch.Tensor): covariance matrix :math:`\Sigma_Y` of :math:`Y`, with shape `(N, N)`.

    Returns:
        torch.Tensor: the Fréchet distance between :math:`X` and :math:`Y`.
    """
    if len(mu_x.size()) != 1:
        raise ValueError(f"Input mu_x must be one-dimensional; got dimension {len(mu_x.size())}.")
    if len(sigma_x.size()) != 2:
        raise ValueError(f"Input sigma_x must be two-dimensional; got dimension {len(sigma_x.size())}.")
    if sigma_x.size(0) != sigma_x.size(1) != mu_x.size(0):
        raise ValueError("Each of sigma_x's dimensions must match mu_x's size.")
    if mu_x.size() != mu_y.size():
        raise ValueError(f"Inputs mu_x and mu_y must have the same shape; got {mu_x.size()} and {mu_y.size()}.")
    if sigma_x.size() != sigma_y.size():
        raise ValueError(
            f"Inputs sigma_x and sigma_y must have the same shape; got {sigma_x.size()} and {sigma_y.size()}."
        )

    a = (mu_x - mu_y).square().sum()
    b = sigma_x.trace() + sigma_y.trace()
    c = torch.linalg.eigvals(sigma_x @ sigma_y).sqrt().real.sum()
    return a + b - 2 * c

if __name__ == "__main__":
    # Set random seed for reproducibility
    np.random.seed(42)
    torch.manual_seed(42)

    # Test with different dimensions
    for dim in [10, 64, 128, 256]:
        print(f"\nTesting with dimension {dim}")

        # Create random feature vectors
        n_samples = 1000
        features1_np = np.random.randn(n_samples, dim).astype(np.float32)
        features2_np = np.random.randn(n_samples, dim).astype(np.float32)

        # Convert to PyTorch tensors
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")

        features1_torch = torch.tensor(features1_np).to(device)
        features2_torch = torch.tensor(features2_np).to(device)

        # Calculate statistics with NumPy
        mu1_np, sigma1_np = calculate_activation_statistics(features1_np)
        mu2_np, sigma2_np = calculate_activation_statistics(features2_np)

        # Calculate FID with NumPy
        fid_np = calculate_frechet_distance(mu1_np, sigma1_np, mu2_np, sigma2_np)

        # Calculate statistics with PyTorch
        mu1_torch, sigma1_torch = calculate_activation_statistics_torch(features1_torch, device)
        mu2_torch, sigma2_torch = calculate_activation_statistics_torch(features2_torch, device)

        # Calculate FID with PyTorch
        fid_torch = calculate_frechet_distance_torch(mu1_torch, sigma1_torch, mu2_torch, sigma2_torch, device=device)

        # Convert PyTorch result to CPU for comparison
        fid_torch_value = fid_torch.cpu().numpy()

        # Compare results
        print(f"NumPy FID: {fid_np:.6f}")
        print(f"PyTorch FID: {fid_torch_value:.6f}")
        print(f"Absolute difference: {abs(fid_np - fid_torch_value):.6f}")
        print(f"Relative difference: {abs(fid_np - fid_torch_value) / fid_np * 100:.6f}%")