# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/

"""Script for calculating Frechet Inception Distance (FID)."""

import os
import click
import tqdm
import time
import sklearn.metrics
import pickle
import numpy as np
import scipy.linalg
import torch
import dnnlib
import gudhi
from scipy.spatial import distance

from torch_utils import distributed as dist
from training import dataset
from training import classifier
import json


###########################################
#   Automatic Grid Search
###########################################
def set_grid(data):
    import numpy as np

    # find min max
    dim = len(data[0])
    mins = np.array([])
    maxs = np.array([])
    for dims in range(dim):
        mins = np.append(mins, min(data[:, dims]))
        maxs = np.append(maxs, max(data[:, dims]))

    # set grid
    # 2 dimensional data
    if len(mins) == 2:
        xval = np.linspace(mins[0], maxs[0], 1000)
        yval = np.linspace(mins[1], maxs[1], 1000)
        positions = np.array([[u, v] for u in xval for v in yval])
    # 3 dimensional data
    elif len(mins) == 3:
        xval = np.linspace(mins[0], maxs[0], 100)
        yval = np.linspace(mins[1], maxs[1], 100)
        zval = np.linspace(mins[2], maxs[2], 100)
        positions = np.array([[u, v, k] for u in xval for v in yval for k in zval])

    return positions


###########################################
#   KDE with Epanechinikov Kernel
###########################################
def compact_KDE(data, position, h, kernel="cosine"):
    # compact kernel options = {"epanechinikov", "cosine"}
    p_hat = np.array([])
    dist = sklearn.metrics.pairwise.euclidean_distances(position, data)

    # Epanechinikov kernel
    if kernel == "epanechinikov":
        for iloop in range(len(dist)):
            sample_score = dist[iloop][np.where(dist[iloop] ** 2 <= (h**2))]
            p_hat = np.append(
                p_hat,
                (1 / len(data))
                * ((3 / (4 * h)) ** len(data[0]))
                * ((len(sample_score)) - np.sum(sample_score / (h**2))),
            )
        return p_hat

    # Cosine kernel
    elif kernel == "cosine":
        for iloop in range(len(dist)):
            sample_score = dist[iloop][np.where(dist[iloop] ** 2 <= (h**2))]
            p_hat = np.append(
                p_hat,
                (1 / len(data))
                * ((np.pi / (4 * h)) ** len(data[0]))
                * np.sum(np.cos((np.pi / 2) * (sample_score / h))),
            )
        return p_hat


###########################################
# Confidence Band
###########################################
def confband_est(data, h, alpha=0.1, kernel="cosine", p_val=True, repeat=10):
    if not isinstance(data, np.ndarray):
        data = np.asarray(data)

    positions = data

    # p_hat
    p_hat = compact_KDE(data, positions, h, kernel=kernel)

    # p_tilde
    theta_star = np.array([])
    for iloop in range(repeat):
        data_bs = data[
            np.random.choice(np.arange(len(data)), size=len(data), replace=True)
        ]
        p_tilde = compact_KDE(data_bs, positions, h, kernel=kernel)

        # theta
        theta_star = np.append(
            theta_star, np.sqrt(len(data)) * np.max(np.abs(p_hat - p_tilde))
        )

    # q_alpha
    if len(theta_star) == 0:
        q_alpha = 0
    else:
        q_alpha = np.quantile(theta_star, 1 - alpha)

    # confidence band
    if p_val == True:
        return q_alpha / np.sqrt(len(data)), p_hat
    else:
        return q_alpha / np.sqrt(len(data))


###########################################
# BandWidth estimator
###########################################
def bandwidth_est(
    data,
    bandwidth_list=[],
    confidence_band=False,
    kernel="cosine",
    alpha=0.1,
    Plot=False,
):
    # non-compact kernel list = {'gaussian','exponential'} | compact kernel list = {'tophat','epanechnikov','linear','cosine'}
    if not isinstance(data, np.ndarray):
        data = np.asarray(data)

    # estimate bandwidth candidates with "balloon estimator" (variable-bandwidth estimation)
    if len(bandwidth_list) == 0:
        dist = distance.cdist(data, data, metric="euclidean")
        dist = dist[:-50,]
        for iloop in range(len(dist)):
            if iloop == 0:
                balloon_est = np.array(sorted(dist[iloop, (iloop + 1) :])[50 - 1])
            else:
                balloon_est = np.append(
                    balloon_est, sorted(dist[iloop, (iloop + 1) :])[50 - 1]
                )
        balloon_est = sorted(balloon_est)
        bandwidth_list = balloon_est[
            int(len(balloon_est) * 0.05) - 1
        ]  # top 5% estimated bandwidth
        bandwidth_list = np.append(
            bandwidth_list, balloon_est[int(len(balloon_est) * 0.2) - 1]
        )  # top 20% estimated bandwidth
        bandwidth_list = np.append(
            bandwidth_list, balloon_est[int(len(balloon_est) * 0.35) - 1]
        )  # top 35% estimated bandwidth
        bandwidth_list = np.append(
            bandwidth_list, balloon_est[int(len(balloon_est) * 0.5) - 1]
        )  # median estimated bandwidth
        bandwidth_list = np.append(
            bandwidth_list, balloon_est[int(len(balloon_est) * 0.65) - 1]
        )  # top 65% estimated bandwidth
        bandwidth_list = np.append(
            bandwidth_list, balloon_est[int(len(balloon_est) * 0.8) - 1]
        )  # top 80% estimated bandwidth
        bandwidth_list = np.append(
            bandwidth_list, balloon_est[int(len(balloon_est) * 0.95) - 1]
        )  # top 95% estimated bandwidth

    # estimate bandwidth
    n_h0 = np.array([])
    s_h0 = np.array([])
    cn_list = np.array([])
    for h in tqdm(bandwidth_list):
        # confidence band & p_hat
        cn = confband_est(data, h, alpha=alpha, kernel=kernel, p_val=False)
        cn_list = np.append(cn_list, cn)

        grid = set_grid(data)
        # filtration
        p_hat = compact_KDE(data, grid, h, kernel=kernel)
        PD = gudhi.CubicalComplex(
            dimensions=[
                round(len(grid) ** (1 / grid.shape[1])),
                round(len(grid) ** (1 / grid.shape[1])),
            ],
            top_dimensional_cells=-p_hat,
        ).persistence()

        # measure life length of all homology
        l_h0 = np.array([])
        for iloop in range(len(PD)):
            if PD[iloop][0] == 0:
                if np.abs(PD[iloop][1][1] - PD[iloop][1][0]) != float("inf"):
                    l_h0 = np.append(l_h0, np.abs(PD[iloop][1][1] - PD[iloop][1][0]))

        # N(h)
        n_h0 = np.append(n_h0, sum(l_h0 > cn) + 1)

        # S(h)
        S_h0 = l_h0 - cn
        s_h0 = np.append(s_h0, sum(list(filter(lambda S_h0: S_h0 > 0, S_h0))))
        print(
            "bandwidth: ",
            h,
            ", N_0(h): ",
            n_h0[-1],
            ", S_0(h): ",
            s_h0[-1],
            ", cn: ",
            cn,
        )

    try:
        if sum(s_h0 == max(s_h0)) == 1:
            if confidence_band == True:
                return (
                    bandwidth_list[s_h0.tolist().index(max(s_h0))],
                    cn_list[s_h0.tolist().index(max(s_h0))],
                )
            elif confidence_band == False:
                return bandwidth_list[s_h0.tolist().index(max(s_h0))]
        else:
            return bandwidth_list[0]
    except Exception as e:
        print(e)
        raise SystemExit


###########################################
#   Top P&R
###########################################
def compute_top_pr(
    *,
    real_features,
    fake_features,
    alpha=0.1,
    kernel="cosine",
    random_proj=True,
    f1_score=True,
    l2norm=False,
):
    """
    Computing Top Precision and Recall
        Args:
            real_features (n, d): input real features
            fake_features (n, d): input fake features
            alpha (float): significance level alpha in confidence band estimation (default=0.1)
            kernel (str): kernel for KDE                                          (default='cosine')
            random_proj (bool): If true, it will add linear layer from Pytorch library for random projection. (default=True)
                                However, If the dimension of the feature is lower than 32, even though random_proj is True, random projection will not be turned on.
            f1_score (bool): If True, it caculates f1 score for getting a 1-score evaluation (default=True)
        Returns:
            evaluation score (dict): fidelity, diversity and (opitionally f1 score.)

    """

    # --- helpers for robustness ---
    def _safe_bandwidth(data_np, h_candidate, floor_frac=1e-3, eps=1e-8):
        # Ensure strictly positive bandwidth. Use a small fraction of the median pairwise distance as a floor.
        if not isinstance(data_np, np.ndarray):
            data_np = np.asarray(data_np)
        if data_np.ndim != 2 or data_np.shape[0] < 2:
            return max(float(h_candidate), eps)
        dmat = distance.cdist(data_np, data_np, metric="euclidean")
        # take strictly positive distances only
        pos = dmat[dmat > 0]
        med = float(np.median(pos)) if pos.size > 0 else 1.0
        floor = max(eps, floor_frac * med)
        return float(max(h_candidate, floor))

    # match data format for random projection
    if torch.is_tensor(real_features) == False:
        real_features = torch.tensor(real_features, dtype=torch.float32)
    if torch.is_tensor(fake_features) == False:
        fake_features = torch.tensor(fake_features, dtype=torch.float32)

    # random projection
    if (random_proj == True) and (real_features.size()[1] > 32):

        projection = torch.nn.Linear(real_features.size()[1], 32, bias=False).eval()
        torch.manual_seed(99)
        torch.nn.init.xavier_normal_(projection.weight)
        for param in projection.parameters():
            param.requires_grad_(False)
        real_features = projection(real_features)
        fake_features = projection(fake_features)

    # to numpy
    real_features = real_features.detach().cpu().numpy()
    fake_features = fake_features.detach().cpu().numpy()

    # Optional L2 normalization
    if l2norm:
        real_features = real_features / (
            np.linalg.norm(real_features, axis=1, keepdims=True) + 1e-12
        )
        fake_features = fake_features / (
            np.linalg.norm(fake_features, axis=1, keepdims=True) + 1e-12
        )

    # use bandwidth estimator to calculate Top P&R
    if len(real_features[0]) <= 3:
        bandwidth_r, c_r = bandwidth_est(
            real_features, bandwidth_list=[], confidence_band=True, alpha=alpha
        )
        bandwidth_f, c_g = bandwidth_est(
            fake_features, bandwidth_list=[], confidence_band=True, alpha=alpha
        )
        bandwidth_r = _safe_bandwidth(real_features, bandwidth_r)
        bandwidth_f = _safe_bandwidth(fake_features, bandwidth_f)
        c_r, score_rr = confband_est(
            data=real_features, h=bandwidth_r, alpha=alpha, kernel=kernel, p_val=True
        )
        c_g, score_gg = confband_est(
            data=fake_features, h=bandwidth_f, alpha=alpha, kernel=kernel, p_val=True
        )
    else:
        # Robust balloon estimator for bandwidths in high dimension
        n_r, d_r = real_features.shape
        n_f, d_f = fake_features.shape

        k_r = max(1, min(d_r * 5, n_r - 1))
        dmat_r = distance.cdist(real_features, real_features, metric="euclidean")
        balloon_est = []
        for i in range(n_r):
            row = np.delete(dmat_r[i], i)  # drop self-distance
            if row.size == 0:
                continue
            row.sort()
            idx = min(k_r - 1, row.size - 1)
            balloon_est.append(row[idx])
        if len(balloon_est) == 0:
            # fallback: small fraction of median distance
            pos = dmat_r[dmat_r > 0]
            med = float(np.median(pos)) if pos.size > 0 else 1.0
            bandwidth_r = 1e-3 * med
        else:
            balloon_est = np.sort(np.asarray(balloon_est))
            bandwidth_r = balloon_est[len(balloon_est) // 2]  # median

        k_f = max(1, min(d_f * 5, n_f - 1))
        dmat_f = distance.cdist(fake_features, fake_features, metric="euclidean")
        balloon_est = []
        for i in range(n_f):
            row = np.delete(dmat_f[i], i)
            if row.size == 0:
                continue
            row.sort()
            idx = min(k_f - 1, row.size - 1)
            balloon_est.append(row[idx])
        if len(balloon_est) == 0:
            pos = dmat_f[dmat_f > 0]
            med = float(np.median(pos)) if pos.size > 0 else 1.0
            bandwidth_f = 1e-3 * med
        else:
            balloon_est = np.sort(np.asarray(balloon_est))
            bandwidth_f = balloon_est[len(balloon_est) // 2]

        # enforce strictly positive, sane bandwidths
        bandwidth_r = _safe_bandwidth(real_features, bandwidth_r)
        bandwidth_f = _safe_bandwidth(fake_features, bandwidth_f)

        # estimation of confidence band and manifold
        c_r, score_rr = confband_est(
            data=real_features, h=bandwidth_r, alpha=alpha, kernel=kernel, p_val=True
        )
        c_g, score_gg = confband_est(
            data=fake_features, h=bandwidth_f, alpha=alpha, kernel=kernel, p_val=True
        )

    # Replace NaNs (can arise with degenerate bandwidths) by zeros so comparisons work
    if np.isnan(score_rr).any():
        score_rr = np.nan_to_num(score_rr, nan=0.0)
    if np.isnan(score_gg).any():
        score_gg = np.nan_to_num(score_gg, nan=0.0)

    # count significant real & fake samples
    num_real = np.sum(score_rr > c_r)
    num_fake = np.sum(score_gg > c_g)

    # count significant fake samples on real manifold
    score_rg = compact_KDE(fake_features, real_features, bandwidth_f, kernel=kernel)
    inter_r = np.sum((score_rr > c_r) * (score_rg > c_g))

    # count significant real samples on fake manifold
    score_gr = compact_KDE(real_features, fake_features, bandwidth_r, kernel=kernel)
    inter_g = np.sum((score_gg > c_g) * (score_gr > c_r))

    # Avoid divide-by-zero; if no significant samples, precision/recall are 0.0
    t_precision = (inter_g / num_fake) if num_fake > 0 else 0.0
    t_recall = (inter_r / num_real) if num_real > 0 else 0.0

    # top f1-score
    if f1_score == True:
        if t_precision > 0.0001 and t_recall > 0.0001:
            F1_score = 2 / ((1 / t_precision) + (1 / t_recall))
        else:
            F1_score = 0
        return dict(fidelity=t_precision, diversity=t_recall, Top_F1=F1_score)
    else:
        return t_precision, t_recall


# ----------------------------------------------------------------------------


def PCA(features_ref, features, pca_dim=100, whiten=False):
    """Concatenate features and apply PCA, return the splitted features"""
    all_features = np.concatenate([features_ref, features], axis=0)
    all_features_mean = np.mean(all_features, axis=0, keepdims=True)
    all_features_centered = all_features - all_features_mean
    cov = np.cov(all_features_centered, rowvar=False)
    U, S, Vt = np.linalg.svd(cov)
    W = U[:, :pca_dim]
    if whiten:
        W = W / np.sqrt(S[:pca_dim] + 1e-5)
    all_features_pca = np.dot(all_features_centered, W)
    features_ref_pca = all_features_pca[: features_ref.shape[0]]
    features_pca = all_features_pca[features_ref.shape[0] :]
    return features_ref_pca, features_pca


def compute_pairwise_distances(X, Y=None):
    """
    args:
        X: np.array of shape N x dim
        Y: np.array of shape N x dim
    returns:
        N x N symmetric np.array
    """
    num_X = X.shape[0]
    if Y is None:
        num_Y = num_X
    else:
        num_Y = Y.shape[0]
    X = X.astype(np.float64)  # to prevent underflow
    X_norm_square = np.sum(X**2, axis=1, keepdims=True)
    if Y is None:
        Y_norm_square = X_norm_square
    else:
        Y_norm_square = np.sum(Y**2, axis=1, keepdims=True)
    X_square = np.repeat(X_norm_square, num_Y, axis=1)
    Y_square = np.repeat(Y_norm_square.T, num_X, axis=0)
    if Y is None:
        Y = X
    XY = np.dot(X, Y.T)
    diff_square = X_square - 2 * XY + Y_square

    # check negative distance
    min_diff_square = diff_square.min()
    if min_diff_square < 0:
        idx = diff_square < 0
        diff_square[idx] = 0
        # logging.info('WARNING: %d negative diff_squares found and set to zero, min_diff_square=' % idx.sum(),
        #       min_diff_square)

    distances = np.sqrt(diff_square)
    return distances


def distances2radii(distances, k):
    num_features = distances.shape[0]
    radii = np.zeros(num_features)
    t0 = time.time()

    for i in range(num_features):
        if i == 1000:
            dist.print0(
                f"Estimated time to finish: {(time.time() - t0) / i * (num_features - i) / 60:.2f} minutes"
            )
        radii[i] = get_kth_value(distances[i], k=k)
    return radii


def get_kth_value(np_array, k):
    kprime = k + 1  # kth NN should be (k+1)th because closest one is itself
    idx = np.argpartition(np_array, kprime)
    k_smallests = np_array[idx[:kprime]]
    kth_value = k_smallests.max()
    return kth_value


def is_in_ball(center, radius, subject):
    return distance(center, subject) < radius


# def distance(feat1, feat2):
#     return np.linalg.norm(feat1 - feat2)


def calculate_prdc(real_features, fake_features, k, real_radii=None, fake_radii=None):
    if fake_radii is None:
        fake_distances = sklearn.metrics.pairwise_distances(
            fake_features, n_jobs=-1, metric="euclidean"
        )
        fake_radii = distances2radii(fake_distances, k=k)
    distance_real_fake = sklearn.metrics.pairwise_distances(
        real_features, fake_features, n_jobs=-1, metric="euclidean"
    )
    if real_radii is None:
        real_distances = sklearn.metrics.pairwise_distances(
            real_features, n_jobs=-1, metric="euclidean"
        )
        real_radii = distances2radii(real_distances, k=k)

    precision = (
        (distance_real_fake < np.expand_dims(real_radii, axis=1)).any(axis=0).mean()
    )

    recall = (
        (distance_real_fake < np.expand_dims(fake_radii, axis=0)).any(axis=1).mean()
    )

    density = (1.0 / float(k)) * (
        distance_real_fake < np.expand_dims(real_radii, axis=1)
    ).sum(axis=0).mean()

    coverage = (distance_real_fake.min(axis=1) < real_radii).mean()

    return precision, recall, density, coverage


# ----------------------------------------------------------------------------


def calculate_dino_stats(
    image_path,
    num_expected=None,
    seed=0,
    max_batch_size=64,
    num_workers=3,
    prefetch_factor=2,
    device=torch.device("cuda"),
):
    # Rank 0 goes first.
    if dist.get_rank() != 0:
        torch.distributed.barrier()

    # Load Dino v2 model.
    detector_net = classifier.FeatureExtractor(url=image_path).to(device)
    feature_dim = 768
    detector_kwargs = {}
    # List images.
    dist.print0(f'Loading images from "{image_path}"...')
    dataset_obj = dataset.ImageFolderDataset(
        path=image_path,
        max_size=num_expected,
        random_seed=seed,
        use_labels=True,
    )
    if num_expected is not None and len(dataset_obj) < num_expected:
        raise click.ClickException(
            f"Found {len(dataset_obj)} images, but expected at least {num_expected}"
        )
    if len(dataset_obj) < 2:
        raise click.ClickException(
            f"Found {len(dataset_obj)} images, but need at least 2 to compute statistics"
        )

    # Other ranks follow.
    if dist.get_rank() == 0:
        torch.distributed.barrier()

    # Divide images into batches.
    num_batches = (
        (len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1
    ) * dist.get_world_size()
    all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
    rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
    data_loader = torch.utils.data.DataLoader(
        dataset_obj,
        batch_sampler=rank_batches,
        num_workers=num_workers,
        prefetch_factor=prefetch_factor,
    )
    num_labels = dataset_obj.label_dim
    # Accumulate statistics.
    dist.print0(f"Calculating statistics for {len(dataset_obj)} images...")
    mu = [
        torch.zeros([feature_dim], dtype=torch.float64, device=device)
        for _ in range(num_labels + 1)
    ]
    sigma = [
        torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
        for _ in range(num_labels + 1)
    ]
    list_features = [[] for _ in range(num_labels + 1)]
    t0 = time.time()
    for k, (images, _labels) in enumerate(data_loader):
        if k == 100:
            dist.print0(
                f"Estimated time to finish: {(time.time() - t0) / k * (len(data_loader) - k) / 60:.2f} minutes"
            )

        torch.distributed.barrier()
        if images.shape[0] == 0:
            continue
        if images.shape[1] == 1:
            images = images.repeat([1, 3, 1, 1])
        with torch.no_grad():
            features = detector_net(images.to(device), **detector_kwargs).to(
                torch.float64
            )
        mu[0] += features.sum(0)
        sigma[0] += features.T @ features
        list_features[0].append(features.cpu())
        # labels are either a single label or a one-hot encoded vector
        if _labels.ndim == 1:
            _labels = _labels.unsqueeze(1)
        _labels = torch.argmax(_labels, dim=1)
        # Count labels.

        for label in range(1, num_labels + 1):
            idx = _labels == label - 1
            if idx.sum() == 0:
                continue
            mu[label] += features[idx].sum(0)
            sigma[label] += features[idx].T @ features[idx]
            list_features[label].append(features[idx].cpu())
    for label in range(num_labels + 1):
        list_features[label] = torch.cat(list_features[label], dim=0)

    gathered_features = [
        [torch.zeros_like(list_features[i]) for _ in range(dist.get_world_size())]
        for i in range(num_labels + 1)
    ]
    # Calculate grand totals.
    for label in range(num_labels + 1):
        torch.distributed.all_reduce(mu[label])
        torch.distributed.all_reduce(sigma[label])
        torch.distributed.all_gather_object(
            gathered_features[label], list_features[label]
        )

        gathered_features[label] = torch.cat(gathered_features[label], dim=0)
        mu[label] /= len(gathered_features[label])
        sigma[label] -= mu[label].ger(mu[label]) * len(gathered_features[label])
        sigma[label] /= len(gathered_features[label]) - 1
    return (
        [x.cpu().numpy() for x in mu],
        [x.cpu().numpy() for x in sigma],
        [x.numpy() for x in gathered_features],
    )


def calculate_inception_stats(
    image_path,
    num_expected=None,
    seed=0,
    max_batch_size=64,
    num_workers=3,
    prefetch_factor=2,
    device=torch.device("cuda"),
):
    # Rank 0 goes first.
    if dist.get_rank() != 0:
        torch.distributed.barrier()

    # Load Inception-v3 model.
    # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
    dist.print0("Loading Inception-v3 model...")
    detector_url = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl"
    detector_kwargs = dict(return_features=True)
    feature_dim = 2048
    with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f:
        detector_net = pickle.load(f).to(device)

    # List images.
    dist.print0(f'Loading images from "{image_path}"...')
    dataset_obj = dataset.ImageFolderDataset(
        path=image_path,
        max_size=num_expected,
        random_seed=seed,
        use_labels=True,
    )
    if num_expected is not None and len(dataset_obj) < num_expected:
        raise click.ClickException(
            f"Found {len(dataset_obj)} images, but expected at least {num_expected}"
        )
    if len(dataset_obj) < 2:
        raise click.ClickException(
            f"Found {len(dataset_obj)} images, but need at least 2 to compute statistics"
        )

    # Other ranks follow.
    if dist.get_rank() == 0:
        torch.distributed.barrier()

    # Divide images into batches.
    num_batches = (
        (len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1
    ) * dist.get_world_size()
    all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
    rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
    data_loader = torch.utils.data.DataLoader(
        dataset_obj,
        batch_sampler=rank_batches,
        num_workers=num_workers,
        prefetch_factor=prefetch_factor,
    )
    num_labels = dataset_obj.label_dim
    # Accumulate statistics.
    dist.print0(f"Calculating statistics for {len(dataset_obj)} images...")
    mu = [
        torch.zeros([feature_dim], dtype=torch.float64, device=device)
        for _ in range(num_labels + 1)
    ]
    sigma = [
        torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
        for _ in range(num_labels + 1)
    ]
    list_features = [[] for _ in range(num_labels + 1)]
    t0 = time.time()
    for k, (images, _labels) in enumerate(data_loader):
        if k == 100:
            dist.print0(
                f"Estimated time to finish: {(time.time() - t0) / k  * (len(data_loader) - k) / 60:.2f} minutes"
            )
        torch.distributed.barrier()
        if images.shape[0] == 0:
            continue
        if images.shape[1] == 1:
            images = images.repeat([1, 3, 1, 1])
        features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)
        mu[0] += features.sum(0)
        sigma[0] += features.T @ features
        list_features[0].append(features.cpu())
        # labels are either a single label or a one-hot encoded vector
        if _labels.ndim == 1:
            _labels = _labels.unsqueeze(1)
        _labels = torch.argmax(_labels, dim=1)
        # Count labels.

        for label in range(1, num_labels + 1):
            idx = _labels == label - 1
            if idx.sum() == 0:
                continue
            mu[label] += features[idx].sum(0)
            sigma[label] += features[idx].T @ features[idx]
            list_features[label].append(features[idx].cpu())
    for label in range(num_labels + 1):
        list_features[label] = torch.cat(list_features[label], dim=0)

    gathered_features = [
        [torch.zeros_like(list_features[i]) for _ in range(dist.get_world_size())]
        for i in range(num_labels + 1)
    ]
    # Calculate grand totals.
    for label in range(num_labels + 1):
        torch.distributed.all_reduce(mu[label])
        torch.distributed.all_reduce(sigma[label])
        torch.distributed.all_gather_object(
            gathered_features[label], list_features[label]
        )

        gathered_features[label] = torch.cat(gathered_features[label], dim=0)
        mu[label] /= len(gathered_features[label])
        sigma[label] -= mu[label].ger(mu[label]) * len(gathered_features[label])
        sigma[label] /= len(gathered_features[label]) - 1
    return (
        [x.cpu().numpy() for x in mu],
        [x.cpu().numpy() for x in sigma],
        [x.numpy() for x in gathered_features],
    )


# ----------------------------------------------------------------------------


def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
    m = np.square(mu - mu_ref).sum()
    s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
    fid = m + np.trace(sigma + sigma_ref - s * 2)
    return float(np.real(fid))


# ----------------------------------------------------------------------------


@click.group()
def main():
    """Calculate Frechet Inception Distance (FID).

    Examples:

    \b
    # Generate 50000 images and save them as fid-tmp/*/*.png
    torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\
        --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl

    \b
    # Calculate FID
    torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\
        --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz

    \b
    # Compute dataset reference statistics
    python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz
    """


# ----------------------------------------------------------------------------


@main.command()
@click.option(
    "--images",
    "image_path",
    help="Path to the images",
    metavar="PATH|ZIP",
    type=str,
    required=True,
)
@click.option(
    "--seed",
    help="Random seed for selecting the images",
    metavar="INT",
    type=int,
    default=0,
    show_default=True,
)
@click.option(
    "--batch",
    help="Maximum batch size",
    metavar="INT",
    type=click.IntRange(min=1),
    default=64,
    show_default=True,
)
@click.option(
    "--dataset",
    help="Dataset name",
    metavar="STR",
    type=str,
    default=None,
    show_default=True,
)
@click.option(
    "--pca",
    help="Apply PCA to the features before calculating PRDC",
    is_flag=True,
    default=True,
    show_default=True,
)
@click.option(
    "--pca-dim",
    help="Dimensionality to reduce to with PCA",
    metavar="INT",
    type=click.IntRange(min=1),
    default=64,  # 64
    show_default=True,
)
@click.option(
    "--whiten",
    help="Apply whitening to the features before calculating PRDC (only if --pca is also passed)",
    is_flag=True,
    default=False,
    show_default=True,
)
# New options for TopPR
@click.option(
    "--toppr-alpha",
    help="Significance level alpha for Top P/R confidence bands (higher = less strict)",
    metavar="FLOAT",
    type=float,
    default=0.2,
    show_default=True,
)
@click.option(
    "--toppr-randproj/--no-toppr-randproj",
    help="Enable/disable 32D random projection inside Top P/R",
    default=True,
    show_default=True,
)
@click.option(
    "--toppr-n",
    help="Number of features to sample per evaluation (per label)",
    metavar="INT",
    type=click.IntRange(min=100),
    default=5000,
    show_default=True,
)
@click.option(
    "--toppr-repeats",
    help="Number of random resamples (averaged) for Top P/R",
    metavar="INT",
    type=click.IntRange(min=1),
    default=5,
    show_default=True,
)
@click.option(
    "--l2norm/--no-l2norm",
    help="L2-normalize features before KDE for Top P/R",
    default=False,
    show_default=True,
)
@click.option(
    "--reset-features",
    help="Reset cached features (if previously calculated with different settings)",
    is_flag=True,
    default=False,
    show_default=True,
)
def calc(
    image_path,
    seed,
    batch,
    dataset,
    pca,
    pca_dim,
    whiten,
    toppr_alpha,
    toppr_randproj,
    toppr_n,
    toppr_repeats,
    l2norm,
    reset_features,
):
    """Calculate FID for a given set of images."""
    torch.multiprocessing.set_start_method("spawn")
    dist.init()
    torch.manual_seed(seed)
    np.random.seed(seed)
    if "cifar10" in dataset:
        ref_path = "eval-refs/cifar10-32x32.npz"
    elif "ffhq" in dataset:
        ref_path = "eval-refs-all/ffhq-64x64.npz"
    elif "afhq" in dataset:
        ref_path = "eval-refs-all/afhqv2-64x64.npz"
    else:
        raise NotImplementedError
    label = 0
    ref_path_labels = [ref_path]
    while os.path.exists(ref_path.split(".")[0] + f"-{label}.npz"):
        ref_path_labels.append(ref_path.split(".")[0] + f"-{label}.npz")
        label += 1

    ref = []
    for ref_path in ref_path_labels:
        dist.print0(f'Using dataset reference statistics from "{ref_path}"...')

        if dist.get_rank() == 0:
            with dnnlib.util.open_url(ref_path) as f:
                ref.append(dict(np.load(f)))
            print(len(ref[-1]["features_inc"]), len(ref[-1]["features_dino"]))

    # Check if features were previously calculated with different settings

    output_path = os.path.dirname(image_path)
    feature_settings_path = os.path.join(output_path, "features.npy")
    if os.path.exists(feature_settings_path) and not reset_features:
        if dist.get_rank() == 0:
            with dnnlib.util.open_url(feature_settings_path) as f:
                feature_settings = dict(np.load(f, allow_pickle=True).item())
                mus_inc = feature_settings["mus_inc"]
                sigmas_inc = feature_settings["sigmas_inc"]
                featuress_inc = feature_settings["featuress_inc"]
                mus_dino = feature_settings["mus_dino"]
                sigmas_dino = feature_settings["sigmas_dino"]
                featuress_dino = feature_settings["featuress_dino"]
    else:

        mus_inc, sigmas_inc, featuress_inc = calculate_inception_stats(
            image_path=image_path,
            seed=seed,
            max_batch_size=batch,
        )
        mus_dino, sigmas_dino, featuress_dino = calculate_dino_stats(
            image_path=image_path,
            seed=seed,
            max_batch_size=batch,
        )
        if dist.get_rank() == 0:
            feature_settings = {
                "mus_inc": mus_inc,
                "sigmas_inc": sigmas_inc,
                "featuress_inc": featuress_inc,
                "mus_dino": mus_dino,
                "sigmas_dino": sigmas_dino,
                "featuress_dino": featuress_dino,
            }
            np.save(
                feature_settings_path,
                feature_settings,
            )
            dist.print0(f"Saved calculated features to {feature_settings_path}")
    num_features = [f.shape[0] for f in featuress_inc]
    featuress_inc = [f[:20000] for f in featuress_inc]
    featuress_dino = [f[:20000] for f in featuress_dino]

    dist.print0(
        f"TopPR settings — alpha={toppr_alpha}, randproj={toppr_randproj}, l2norm={l2norm}, n={toppr_n}, repeats={toppr_repeats}"
    )
    if dist.get_rank() == 0:
        data = {}
        for i in range(len(mus_inc)):
            metric_name = "_all" if i == 0 else f"-{i-1}"

            dist.print0(
                f"Calculating FID, P, R, D, C for {metric_name[1:]} on {image_path} with TopPR n={toppr_n} (out of 20k)."
            )

            # Inception TopPR
            Ps, Rs, Ds, Cs = [], [], [], []
            num_test = toppr_repeats
            pool_size = min(20000, featuress_inc[i].shape[0])
            use_n = min(toppr_n, pool_size)
            for _ in range(num_test):
                index = np.random.permutation(pool_size)[:use_n]
                features_inc_ref = ref[i]["features_inc"][index]
                features_inc = featuress_inc[i][index]
                if pca:
                    features_inc_ref, features_inc = PCA(
                        features_inc_ref, features_inc, pca_dim=pca_dim, whiten=whiten
                    )
                features_ref = features_inc_ref
                Pinc, Rinc = compute_top_pr(
                    real_features=features_ref,
                    fake_features=features_inc,
                    alpha=toppr_alpha,
                    kernel="cosine",
                    random_proj=toppr_randproj,
                    f1_score=False,
                    l2norm=l2norm,
                )
                Dinc, Cinc = 0, 0
                Ps.append(Pinc)
                Rs.append(Rinc)
                Ds.append(Dinc)
                Cs.append(Cinc)
            Pinc = np.mean(Ps)
            Rinc = np.mean(Rs)
            Dinc = np.mean(Ds)
            Cinc = np.mean(Cs)
            Pinc_std = np.std(Ps)
            Rinc_std = np.std(Rs)
            Dinc_std = np.std(Ds)
            Cinc_std = np.std(Cs)

            # FID uses full accumulated stats; Top P/R uses use_n subsamples
            fid = calculate_fid_from_inception_stats(
                mus_inc[i], sigmas_inc[i], ref[i]["mu_inc"], ref[i]["sigma_inc"]
            )
            dist.print0(f"Results for {metric_name[1:]}:")
            dist.print0(
                f"Inception: fid: {fid:.2f}, P: {Pinc:.4f} pm {Pinc_std:.4f}, R: {Rinc:.4f} pm {Rinc_std:.4f}, D: {Dinc:.4f} pm {Dinc_std:.4f}, C: {Cinc:.4f} pm {Cinc_std:.4f}"
            )

            # DINO TopPR
            Ps_dino, Rs_dino, Ds_dino, Cs_dino = [], [], [], []
            num_test = toppr_repeats
            pool_size_dino = min(20000, featuress_dino[i].shape[0])
            use_n_dino = min(toppr_n, pool_size_dino)
            for _ in range(num_test):
                index = np.random.permutation(pool_size_dino)[:use_n_dino]
                features_ref_dino = ref[i]["features_dino"][index]
                features_dino = featuress_dino[i][index]
                if pca:
                    features_ref_dino, features_dino = PCA(
                        features_ref_dino,
                        features_dino,
                        pca_dim=pca_dim,
                        whiten=whiten,
                    )
                features_ref = features_ref_dino
                Pdino, Rdino = compute_top_pr(
                    real_features=features_ref,
                    fake_features=features_dino,
                    alpha=toppr_alpha,
                    kernel="cosine",
                    random_proj=toppr_randproj,
                    f1_score=False,
                    l2norm=l2norm,
                )
                Ddino, Cdino = 0, 0
                Ps_dino.append(Pdino)
                Rs_dino.append(Rdino)
                Ds_dino.append(Ddino)
                Cs_dino.append(Cdino)
            Pdino = np.mean(Ps_dino)
            Rdino = np.mean(Rs_dino)
            Ddino = np.mean(Ds_dino)
            Cdino = np.mean(Cs_dino)
            Pdino_std = np.std(Ps_dino)
            Rdino_std = np.std(Rs_dino)
            Ddino_std = np.std(Ds_dino)
            Cdino_std = np.std(Cs_dino)

            # FID uses full accumulated stats; Top P/R uses use_n subsamples
            fid_dino = calculate_fid_from_inception_stats(
                mus_dino[i], sigmas_dino[i], ref[i]["mu_dino"], ref[i]["sigma_dino"]
            )
            dist.print0(
                f"DINO: fid: {fid_dino:.2f}, P: {Pdino:.4f} pm {Pdino_std:.4f}, R: {Rdino:.4f} pm {Rdino_std:.4f}, D: {Ddino:.4f} pm {Ddino_std:.4f}, C: {Cdino:.4f} pm {Cdino_std:.4f}"
            )

            # Create a dictionary with the values
            data.update(
                {
                    f"fid{metric_name}": fid,
                    f"P{metric_name}": Pinc,
                    f"R{metric_name}": Rinc,
                    f"D{metric_name}": Dinc,
                    f"C{metric_name}": Cinc,
                    f"P_std{metric_name}": Pinc_std,
                    f"R_std{metric_name}": Rinc_std,
                    f"D_std{metric_name}": Dinc_std,
                    f"C_std{metric_name}": Cinc_std,
                    f"num_features{metric_name}": num_features[i],
                    f"fid_dino{metric_name}": fid_dino,
                    f"P_dino{metric_name}": Pdino,
                    f"R_dino{metric_name}": Rdino,
                    f"D_dino{metric_name}": Ddino,
                    f"C_dino{metric_name}": Cdino,
                    f"P_dino_std{metric_name}": Pdino_std,
                    f"R_dino_std{metric_name}": Rdino_std,
                    f"D_dino_std{metric_name}": Ddino_std,
                    f"C_dino_std{metric_name}": Cdino_std,
                }
            )
        # Write the dictionary to a JSON file
        output_path = os.path.dirname(image_path)
        with open(output_path + "/" + "results_eval.jsonl", "w") as file:
            file.write(json.dumps(data) + "\n")

    torch.distributed.barrier()


# ----------------------------------------------------------------------------


@main.command()
@click.option(
    "--data",
    "dataset_path",
    help="Path to the dataset",
    metavar="PATH|ZIP",
    type=str,
    required=True,
)
@click.option(
    "--batch",
    help="Maximum batch size",
    metavar="INT",
    type=click.IntRange(min=1),
    default=64,
    show_default=True,
)
def ref(dataset_path, batch):
    """Calculate dataset reference statistics needed by 'calc'."""
    torch.multiprocessing.set_start_method("spawn")
    dist.init()
    dataset_name = dataset_path.split("/")[-1].split(".")[0]
    dest_path = os.path.join("eval-refs-all")

    mus_inc, sigmas_inc, featuress_inc = calculate_inception_stats(
        image_path=dataset_path,
        max_batch_size=batch,
    )

    mus_dino, sigmas_dino, featuress_dino = calculate_dino_stats(
        image_path=dataset_path,
        max_batch_size=batch,
    )

    if dist.get_rank() == 0:
        dist.print0(f"Estimating Manifold...")
        for label in range(len(mus_inc)):
            name = f"{dataset_name}-{label-1}" if label > 0 else dataset_name
            distances_inc = sklearn.metrics.pairwise_distances(
                featuress_inc[label], n_jobs=-1, metric="euclidean"
            )
            distances_dino = sklearn.metrics.pairwise_distances(
                featuress_dino[label], n_jobs=-1, metric="euclidean"
            )
            dist.print0(f"Saving dataset reference statistics for {name}...")
            dist.print0(f"Estimating Radii...")
            radius_inc = distances2radii(distances_inc, k=k)
            radius_dino = distances2radii(distances_dino, k=k)
            dist.print0(f'Saving dataset reference statistics to "{dest_path}"...')
            if not os.path.dirname(dest_path):
                os.makedirs(dest_path, exist_ok=True)
            np.savez(
                os.path.join(dest_path, f"{name}.npz"),
                mu_inc=mus_inc[label],
                sigma_inc=sigmas_inc[label],
                features_inc=featuress_inc[label],
                real_radii_inc=radius_inc,
                mu_dino=mus_dino[label],
                sigma_dino=sigmas_dino[label],
                features_dino=featuress_dino[label],
                real_radii_dino=radius_dino,
            )
            dist.print0(
                f"Saved dataset reference statistics for {name} to {dest_path}/{name}.npz"
            )

    torch.distributed.barrier()
    dist.print0("Done.")


# ----------------------------------------------------------------------------

if __name__ == "__main__":
    main()

# ----------------------------------------------------------------------------
