# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/

"""Generate random images using the given model."""

import os
import re
import warnings
import click
import tqdm
import pickle
import numpy as np
import torch
import PIL.Image
import dnnlib
from training import dataset
from torch_utils import distributed as dist

from generate_images_custom import generate_images, parse_int_list, config_presets
from calculate_metrics import (
    load_stats,
    calculate_stats_for_files,
    calculate_metrics_from_stats,
    parse_metric_list,
    calculate_stats_for_iterable,
)

import argparse
import json
import os
import sys

import torch

from torch_fidelity_utils.defaults import DEFAULTS
from torch_fidelity_utils.helpers import process_deprecations
from torch_fidelity_utils.metrics import calculate_metrics
from torch_fidelity_utils.registry import (
    FEATURE_EXTRACTORS_REGISTRY,
    DATASETS_REGISTRY,
    SAMPLE_SIMILARITY_REGISTRY,
    INTERPOLATION_REGISTRY,
    NOISE_SOURCE_REGISTRY,
)

warnings.filterwarnings('ignore', '`resume_download` is deprecated')
warnings.filterwarnings('ignore', 'You are using `torch.load` with `weights_only=False`')
warnings.filterwarnings('ignore', '1Torch was not compiled with flash attention')


#----------------------------------------------------------------------------
# Calculate feature statistics for the given directory or ZIP of images
# in a distributed fashion. Returns an iterable that yields
# dnnlib.EasyDict(stats, images, batch_idx, num_batches)

def calculate_stats_for_files(
    image_path,             # Path to a directory or ZIP file containing the images.
    num_images      = None, # Number of images to use. None = all available images.
    seed            = 0,    # Random seed for selecting the images.
    max_batch_size  = 64,   # Maximum batch size.
    num_workers     = 2,    # How many subprocesses to use for data loading.
    prefetch_factor = 2,    # Number of images loaded in advance by each worker.
    verbose         = True, # Enable status prints?
    **stats_kwargs,         # Arguments for calculate_stats_for_iterable().
):
    # Rank 0 goes first.
    if dist.get_rank() != 0:
        torch.distributed.barrier()

    # List images.
    if verbose:
        dist.print0(f'Loading images from {image_path} ...')
    dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_images, random_seed=seed)
    if num_images is not None and len(dataset_obj) < num_images:
        raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_images}')
    if len(dataset_obj) < 2:
        raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics')

    # Other ranks follow.
    if dist.get_rank() == 0:
        torch.distributed.barrier()

    # Divide images into batches.
    num_batches = max((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1, 1) * dist.get_world_size()
    rank_batches = np.array_split(np.arange(len(dataset_obj)), num_batches)[dist.get_rank() :: dist.get_world_size()]
    data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches,
        num_workers=num_workers, prefetch_factor=(prefetch_factor if num_workers > 0 else None))

    # Return an interable for calculating the statistics.
    return calculate_stats_for_iterable(image_iter=data_loader, verbose=verbose, **stats_kwargs)

#----------------------------------------------------------------------------
# Command line interface.

@click.command()
@click.option('--preset',                   help='Configuration preset', metavar='STR',                             type=str, default=None)
@click.option('--net',                      help='Main network pickle filename', metavar='PATH|URL',                type=str, default=None)
@click.option('--gnet',                     help='Guiding network pickle filename', metavar='PATH|URL',             type=str, default=None)
@click.option('--outdir',                   help='Where to save the output images', metavar='DIR',                  type=str, required=True)
@click.option('--subdirs',                  help='Create subdirectory for every 1000 seeds',                        is_flag=True)
@click.option('--seeds',                    help='List of random seeds (e.g. 1,2,5-10)', metavar='LIST',            type=parse_int_list, default='16-19', show_default=True)
@click.option('--class', 'class_idx',       help='Class label  [default: random]', metavar='INT',                   type=click.IntRange(min=0), default=None)
@click.option('--batch', 'max_batch_size',  help='Maximum batch size', metavar='INT',                               type=click.IntRange(min=1), default=32, show_default=True)

@click.option('--steps', 'num_steps',       help='Number of sampling steps', metavar='INT',                         type=click.IntRange(min=1), default=32, show_default=True)
@click.option('--sigma_min',                help='Lowest noise level', metavar='FLOAT',                             type=click.FloatRange(min=0, min_open=True), default=0.002, show_default=True)
@click.option('--sigma_max',                help='Highest noise level', metavar='FLOAT',                            type=click.FloatRange(min=0, min_open=True), default=80, show_default=True)
@click.option('--rho',                      help='Time step exponent', metavar='FLOAT',                             type=click.FloatRange(min=0, min_open=True), default=7, show_default=True)
@click.option('--guidance',                 help='Guidance strength  [default: 1; no guidance]', metavar='FLOAT',   type=float, default=None)
@click.option('--S_churn', 'S_churn',       help='Stochasticity strength', metavar='FLOAT',                         type=click.FloatRange(min=0), default=0, show_default=True)
@click.option('--S_min', 'S_min',           help='Stoch. min noise level', metavar='FLOAT',                         type=click.FloatRange(min=0), default=0, show_default=True)
@click.option('--S_max', 'S_max',           help='Stoch. max noise level', metavar='FLOAT',                         type=click.FloatRange(min=0), default='inf', show_default=True)
@click.option('--S_noise', 'S_noise',       help='Stoch. noise inflation', metavar='FLOAT',                         type=float, default=1, show_default=True)
##### vanilla #####
@click.option('--pred_type',                help='Model prediction type', metavar='STR',                            type=str, default=None)
##### mom #####
@click.option('--beta_1',                   help='beta for ema 1', metavar='FLOAT',                                 type=float, default=0.9, show_default=True)
##### egg #####
@click.option('--beta_2',                   help='beta for ema 2', metavar='FLOAT',                                 type=float, default=0.9, show_default=True)
@click.option('--gamma',                    help='ema scale for pred_type == egg', metavar='FLOAT',                 type=float, default=0.4, show_default=True)
##### apg #####
@click.option('--eta',                      help='parallel scale for pred_type == apg', metavar='FLOAT',            type=float, default=0.0, show_default=True)
@click.option('--r_scale',                  help='scale for pred_type == apg', metavar='FLOAT',                     type=float, default=2.5, show_default=True)
##### fdg #####
@click.option('--w_low',                    help='low freq guidance scale for pred_type == fdg', metavar='FLOAT',            type=float, default=1.0, show_default=True)
@click.option('--w_high',                   help='high freq guidance scale for pred_type == fdg', metavar='FLOAT',           type=float, default=3.0, show_default=True)
@click.option('--w_par',                   help='high freq guidance scale for pred_type == fdg', metavar='FLOAT',           type=float, default=3.0, show_default=True)
##### ig #####
@click.option('--sigma_start',              help='sigma guidance scale for pred_type == ig', metavar='FLOAT',            type=float, default=0.28, show_default=True)
@click.option('--sigma_end',                help='sigma guidance scale for pred_type == ig', metavar='FLOAT',           type=float, default=2.0, show_default=True)
##### calc fid #####
@click.option('--images', 'image_path',     help='Path to the images', metavar='PATH|ZIP',                  type=str, default=None)
@click.option('--ref', 'ref_path',          help='Dataset reference statistics ', metavar='PKL|NPZ|URL',    type=str, default='https://nvlabs-fi-cdn.nvidia.com/edm2/dataset-refs/img512.pkl')
@click.option('--metrics',                  help='List of metrics to compute', metavar='LIST',              type=parse_metric_list, default='fid,fd_dinov2', show_default=True)
@click.option('--num', 'num_images',        help='Number of images to use', metavar='INT',                  type=click.IntRange(min=2), default=50000, show_default=True)
@click.option('--seed',                     help='Random seed for selecting the images', metavar='INT',     type=int, default=0, show_default=True)
@click.option('--batch', 'max_batch_size',  help='Maximum batch size', metavar='INT',                       type=click.IntRange(min=1), default=64, show_default=True)
@click.option('--workers', 'num_workers',   help='Subprocesses to use for data loading', metavar='INT',     type=click.IntRange(min=0), default=2, show_default=True)
##### calc prc #####
@click.option('--input1',                               help='First input: directory, registered dataset, or model path', metavar='STR', type=str, default=DEFAULTS['input1'], show_default=False)
@click.option('--input2',                               help='Second input: directory, registered dataset, or model path', metavar='STR', type=str, default=DEFAULTS['input2'], show_default=False)
@click.option('--gpu',                                  help='Use CUDA (overrides CUDA_VISIBLE_DEVICES)', metavar='STR', type=str, default=None, show_default=False)
@click.option('--cpu',                                  help='Use CPU despite capabilities', is_flag=True)
@click.option('--json',                                 help='Print scores in JSON', is_flag=True)
@click.option('--isc',                                  help='Calculate ISC (Inception Score)', is_flag=True)
@click.option('--fid',                                  help='Calculate FID (Frechet Inception Distance)', is_flag=True)
@click.option('--kid',                                  help='Calculate KID (Kernel Inception Distance)', is_flag=True)
@click.option('--prc',                                  help='Calculate PRC (Precision and Recall)', is_flag=True)
@click.option('--ppl',                                  help='Calculate PPL (Perceptual Path Length)', is_flag=True)
@click.option('--feature-extractor',                    help='Name of the feature extractor', metavar='STR', type=str, default=DEFAULTS['feature_extractor'], show_default=False)
@click.option('--feature-layer-isc',                    help='Feature layer to use with ISC', metavar='STR', type=str, default=DEFAULTS['feature_layer_isc'], show_default=False)
@click.option('--feature-layer-fid',                    help='Feature layer to use with FID', metavar='STR', type=str, default=DEFAULTS['feature_layer_fid'], show_default=False)
@click.option('--feature-layer-kid',                    help='Feature layer to use with KID', metavar='STR', type=str, default=DEFAULTS['feature_layer_kid'], show_default=False)
@click.option('--feature-layer-prc',                    help='Feature layer to use with PRC', metavar='STR', type=str, default=DEFAULTS['feature_layer_prc'], show_default=False)
@click.option('--feature-extractor-weights-path',       help='Path to feature extractor weights', metavar='PATH', type=str, default=DEFAULTS['feature_extractor_weights_path'], show_default=False)
@click.option('--feature-extractor-internal-dtype',     help='dtype for feature extractor', metavar='STR', type=click.Choice(['float32', 'float64']), default=DEFAULTS['feature_extractor_internal_dtype'], show_default=False)
@click.option('--feature-extractor-compile',            help='Compile feature extractor (experimental)', is_flag=True)
@click.option('--isc-splits',                           help='Number of splits in ISC', metavar='INT', type=int, default=DEFAULTS['isc_splits'], show_default=False)
@click.option('--kid-subsets',                          help='Number of subsets in KID', metavar='INT', type=int, default=DEFAULTS['kid_subsets'], show_default=False)
@click.option('--kid-subset-size',                      help='Subset size in KID', metavar='INT', type=int, default=DEFAULTS['kid_subset_size'], show_default=False)
@click.option('--kid-kernel',                           help='Kernel type in KID', metavar='STR', type=click.Choice(['poly', 'rbf']), default=DEFAULTS['kid_kernel'], show_default=False)
@click.option('--kid-kernel-poly-degree',               help='Degree of poly kernel', metavar='INT', type=int, default=DEFAULTS['kid_kernel_poly_degree'], show_default=False)
@click.option('--kid-kernel-poly-gamma',                help='Gamma for poly kernel', metavar='FLOAT', type=float, default=DEFAULTS['kid_kernel_poly_gamma'], show_default=False)
@click.option('--kid-kernel-poly-coef0',                help='Coef0 for poly kernel', metavar='FLOAT', type=float, default=DEFAULTS['kid_kernel_poly_coef0'], show_default=False)
@click.option('--kid-kernel-rbf-sigma',                 help='Sigma for RBF kernel', metavar='FLOAT', type=float, default=DEFAULTS['kid_kernel_rbf_sigma'], show_default=False)
@click.option('--ppl-epsilon',                          help='Interpolation step size in PPL', metavar='FLOAT', type=float, default=DEFAULTS['ppl_epsilon'], show_default=False)
@click.option('--ppl-reduction',                        help='PPL reduction type', metavar='STR', type=click.Choice(['mean', 'none']), default=DEFAULTS['ppl_reduction'], show_default=False)
@click.option('--ppl-sample-similarity',                help='Similarity method for PPL', metavar='STR', type=str, default=DEFAULTS['ppl_sample_similarity'], show_default=False)
@click.option('--ppl-sample-similarity-resize',         help='Resize samples in PPL', metavar='INT', type=int, default=DEFAULTS['ppl_sample_similarity_resize'], show_default=False)
@click.option('--ppl-sample-similarity-dtype',          help='Sample dtype check for PPL', metavar='STR', type=str, default=DEFAULTS['ppl_sample_similarity_dtype'], show_default=False)
@click.option('--ppl-discard-percentile-lower',         help='Lower discard percentile', metavar='INT', type=int, default=DEFAULTS['ppl_discard_percentile_lower'], show_default=False)
@click.option('--ppl-discard-percentile-higher',        help='Upper discard percentile', metavar='INT', type=int, default=DEFAULTS['ppl_discard_percentile_higher'], show_default=False)
@click.option('--ppl-z-interp-mode',                    help='Z interpolation mode', metavar='STR', type=str, default=DEFAULTS['ppl_z_interp_mode'], show_default=False)
@click.option('--prc-neighborhood',                     help='Nearest neighbors in PRC', metavar='INT', type=int, default=DEFAULTS['prc_neighborhood'], show_default=False)
@click.option('--prc-batch-size',                       help='Batch size for PRC', metavar='INT', type=int, default=DEFAULTS['prc_batch_size'], show_default=False)
@click.option('--no-samples-shuffle',                   help='Disable sample shuffling', is_flag=True)
@click.option('--samples-find-deep',                    help='Recursive sample search', is_flag=True)
@click.option('--samples-find-ext',                     help='File extensions to find', metavar='STR', type=str, default=DEFAULTS['samples_find_ext'], show_default=False)
@click.option('--samples-ext-lossy',                    help='Lossy extensions warning', metavar='STR', type=str, default=DEFAULTS['samples_ext_lossy'], show_default=False)
@click.option('--samples-resize-and-crop',              help='Resize and crop images', metavar='INT', type=int, default=DEFAULTS['samples_resize_and_crop'], show_default=False)
@click.option('--datasets-root',                        help='Root path for datasets', metavar='PATH', type=str, default=DEFAULTS['datasets_root'], show_default=False)
@click.option('--no-datasets-download',                 help='Disable dataset downloading', is_flag=True)
@click.option('--cache-root',                           help='Cache root path', metavar='PATH', type=str, default=DEFAULTS['cache_root'], show_default=False)
@click.option('--no-cache',                             help='Disable cache usage', is_flag=True)
@click.option('--input1-cache-name',                    help='Cache name for input1', metavar='STR', type=str, default=DEFAULTS['input1_cache_name'], show_default=False)
@click.option('--input2-cache-name',                    help='Cache name for input2', metavar='STR', type=str, default=DEFAULTS['input2_cache_name'], show_default=False)
@click.option('--input1-model-z-type',                  help='Z type for input1 model', metavar='STR', type=str, default=DEFAULTS['input1_model_z_type'], show_default=False)
@click.option('--input1-model-z-size',                  help='Z size for input1 model', metavar='INT', type=int, default=DEFAULTS['input1_model_z_size'], show_default=False)
@click.option('--input1-model-num-classes',             help='Num classes for input1', metavar='INT', type=int, default=DEFAULTS['input1_model_num_classes'], show_default=False)
@click.option('--input1-model-num-samples',             help='Sample count for input1', metavar='INT', type=int, default=DEFAULTS['input1_model_num_samples'], show_default=False)
@click.option('--input2-model-z-type',                  help='Z type for input2 model', metavar='STR', type=str, default=DEFAULTS['input2_model_z_type'], show_default=False)
@click.option('--input2-model-z-size',                  help='Z size for input2 model', metavar='INT', type=int, default=DEFAULTS['input2_model_z_size'], show_default=False)
@click.option('--input2-model-num-classes',             help='Num classes for input2', metavar='INT', type=int, default=DEFAULTS['input2_model_num_classes'], show_default=False)
@click.option('--input2-model-num-samples',             help='Sample count for input2', metavar='INT', type=int, default=DEFAULTS['input2_model_num_samples'], show_default=False)
@click.option('--rng-seed',                             help='RNG seed', metavar='INT', type=int, default=DEFAULTS['rng_seed'], show_default=False)
@click.option('--save-cpu-ram',                         help='Reduce RAM usage at cost of speed', is_flag=True)
@click.option('--silent',                               help='Suppress stderr output', is_flag=True)
##### wandb #####
@click.option('--logwandb',                             help='Log results to WandB', is_flag=True)
@click.option('--wandb-project',                        help='WandB project name', metavar='STR', type=str, default='reg')
################################################################################
def cmdline(preset, **opts):
    """Generate random images using the given model.

    Examples:

    \b
    # Generate a couple of images and save them as out/*.png
    python generate_images.py --preset=edm2-img512-s-guid-dino --outdir=out

    \b
    # Generate 50000 images using 8 GPUs and save them as out/*/*.png
    torchrun --standalone --nproc_per_node=8 generate_images.py \\
        --preset=edm2-img64-s-fid --outdir=out --subdirs --seeds=0-49999
    """
    # separate cmdline opts and calc opts
    cmd_args = [
        # 'preset', 
        'net', 'gnet', 'outdir', 'subdirs', 'seeds', 'class_idx',
        'max_batch_size', 'num_steps', 'sigma_min', 'sigma_max', 'rho',
        'guidance', 'S_churn', 'S_min', 'S_max', 'S_noise',
        'pred_type', 'beta_1', 'beta_2', 'gamma', 'eta', 'r_scale',
        'w_low', 'w_high', 'w_par', 'sigma_start', 'sigma_end',
    ]
    calc_args = [
        'image_path', 
        # 'ref_path', 'metrics', 
        'num_images', 'seed', 'max_batch_size', 'num_workers',
    ]
    prc_args = [
        'input1', 'input2', 'gpu', 'cpu', 'json', 'isc', 'fid', 'kid',
        'prc', 'ppl', 'feature_extractor', 'feature_layer_isc',
        'feature_layer_fid', 'feature_layer_kid', 'feature_layer_prc',
        'feature_extractor_weights_path', 'feature_extractor_internal_dtype',
        'feature_extractor_compile', 'isc_splits', 'kid_subsets',
        'kid_subset_size', 'kid_kernel', 'kid_kernel_poly_degree',
        'kid_kernel_poly_gamma', 'kid_kernel_poly_coef0',
        'kid_kernel_rbf_sigma', 'ppl_epsilon', 'ppl_reduction',
        'ppl_sample_similarity', 'ppl_sample_similarity_resize',
        'ppl_sample_similarity_dtype', 'ppl_discard_percentile_lower',
        'ppl_discard_percentile_higher', 'ppl_z_interp_mode',
        'prc_neighborhood', 'prc_batch_size', 'no_samples_shuffle',
        'samples_find_deep', 'samples_find_ext', 'samples_ext_lossy',
        'samples_resize_and_crop', 'datasets_root', 'no_datasets_download',
        'cache_root', 'no_cache', 'input1_cache_name', 'input2_cache_name',
        'input1_model_z_type', 'input1_model_z_size', 'input1_model_num_classes',
        'input1_model_num_samples', 'input2_model_z_type',
        'input2_model_z_size', 'input2_model_num_classes',
        'input2_model_num_samples', 'rng_seed', 'save_cpu_ram', 'silent',
    ]
    wandb_args = [
        'logwandb', 'wandb_project',
    ]
    opts['image_path'] = opts.get('outdir', None)
    opts['num_images'] = opts.get('seeds')[-1] - opts.get('seeds')[0] + 1
    preset = opts.get('preset', None)
    ref_path = opts.get('ref_path', None)
    metrics = opts.get('metrics', None)

    cmd_opts = {k: opts[k] for k in cmd_args if k in opts}
    calc_opts = {k: opts[k] for k in calc_args if k in opts}
    prc_opts = {k: opts[k] for k in prc_args if k in opts}
    wandb_opts = {k: opts[k] for k in wandb_args if k in opts}

    wandb_name = f'{cmd_opts["pred_type"]}'
    if cmd_opts['pred_type'].startswith('egg'):
        wandb_name += f'-guid{cmd_opts["guidance"]}-gamma{cmd_opts["gamma"]}-beta1{cmd_opts["beta_1"]}-beta2{cmd_opts["beta_2"]}'
    elif cmd_opts['pred_type'] == 'mom':
        wandb_name += f'-guid{cmd_opts["guidance"]}-beta1{cmd_opts["beta_1"]}'
    elif cmd_opts['pred_type'] == 'vanilla':
        wandb_name += f'-guid{cmd_opts["guidance"]}'
    elif cmd_opts['pred_type'] == 'cfgpp':
        wandb_name += f'-guid{cmd_opts["guidance"]}'
    elif cmd_opts['pred_type'] == 'apg':
        wandb_name += f'-guid{cmd_opts["guidance"]}-beta1{cmd_opts["beta_1"]}-eta{cmd_opts["eta"]}-rscale{cmd_opts["r_scale"]}'
    elif cmd_opts['pred_type'] == 'ag_noise':
        wandb_name += f'-guid{cmd_opts["guidance"]}'
    elif cmd_opts['pred_type'] == 'fdg':
        wandb_name += f'-wlow{cmd_opts["w_low"]}-whigh{cmd_opts["w_high"]}-wpar{cmd_opts["w_par"]}'
    elif cmd_opts['pred_type'] == 'ig':
        wandb_name += f'-guid{cmd_opts["guidance"]}-sigmastart{cmd_opts["sigma_start"]}-sigmaend{cmd_opts["sigma_end"]}'
    elif cmd_opts['pred_type'] == 'tcfg':
        wandb_name += f'-guid{cmd_opts["guidance"]}'
    ################################################################################

    opts = cmd_opts

    opts = dnnlib.EasyDict(opts)

    # Apply preset.
    if preset is not None:
        if preset not in config_presets:
            raise click.ClickException(f'Invalid configuration preset "{preset}"')
        for key, value in config_presets[preset].items():
            if opts[key] is None:
                opts[key] = value

    # Validate options.
    if opts.net is None:
        raise click.ClickException('Please specify either --preset or --net')
    if opts.guidance is None or opts.guidance == 1:
        opts.guidance = 1
        opts.gnet = None
    elif opts.gnet is None:
        raise click.ClickException('Please specify --gnet when using guidance')
    ################################################################################
    if opts.pred_type == 'vanilla':
        opts.beta_1 = None
        opts.beta_2 = None
        opts.gamma = None
    elif opts.pred_type == 'mom':
        opts.gamma = None
    ################################################################################

    # Generate.
    dist.init()

    if wandb_opts['logwandb'] and dist.get_rank() == 0:
        import wandb
        wandb.init(project=wandb_opts['wandb_project'], name=wandb_name, config=opts)

    if dist.get_rank() == 0:
        ref = load_stats(path=ref_path) # do this first, just in case it fails
        
    image_iter = generate_images(**opts)
    for _r in tqdm.tqdm(image_iter, unit='batch', disable=(dist.get_rank() != 0)):
        pass
    

    stats_iter = calculate_stats_for_files(metrics=metrics, **calc_opts)
    for r in tqdm.tqdm(stats_iter, unit='batch', disable=(dist.get_rank() != 0)):
        pass

    if dist.get_rank() == 0:
        results = calculate_metrics_from_stats(stats=r.stats, ref=ref, metrics=metrics)

    torch.distributed.barrier()
    torch.distributed.destroy_process_group()

    if os.environ.get("RANK", "0") != "0":
        sys.exit(0)

    ################################################################################
    if not (prc_opts["isc"] or prc_opts["fid"] or prc_opts["kid"] or prc_opts["ppl"] or prc_opts["prc"]):
        print(f"No metrics to compute, exiting", file=sys.stderr)
        print(f"Use 'fidelity --help' to see the command line options", file=sys.stderr)
        exit(1)

    if prc_opts["input1"] is None and prc_opts["input2"] is None:
        print(f"No inputs are given, exiting", file=sys.stderr)
        print(f"Use 'fidelity --help' to see the command line options", file=sys.stderr)
        exit(1)

    process_deprecations(prc_opts)

    prc_opts["verbose"] = not prc_opts["silent"]
    prc_opts["datasets_download"] = not prc_opts["no_datasets_download"]
    prc_opts["samples_shuffle"] = not prc_opts["no_samples_shuffle"]
    prc_opts["cache"] = not prc_opts["no_cache"]

    if prc_opts["gpu"] is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = prc_opts["gpu"]

    prc_opts["cuda"] = not prc_opts["cpu"] and os.environ.get("CUDA_VISIBLE_DEVICES", "") != ""

    if torch.cuda.is_available() and not prc_opts["cuda"]:
        print("CUDA is available but --gpu option is not specified", file=sys.stderr)

    metrics = calculate_metrics(**prc_opts)

    if prc_opts["json"]:
        print(json.dumps(metrics, indent=4))
    else:
        print("\n".join((f"{k}: {v:.7g}" for k, v in metrics.items())))

    ################################################################################

    if wandb_opts['logwandb']:
        combined_metrics = results.copy()
        combined_metrics.update(metrics)
        wandb.log(combined_metrics)
        wandb.finish()


    

#----------------------------------------------------------------------------

if __name__ == "__main__":
    cmdline()

#----------------------------------------------------------------------------
