import os
import click
import tqdm
import pickle
import numpy as np
import torch
import dnnlib
from torch_utils import distributed as dist
from training import dataset

#----------------------------------------------------------------------------

def calculate_inception_stats(
    image_path, num_expected=None, seed=0, max_batch_size=64,
    num_workers=3, prefetch_factor=2, device=torch.device('cuda'),
):
    
    if dist.get_rank() != 0:
        torch.distributed.barrier()

    dist.print0('Loading Inception-v3 model...')
    detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
    detector_kwargs = dict(return_features=True)
    feature_dim = 2048
    with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f:
        detector_net = pickle.load(f).to(device)


    dist.print0(f'Loading images from "{image_path}"...')
    dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
    if num_expected is not None and len(dataset_obj) < num_expected:
        raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}')
    if len(dataset_obj) < 2:
        raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics')

    if dist.get_rank() == 0:
        torch.distributed.barrier()
    num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
    all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
    rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
    data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor)
    
   
    dist.print0(f'Extracting features for {len(dataset_obj)} images...')
    features_list = []
    
    for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)):
        torch.distributed.barrier()
        if images.shape[0] == 0:
            continue
        if images.shape[1] == 1:
            images = images.repeat([1, 3, 1, 1])
        features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)  # shape: [batch_size, feature_dim]
        features_list.append(features)
    
    if len(features_list) == 0:
        local_features = torch.empty((0, feature_dim), dtype=torch.float64, device='cpu')
    else:
        local_features = torch.cat(features_list, dim=0).cpu()
    
    world_size = dist.get_world_size()
    features_gathered = [None for _ in range(world_size)]
    torch.distributed.all_gather_object(features_gathered, local_features)
    
    if dist.get_rank() == 0:
        all_features = torch.cat(features_gathered, dim=0)
    else:
        all_features = None

    return all_features if dist.get_rank() == 0 else None

#----------------------------------------------------------------------------
def calculate_kid_from_inception_features(gen_features, ref_features, subset_size=1000, num_subsets=100, device=torch.device('cuda')):
    gen_features = torch.from_numpy(gen_features).to(torch.float64).to(device)
    ref_features = torch.from_numpy(ref_features).to(torch.float64).to(device)
    n = gen_features.shape[0]
    m = ref_features.shape[0]
    d = gen_features.shape[1]
    
    def polynomial_kernel(a, b):
        return (torch.matmul(a, b.t()) / d + 1) ** 3

    kid_values = []
    for _ in range(num_subsets):
        gen_idx = torch.randperm(n)[:subset_size]
        ref_idx = torch.randperm(m)[:subset_size]
        X = gen_features[gen_idx]
        Y = ref_features[ref_idx]
        K_XX = polynomial_kernel(X, X)
        K_YY = polynomial_kernel(Y, Y)
        K_XY = polynomial_kernel(X, Y)
        mmd = (K_XX.sum() - torch.diag(K_XX).sum()) / (subset_size * (subset_size - 1)) \
              + (K_YY.sum() - torch.diag(K_YY).sum()) / (subset_size * (subset_size - 1)) \
              - 2 * K_XY.mean()
        kid_values.append(mmd.item())
    return float(np.mean(kid_values))

#----------------------------------------------------------------------------
@click.group()
def main():
    """Calculate Kernel Inception Distance (KID).
    """

#----------------------------------------------------------------------------
@main.command()
@click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True)
@click.option('--ref', 'ref_path', help='Dataset reference statistics (features)', metavar='NPZ|URL', type=str, required=True)
@click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True)
@click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True)
@click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
def calc(image_path, ref_path, num_expected, seed, batch):
    torch.multiprocessing.set_start_method('spawn')
    dist.init()

    dist.print0(f'Loading dataset reference statistics from "{ref_path}"...')
    ref = None
    if dist.get_rank() == 0:
        with dnnlib.util.open_url(ref_path) as f:
            ref = dict(np.load(f))
    
    gen_features = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch)
    dist.print0('Calculating KID...')
    if dist.get_rank() == 0:
        kid = calculate_kid_from_inception_features(gen_features.numpy(), ref['features'])
        print(f'{kid:g}')
    torch.distributed.barrier()

#----------------------------------------------------------------------------
@main.command()
@click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True)
@click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True)
@click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
def ref(dataset_path, dest_path, batch):
    """Calculate dataset reference statistics (features) needed by 'calc'."""
    torch.multiprocessing.set_start_method('spawn')
    dist.init()

    features = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch)
    dist.print0(f'Saving dataset reference statistics to "{dest_path}"...')
    if dist.get_rank() == 0:
        if os.path.dirname(dest_path):
            os.makedirs(os.path.dirname(dest_path), exist_ok=True)
        np.savez(dest_path, features=features.numpy())
    torch.distributed.barrier()
    dist.print0('Done.')

#----------------------------------------------------------------------------
if __name__ == "__main__":
    main()



#python kid.py ref --data=datasets/cifar10-32x32.zip --dest=kid-refs/cifar10-32x32.npz --batch=64
#torchrun --standalone --nproc_per_node=1 kid.py calc --images=plots/images --ref=kid-refs/cifar10-32x32.npz --num=50000 --batch=64

