"""Evaluate distribution-matching accuracy with MMD/FID/KID metrics."""
import argparse
from glob import glob
import os
import sys
sys.path.append('..')

import numpy as np
from pytorch_fid.fid_score import calculate_frechet_distance
from pytorch_fid.inception import InceptionV3
from torch.nn.functional import adaptive_avg_pool2d
from sklearn.metrics.pairwise import polynomial_kernel
import torch
from tqdm import tqdm

import mmd_torch


parser = argparse.ArgumentParser()
parser.add_argument(
  '-c', '--constraint', type=str, default='flux', required=False,
  choices=['flux', 'burgers', 'incompress', 'periodic', 'count'])
parser.add_argument(
  '-m', '--metric', type=str, default='flux', required=False,
  choices=['mmd', 'kid', 'fid'])
parser.add_argument('--vae', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('-d', '--device', type=str, default='cpu')

RESULTSDIR = 'results'
MMD_N_SAMPLES = 10000
KID_N_SAMPLES = 10000
FID_N_SAMPLES = 50000

DATASET_PER_CONSTRAINT = {
  'flux': 'GRMHD',
  'burgers': 'Burgers',
  'periodic': 'Periodic',
  'incompress': 'Kolmogorov',
  'kolmogorov': 'Kolmogorov',
  'count': 'Galaxies',
}

CLIM_PER_CONSTRAINT = {
  'flux': (0, 1),
  'burgers': (0, 1),
  'incompress': (-20, 20),
  'periodic': (0, 1),
  'count': (0, 1)
}

SIGMA_PER_CONSTRAINT = {
  'flux': 2.48111,
  'burgers': 12.40963,
  'incompress': 230.86373,
  'periodic': 32.00000,
  'count': 14.17108,
}


def convert_data_to_rgb_images(data, clim=(0, 1), as_uint8=True):
  assert len(data.shape) == 4
  # Clip to `clim`.
  norm_data = np.clip(data, clim[0], clim[1])
  # Scale to [0, 1].
  norm_data = (norm_data - clim[0]) / (clim[1] - clim[0])
  # Repeat channel three times for RGB. 
  rgb = np.concatenate((norm_data, norm_data, norm_data), axis=-1)
  if as_uint8:
    return (rgb * 255.).astype(np.uint8)
  else:
    return rgb.astype(np.float32)
  

def estimate_sigma(data, constraint):
  try:
    return SIGMA_PER_CONSTRAINT[constraint]
  except:
    return mmd_torch.estimate_sigma(data)


def mmd_averages(X_gen, X_real, sigma_list=None, n_subsets=50, subset_size=1000,
                 seed=0):
  mmds = np.zeros(n_subsets)
  if sigma_list is None:
    # Keep track of estimated bandwidth parameters.
    sigs = np.zeros(n_subsets)
  random_state = np.random.RandomState(seed)

  with tqdm(range(n_subsets), desc='MMD') as pbar:
    for i in pbar:
      g = X_gen[random_state.choice(len(X_gen), subset_size, replace=False)]
      r = X_real[random_state.choice(len(X_real), subset_size, replace=False)]
      if sigma_list is not None:
        mmd2 = mmd_torch.mix_rbf_mmd2(g, r, sigma_list, return_var=False)
      else:
        sig = mmd_torch.estimate_sigma(r)
        sigs[i] = sig
        mmd2 = mmd_torch.mix_rbf_mmd2(g, r, [sig], return_var=False)
      mmds[i] = np.sqrt(mmd2.item())
      if sigma_list is None:
        pbar.set_postfix({'mean': mmds[:i + 1].mean(), 'sig': sigs[:i + 1].mean()})
      else:
        pbar.set_postfix({'mean': mmds[:i + 1].mean()})
  return mmds


def eval_mmd(constraint, vae=False):
  real_samples = np.load(os.path.join(
    RESULTSDIR, constraint, 'real_images/set_001.npy'))
  print('Real samples shape     :', real_samples.shape)
  X_real = real_samples.reshape(real_samples.shape[0], -1)[:MMD_N_SAMPLES]

  # Estimate sigma based on real data.
  try:
    sig = SIGMA_PER_CONSTRAINT[constraint]
  except:
    sig = mmd_torch.estimate_sigma(X_real)
  # sigma_list = [sig]
  # print(f'sigma = {sig:.5f}')
  sigma_list = None

  X_real = torch.from_numpy(X_real)
  
  if not vae:
    # Load samples.
    bft_samples = np.load(os.path.join(
      RESULTSDIR, constraint, 'before_ft_images/set_001.npy'))
    ft_samples = np.load(os.path.join(
      RESULTSDIR, constraint, 'ft_images/set_001.npy'))
    dm_samples = np.load(os.path.join(
      RESULTSDIR, constraint, 'dm_images/set_001.npy'))

    print('Before FT samples shape:', bft_samples.shape)
    print('FT samples shape       :', ft_samples.shape)
    print('DM samples shape       :', dm_samples.shape)

    X_bft = bft_samples.reshape(bft_samples.shape[0], -1)[:MMD_N_SAMPLES]
    X_ft = ft_samples.reshape(ft_samples.shape[0], -1)[:MMD_N_SAMPLES]
    X_dm = dm_samples.reshape(dm_samples.shape[0], -1)[:MMD_N_SAMPLES]

    X_bft = torch.from_numpy(X_bft)
    X_ft = torch.from_numpy(X_ft)
    X_dm = torch.from_numpy(X_dm)

    # Calculate MMD for before finetuning samples.
    # bft_mmd, bft_std = _get_mmd_and_std(X_bft, X_real, sig)
    bft_mmds = mmd_averages(X_bft, X_real, sigma_list)
    bft_mmd = np.mean(bft_mmds)
    bft_std = np.std(bft_mmds)
    print(f'Before FT: {bft_mmd:.5f} ± {str(np.around(bft_std, 5))}')

    # Calculate MMD for finetuning samples.
    ft_mmds = mmd_averages(X_ft, X_real, sigma_list)
    ft_mmd = np.mean(ft_mmds)
    ft_std = np.std(ft_mmds)
    print(f'       FT: {ft_mmd:.5f} ± {str(np.around(ft_std, 5))}')

    # Calculate MMD for DM samples.
    dm_mmds = mmd_averages(X_dm, X_real, sigma_list)
    dm_mmd = np.mean(dm_mmds)
    dm_std = np.std(dm_mmds)
    print(f'       DM: {dm_mmd:.5f} ± {str(np.around(dm_std, 5))}')
  else:
    # Load samples.
    mvae_samples = np.load(os.path.join(
      RESULTSDIR, constraint, 'mvae_images/set_001.npy'))
    vae_samples = np.load(os.path.join(
      RESULTSDIR, constraint, 'vae_images/set_001.npy'))

    print('MVAE samples shape:', mvae_samples.shape)
    print('VAE samples shape :', vae_samples.shape)

    X_mvae = mvae_samples.reshape(mvae_samples.shape[0], -1)[:MMD_N_SAMPLES]
    X_vae = vae_samples.reshape(vae_samples.shape[0], -1)[:MMD_N_SAMPLES]

    X_mvae = torch.from_numpy(X_mvae)
    X_vae = torch.from_numpy(X_vae)

    # Calculate MMD for MVAE samples.
    mvae_mmds = mmd_averages(X_mvae, X_real, sigma_list)
    mvae_mmd = np.mean(mvae_mmds)
    mvae_std = np.std(mvae_mmds)
    print(f'MVAE: {mvae_mmd:.5f} ± {str(np.around(mvae_std, 5))}')

    # Calculate MMD for VAE samples.
    vae_mmds = mmd_averages(X_vae, X_real, sigma_list)
    vae_mmd = np.mean(vae_mmds)
    vae_std = np.std(vae_mmds)
    print(f'VAE : {vae_mmd:.5f} ± {str(np.around(vae_std, 5))}')


def _get_activation_statistics(acts):
  mu = np.mean(acts, axis=0)
  sigma = np.cov(acts, rowvar=False)
  return  mu, sigma


def _get_activations(paths, model, n_samples, batch_size, total_n_batches, clim,
                     pbar_desc='', dims=2048, device='cuda'):
  """Get Inception activations given paths to NumPy arrays."""
  pred_arr = np.empty((n_samples, dims))

  start_idx = 0
  with tqdm(total=total_n_batches, desc=pbar_desc) as pbar:
    for path in paths:
      samples = np.load(path)
      if len(samples.shape) == 3:
        samples = np.expand_dims(samples, axis=-1)
      assert len(samples) % batch_size == 0
      n_batches = len(samples) // batch_size

      # Convert to RGB.
      X = convert_data_to_rgb_images(samples, clim, as_uint8=False)

      # im = Image.fromarray(X[0])
      # im.save('test.png')
      # raise RuntimeError

      for batch_i in range(n_batches):
        batch = X[batch_i * batch_size:(batch_i + 1) * batch_size]
        batch = torch.from_numpy(batch).permute((0, 3, 1, 2)).to(device)

        with torch.no_grad():
          pred = model(batch)[0]
        
        # If model output is not scalar, apply global spatial average pooling.
        # This happens if you choose a dimensionality not equal 2048.
        if pred.size(2) != 1 or pred.size(3) != 1:
          pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
        
        pred = pred.squeeze(3).squeeze(2).cpu().numpy()
        pred_arr[start_idx : start_idx + pred.shape[0]] = pred
        start_idx = start_idx + pred.shape[0]

        pbar.update(1)

        if start_idx == n_samples:
          break
      if start_idx == n_samples:
        break
  assert start_idx == n_samples
  return pred_arr


def eval_fid(constraint, device, vae=False):
  """Borrowed from https://github.com/mseitzer/pytorch-fid."""
  dims = 2048
  batch_size = 50
  block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
  model = InceptionV3([block_idx]).to(device)
  model.eval()
  clim = CLIM_PER_CONSTRAINT[constraint]
  total_n_batches = FID_N_SAMPLES // batch_size
  is_kolmogorov = (DATASET_PER_CONSTRAINT[constraint] == 'Kolmogorov')

  if is_kolmogorov:
    real_paths = glob(
      os.path.join(RESULTSDIR, constraint, 'real_images/vor_set*.npy'))
  else:
    real_paths = glob(
      os.path.join(RESULTSDIR, constraint, 'real_images/set*.npy'))
  real_acts = _get_activations(
    real_paths, model, FID_N_SAMPLES, batch_size, total_n_batches, clim,
    pbar_desc='Real', dims=2048, device=device
  )
  real_mu, real_sigma = _get_activation_statistics(real_acts)

  if not vae:
    if is_kolmogorov:
      bft_paths = glob(
        os.path.join(RESULTSDIR, constraint, 'before_ft_images/vor_set*.npy'))
      ft_paths = glob(
        os.path.join(RESULTSDIR, constraint, 'ft_images/vor_set*.npy'))
      dm_paths = glob(
        os.path.join(RESULTSDIR, constraint, 'dm_images/vor_set*.npy'))
    else:
      bft_paths = glob(
        os.path.join(RESULTSDIR, constraint, 'before_ft_images/set*.npy'))
      ft_paths = glob(
        os.path.join(RESULTSDIR, constraint, 'ft_images/set*.npy'))
      dm_paths = glob(
        os.path.join(RESULTSDIR, constraint, 'dm_images/set*.npy'))

    # Calculate FID for before finetuning samples.
    bft_acts = _get_activations(
      bft_paths, model, FID_N_SAMPLES, batch_size, total_n_batches, clim,
      pbar_desc='Before FT', dims=2048, device=device
    )
    bft_mu, bft_sigma = _get_activation_statistics(bft_acts)
    bft_fid = calculate_frechet_distance(bft_mu, bft_sigma, real_mu, real_sigma)
    print(f'Before FT: {bft_fid:.5f}')

    # Calculate FID for finetuning samples.
    ft_acts = _get_activations(
      ft_paths, model, FID_N_SAMPLES, batch_size, total_n_batches, clim,
      pbar_desc='FT', dims=2048, device=device
    )
    ft_mu, ft_sigma = _get_activation_statistics(ft_acts)
    ft_fid = calculate_frechet_distance(ft_mu, ft_sigma, real_mu, real_sigma)
    print(f'FT       : {ft_fid:.5f}')

    # Calculate FID for DM samples.
    dm_acts = _get_activations(
      dm_paths, model, FID_N_SAMPLES, batch_size, total_n_batches, clim,
      pbar_desc='DM', dims=2048, device=device
    )
    dm_mu, dm_sigma = _get_activation_statistics(dm_acts)
    dm_fid = calculate_frechet_distance(dm_mu, dm_sigma, real_mu, real_sigma)
    print(f'DM       : {dm_fid:.5f}')
  else:
    if is_kolmogorov:
      mvae_paths = glob(
        os.path.join(RESULTSDIR, constraint, 'mvae_images/vor_set*.npy'))
      vae_paths = glob(
        os.path.join(RESULTSDIR, constraint, 'vae_images/vor_set*.npy'))
    else:
      mvae_paths = glob(
        os.path.join(RESULTSDIR, constraint, 'mvae_images/set*.npy'))
      vae_paths = glob(
        os.path.join(RESULTSDIR, constraint, 'vae_images/set*.npy'))

    # Calculate FID for MVAE samples.
    mvae_acts = _get_activations(
      mvae_paths, model, FID_N_SAMPLES, batch_size, total_n_batches, clim,
      pbar_desc='MVAE', dims=2048, device=device
    )
    mvae_mu, mvae_sigma = _get_activation_statistics(mvae_acts)
    mvae_fid = calculate_frechet_distance(mvae_mu, mvae_sigma, real_mu, real_sigma)
    print(f'MVAE: {mvae_fid:.5f}')

    # Calculate FID for VAE samples.
    vae_acts = _get_activations(
      vae_paths, model, FID_N_SAMPLES, batch_size, total_n_batches, clim,
      pbar_desc='VAE', dims=2048, device=device
    )
    vae_mu, vae_sigma = _get_activation_statistics(vae_acts)
    vae_fid = calculate_frechet_distance(vae_mu, vae_sigma, real_mu, real_sigma)
    print(f'VAE : {vae_fid:.5f}')


def _sqn(arr):
  flat = np.ravel(arr)
  return flat.dot(flat)


def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False,
                       mmd_est='unbiased', block_size=1024,
                       var_at_m=None, ret_var=True):
  # based on
  # https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py
  # but changed to not compute the full kernel matrix at once
  m = K_XX.shape[0]
  assert K_XX.shape == (m, m)
  assert K_XY.shape == (m, m)
  assert K_YY.shape == (m, m)
  if var_at_m is None:
      var_at_m = m

  # Get the various sums of kernels that we'll use
  # Kts drop the diagonal, but we don't need to compute them explicitly
  if unit_diagonal:
    diag_X = diag_Y = 1
    sum_diag_X = sum_diag_Y = m
    sum_diag2_X = sum_diag2_Y = m
  else:
    diag_X = np.diagonal(K_XX)
    diag_Y = np.diagonal(K_YY)

    sum_diag_X = diag_X.sum()
    sum_diag_Y = diag_Y.sum()

    sum_diag2_X = _sqn(diag_X)
    sum_diag2_Y = _sqn(diag_Y)

  Kt_XX_sums = K_XX.sum(axis=1) - diag_X
  Kt_YY_sums = K_YY.sum(axis=1) - diag_Y
  K_XY_sums_0 = K_XY.sum(axis=0)
  K_XY_sums_1 = K_XY.sum(axis=1)

  Kt_XX_sum = Kt_XX_sums.sum()
  Kt_YY_sum = Kt_YY_sums.sum()
  K_XY_sum = K_XY_sums_0.sum()

  if mmd_est == 'biased':
    mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
            + (Kt_YY_sum + sum_diag_Y) / (m * m)
            - 2 * K_XY_sum / (m * m))
  else:
    assert mmd_est in {'unbiased', 'u-statistic'}
    mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1))
    if mmd_est == 'unbiased':
      mmd2 -= 2 * K_XY_sum / (m * m)
    else:
      mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m-1))

  if not ret_var:
    return mmd2

  Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X
  Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y
  K_XY_2_sum = _sqn(K_XY)

  dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1)
  dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0)

  m1 = m - 1
  m2 = m - 2
  zeta1_est = (
    1 / (m * m1 * m2) * (
        _sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum)
    - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2)
    + 1 / (m * m * m1) * (
        _sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum)
    - 2 / m**4 * K_XY_sum**2
    - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
    + 2 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
  )
  zeta2_est = (
    1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum)
    - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2)
    + 2 / (m * m) * K_XY_2_sum
    - 2 / m**4 * K_XY_sum**2
    - 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
    + 4 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
  )
  var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est
             + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est)

  return mmd2, var_est


def polynomial_mmd2(codes_g, codes_r, degree=3, gamma=None, coef0=1,
                   var_at_m=None, ret_var=True):
  # use  k(x, y) = (gamma <x, y> + coef0)^degree
  # default gamma is 1 / dim
  X = codes_g
  Y = codes_r

  K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0)
  K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0)
  K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0)

  return _mmd2_and_variance(
    K_XX, K_XY, K_YY, var_at_m=var_at_m, ret_var=ret_var)


def polynomial_mmd2_averages(codes_g, codes_r, n_subsets=50, subset_size=1000,
                            ret_var=True, seed=0, **kernel_args):
  m = min(codes_g.shape[0], codes_r.shape[0])
  mmds = np.zeros(n_subsets)
  if ret_var:
    vars = np.zeros(n_subsets)
  random_state = np.random.RandomState(seed)

  with tqdm(range(n_subsets), desc='MMD') as pbar:
    for i in pbar:
      g = codes_g[random_state.choice(len(codes_g), subset_size, replace=False)]
      r = codes_r[random_state.choice(len(codes_r), subset_size, replace=False)]
      o = polynomial_mmd2(g, r, **kernel_args, var_at_m=m, ret_var=ret_var)
      if ret_var:
        mmds[i], vars[i] = o
      else:
        mmds[i] = 0
      pbar.set_postfix({'mean': mmds[:i + 1].mean()})
  return (mmds, vars) if ret_var else mmds


def eval_kid(constraint, device, vae=False):
  """Borrowed from https://github.com/abdulfatir/gan-metrics-pytorch/blob/master/kid_score.py."""
  dims = 2048
  batch_size = 50
  block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
  model = InceptionV3([block_idx]).to(device)
  model.eval()
  clim = CLIM_PER_CONSTRAINT[constraint]
  total_n_batches = KID_N_SAMPLES // batch_size
  is_kolmogorov = (DATASET_PER_CONSTRAINT[constraint] == 'Kolmogorov')

  if is_kolmogorov:
    real_paths = sorted(glob(
      os.path.join(RESULTSDIR, constraint, 'real_images/vor_set*.npy')))
  else:
    real_paths = sorted(glob(
      os.path.join(RESULTSDIR, constraint, 'real_images/set*.npy')))
  
  real_acts = _get_activations(
    real_paths, model, KID_N_SAMPLES, batch_size, total_n_batches, clim,
    pbar_desc='Real', dims=2048, device=device
  )

  if not vae:
    if is_kolmogorov:
      bft_paths = sorted(glob(
        os.path.join(RESULTSDIR, constraint, 'before_ft_images/vor_set*.npy')))
      ft_paths = sorted(glob(
        os.path.join(RESULTSDIR, constraint, 'ft_images/vor_set*.npy')))
      dm_paths = sorted(glob(
        os.path.join(RESULTSDIR, constraint, 'dm_images/vor_set*.npy')))
    else:
      bft_paths = sorted(glob(
        os.path.join(RESULTSDIR, constraint, 'before_ft_images/set*.npy')))
      ft_paths = sorted(glob(
        os.path.join(RESULTSDIR, constraint, 'ft_images/set*.npy')))
      dm_paths = sorted(glob(
        os.path.join(RESULTSDIR, constraint, 'dm_images/set*.npy')))

    # Calculate KID for before finetuning samples.
    bft_acts = _get_activations(
      bft_paths, model, KID_N_SAMPLES, batch_size, total_n_batches, clim,
      pbar_desc='Before FT', dims=2048, device=device
    )
    bft_mmds, _ = polynomial_mmd2_averages(real_acts, bft_acts)
    bft_mmd = np.mean(bft_mmds)
    bft_std = np.std(bft_mmds)
    print(f'Before FT: {bft_mmd:.5f} ± {str(np.around(bft_std, 5))}')

    # Calculate KID for finetuning samples.
    ft_acts = _get_activations(
      ft_paths, model, KID_N_SAMPLES, batch_size, total_n_batches, clim,
      pbar_desc='FT', dims=2048, device=device
    )
    ft_mmds, _ = polynomial_mmd2_averages(real_acts, ft_acts)
    ft_mmd = np.mean(ft_mmds)
    ft_std = np.std(ft_mmds)
    print(f'FT       : {ft_mmd:.5f} ± {str(np.around(ft_std, 5))}')

    # Calculate KID for DM samples.
    dm_acts = _get_activations(
      dm_paths, model, KID_N_SAMPLES, batch_size, total_n_batches, clim,
      pbar_desc='FT', dims=2048, device=device
    )
    dm_mmds, _ = polynomial_mmd2_averages(real_acts, dm_acts)
    dm_mmd = np.mean(dm_mmds)
    dm_std = np.std(dm_mmds)
    print(f'DM       : {dm_mmd:.5f} ± {str(np.around(dm_std, 5))}')
  else:
    if is_kolmogorov:
      mvae_paths = sorted(glob(
        os.path.join(RESULTSDIR, constraint, 'mvae_images/vor_set*.npy')))
      vae_paths = sorted(glob(
        os.path.join(RESULTSDIR, constraint, 'vae_images/vor_set*.npy')))
    else:
      mvae_paths = sorted(glob(
        os.path.join(RESULTSDIR, constraint, 'mvae_images/set*.npy')))
      vae_paths = sorted(glob(
        os.path.join(RESULTSDIR, constraint, 'vae_images/set*.npy')))

    # Calculate KID for MVAE samples.
    mvae_acts = _get_activations(
      mvae_paths, model, KID_N_SAMPLES, batch_size, total_n_batches, clim,
      pbar_desc='MVAE', dims=2048, device=device
    )
    mvae_mmds, _ = polynomial_mmd2_averages(real_acts, mvae_acts)
    mvae_mmd = np.mean(mvae_mmds)
    mvae_std = np.std(mvae_mmds)
    print(f'MVAE: {mvae_mmd:.5f} ± {str(np.around(mvae_std, 5))}')

    # Calculate KID for VAE samples.
    vae_acts = _get_activations(
      vae_paths, model, KID_N_SAMPLES, batch_size, total_n_batches, clim,
      pbar_desc='VAE', dims=2048, device=device
    )
    vae_mmds, _ = polynomial_mmd2_averages(real_acts, vae_acts)
    vae_mmd = np.mean(vae_mmds)
    vae_std = np.std(vae_mmds)
    print(f'VAE : {vae_mmd:.5f} ± {str(np.around(vae_std, 5))}')


if __name__ == '__main__':
  args = parser.parse_args()

  if args.metric == 'mmd':
    eval_mmd(args.constraint, args.vae)
  elif args.metric == 'kid':
    eval_kid(args.constraint, args.device, args.vae)
  elif args.metric == 'fid':
    eval_fid(args.constraint, args.device, args.vae)
