''' Tensorflow inception score code
Derived from https://github.com/openai/improved-gan
Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
THIS CODE REQUIRES TENSORFLOW 1.3 or EARLIER to run in PARALLEL BATCH MODE 

To use this code, run sample.py on your model with --sample_npz, and then 
pass the experiment name in the --experiment_name.
This code also saves pool3 stats to an npz file for FID calculation
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import os
import gzip, pickle
from imageio import imread
from scipy import linalg
import pathlib
import urllib
import warnings

import os.path
import sys
import tarfile
import math
from tqdm import tqdm, trange
from argparse import ArgumentParser
import time

import numpy as np
from six.moves import urllib
# import tensorflow as tf
import tensorflow.compat.v1 as tf

MODEL_DIR = 'model'
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
softmax = None

import os
import sys
import numpy as np
from sklearn.metrics.pairwise import rbf_kernel as polynomial_kernel  


# sess = tf.Session(config=config)

def prepare_parser():
  usage = 'Parser for TF1.3- Inception Score scripts.'
  parser = ArgumentParser(description=usage)
  parser.add_argument(
    '--experiment_name', type=str, default='',
    help='Which experiment''s samples.npz file to pull and evaluate')
  parser.add_argument(
    '--experiment_root', type=str, default='samples',
    help='Default location where samples are stored (default: %(default)s)')
  parser.add_argument(
    '--batch_size', type=int, default=500,
    help='Default overall batchsize (default: %(default)s)')
  parser.add_argument(
    '--test', action='store_true', default=False,
    help='use vgg loss or not'
         '(default: %(default)s)')
  parser.add_argument(
    '--kid', action='store_true', default=False,
    help='use vgg loss or not'
         '(default: %(default)s)')
  return parser


def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
            
    Stable version by Dougal J. Sutherland.
    Params:
    -- mu1 : Numpy array containing the activations of the pool_3 layer of the
             inception net ( like returned by the function 'get_predictions')
             for generated samples.
    -- mu2   : The sample mean over activations of the pool_3 layer, precalcualted
               on an representive data set.
    -- sigma1: The covariance matrix over activations of the pool_3 layer for
               generated samples.
    -- sigma2: The covariance matrix over activations of the pool_3 layer,
               precalcualted on an representive data set.
    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
    assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"

    diff = mu1 - mu2

    # product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
        warnings.warn(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
  
def _handle_path(path,low_profile=False):
  f = np.load(path)
  acts = None
  print(f.files)
  if 'pool_activations' in f:
    print("pool" , path)
    acts , m, s = f['pool_activations'][:], f['pool_mean'][:], f['pool_var'][:]
  else:
     m, s = f['pool_mean'][:], f['pool_var'][:]
  f.close()
  return acts, m, s

  


def polynomial_mmd_averages(codes_r, codes_g, n_subsets=100, subset_size=1000,
                            ret_var=False, replace=False, **kernel_args):
    m = min(codes_r.shape[0], codes_g.shape[0])
    mmds = np.zeros(n_subsets)
    if ret_var:
        vars = np.zeros(n_subsets)
    choice = np.random.choice

    for i in range(n_subsets):
        r = codes_r[choice(len(codes_r), subset_size, replace=replace)]
        g = codes_g[choice(len(codes_g), subset_size, replace=replace)]
        o = polynomial_mmd(r, g, **kernel_args, var_at_m=m, ret_var=ret_var)
        if ret_var:
            mmds[i], vars[i] = o
        else:
            mmds[i] = o

    return (mmds, vars) if ret_var else mmds


def polynomial_mmd(codes_r, codes_g, degree=3, gamma=None, coef0=1,
                   var_at_m=None, ret_var=True):
    # use  k(x, y) = (gamma <x, y> + coef0)^degree
    # default gamma is 1 / dim
    X = codes_r
    Y = codes_g
    sigma = 1.0
    K_XX = polynomial_kernel(X, X )#, degree=degree, gamma=gamma, coef0=coef0)
    K_YY = polynomial_kernel(Y, Y )#, degree=degree, gamma=gamma, coef0=coef0)
    K_XY = polynomial_kernel(X, Y )#, degree=degree, gamma=gamma, coef0=coef0)

    return _mmd2_and_variance(K_XX, K_XY, K_YY,
                              var_at_m=var_at_m, ret_var=ret_var)


def _sqn(arr):
    flat = np.ravel(arr)
    return flat.dot(flat)


def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False,
                       mmd_est='unbiased', block_size=1024,
                       var_at_m=None, ret_var=True):
    # based on
    # https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py
    # but changed to not compute the full kernel matrix at once
    m = K_XX.shape[0]
    assert K_XX.shape == (m, m)
    assert K_XY.shape == (m, m)
    assert K_YY.shape == (m, m)
    if var_at_m is None:
        var_at_m = m

    # Get the various sums of kernels that we'll use
    # Kts drop the diagonal, but we don't need to compute them explicitly
    if unit_diagonal:
        diag_X = diag_Y = 1
        sum_diag_X = sum_diag_Y = m
        sum_diag2_X = sum_diag2_Y = m
    else:
        diag_X = np.diagonal(K_XX)
        diag_Y = np.diagonal(K_YY)

        sum_diag_X = diag_X.sum()
        sum_diag_Y = diag_Y.sum()

        sum_diag2_X = _sqn(diag_X)
        sum_diag2_Y = _sqn(diag_Y)

    Kt_XX_sums = K_XX.sum(axis=1) - diag_X
    Kt_YY_sums = K_YY.sum(axis=1) - diag_Y
    K_XY_sums_0 = K_XY.sum(axis=0)
    K_XY_sums_1 = K_XY.sum(axis=1)

    Kt_XX_sum = Kt_XX_sums.sum()
    Kt_YY_sum = Kt_YY_sums.sum()
    K_XY_sum = K_XY_sums_0.sum()

    if mmd_est == 'biased':
        mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
                + (Kt_YY_sum + sum_diag_Y) / (m * m)
                - 2 * K_XY_sum / (m * m))
    else:
        assert mmd_est in {'unbiased', 'u-statistic'}
        mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1))
        if mmd_est == 'unbiased':
            mmd2 -= 2 * K_XY_sum / (m * m)
        else:
            mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m-1))

    if not ret_var:
        return mmd2

    Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X
    Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y
    K_XY_2_sum = _sqn(K_XY)

    dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1)
    dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0)

    m1 = m - 1
    m2 = m - 2
    zeta1_est = (
        1 / (m * m1 * m2) * (
            _sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum)
        - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2)
        + 1 / (m * m * m1) * (
            _sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum)
        - 2 / m**4 * K_XY_sum**2
        - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
        + 2 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
    )
    zeta2_est = (
        1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum)
        - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2)
        + 2 / (m * m) * K_XY_2_sum
        - 2 / m**4 * K_XY_sum**2
        - 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
        + 4 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
    )
    var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est
               + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est)

    return mmd2, var_est




def compute_kid(real_acts, fake_acts):
    # size = min(len(real_acts), len(fake_acts))
    # mmds, mmd_vars = polynomial_mmd_averages(real_acts[:size], fake_acts[:size])
    mmds = polynomial_mmd_averages(real_acts, fake_acts, replace=True)
    kid = mmds.mean()

    # print("mean MMD^2 estimate:", mmds.mean())
    # print("std MMD^2 estimate:", mmds.std())
    # print("MMD^2 estimates:", mmds, sep='\n')
    #
    # print("mean Var[MMD^2] estimate:", mmd_vars.mean())
    # print("std Var[MMD^2] estimate:", mmd_vars.std())
    # print("Var[MMD^2] estimates:", mmd_vars, sep='\n')

    return kid

def calculate_fid_given_paths(paths, low_profile=False, kid=False):
  print(paths[0])
  real_acts , m1, s1 = _handle_path(paths[0], low_profile=low_profile)
  print(paths[1])
  fake_acts , m2, s2 = _handle_path(paths[1], low_profile=low_profile)
  fid_value = calculate_frechet_distance(m1, s1, m2, s2)
  if kid:
    kid_value =  compute_kid(real_acts, fake_acts) #rbf_kernel(real_acts, fake_acts).mean() #
  else:
    kid_value = 0.0
  return fid_value , kid_value


      
def run(config):

  fname = '%s/%s/TF_pool.npz' % (config['experiment_root'], config['experiment_name'])
  print('loading %s ...'%fname)
  t0 = time.time()
  if config['test']:
    print("doing fid test")
    if 'anime' in config['experiment_name'].lower():
      fname2 = './TF_POOL/anime/TF_pool.npz'  
    elif 'face' in config['experiment_name'].lower():
      fname2 = './TF_POOL/face/TF_pool.npz'  
    elif 'flower' in config['experiment_name'].lower():
      fname2 = './TF_POOL/flower/TF_pool_new.npz'  

  if config['kid']:
      fid , kid = calculate_fid_given_paths([fname , fname2] , kid=True)
  else:
      fid , kid = calculate_fid_given_paths([fname , fname2])
  t1 = time.time()
  
  print('FID took %3f seconds, score of %3f , %3f '%(t1-t0, fid , kid))

def main():
  # parse command line and run
  parser = prepare_parser()
  config = vars(parser.parse_args())
  print(config)
  run(config)

if __name__ == '__main__':
  main()