"""Basic nonlinear transform coder for RGB images.

This is a close approximation of the image compression model of
Balle, Laparra, Simoncelli (2017):
End-to-end optimized image compression
https://arxiv.org/abs/1611.01704

With patches from Victor Xing <victor.t.xing@gmail.com>
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports
import numpy as np
import tensorflow as tf
import tensorflow_compression as tfc
import argparse,shutil
import os,sys,glob,scipy.misc
import matplotlib.pyplot as plt
import ms_ssim
import inputpipeline
from metric import Psnr, msssim
import math
import ssim_matrix

def analysis_transform(tensor, num_filters):
    """Builds the analysis transform."""

    with tf.variable_scope("analysis"):
        with tf.variable_scope("layer_0"):
            layer = tfc.SignalConv2D(
                num_filters, (9, 9), corr=True, strides_down=2, padding="same_zeros",
                use_bias=True, activation=tfc.GDN())
            tensor = layer(tensor)

        with tf.variable_scope("layer_1"):
            layer = tfc.SignalConv2D(
                num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros",
                use_bias=True, activation=tfc.GDN())
            tensor = layer(tensor)

        with tf.variable_scope("layer_2"):
            layer = tfc.SignalConv2D(
                num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros",
                use_bias=True, activation=tfc.GDN())
            tensor = layer(tensor)

        with tf.variable_scope("layer_3"):
            layer = tfc.SignalConv2D(
                num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros",
                use_bias=False, activation=None)
            tensor = layer(tensor)

        with tf.variable_scope('reshape'):
            tensor = tf.layers.flatten(tensor)
        if args.activation == 'sigmoid':
            with tf.variable_scope('encoder'):
                tensor = tf.nn.sigmoid(tf.layers.dense(tensor, args.dim1))
                tensor = tf.layers.dense(tensor, args.z)
        elif args.activation == 'softplus':
            with tf.variable_scope('encoder'):
                tensor = tf.nn.softplus(tf.layers.dense(tensor, args.dim1))
                # mean of z
                mean = tf.layers.dense(tensor,  args.z)
                # mean of sigma
                sigma = tf.layers.dense(tensor, args.z)

                # dense layer
                # Sampler: Normal (gaussian) random distribution
                eps = tf.random_normal(tf.shape(mean), dtype=tf.float32, mean=0., stddev=1.0,
                                       name='epsilon')
                # reparameterization trick
                z = mean + tf.exp(sigma / 2) * eps
                # x = tf.layers.dense(x, 128, tf.nn.tanh)
        elif args.activation == 'None':
            with tf.variable_scope('encoder'):
                tensor = tf.layers.dense(tensor, args.z)
        return z, mean, sigma


def synthesis_transform(tensor, num_filters):
    """Builds the synthesis transform."""

    with tf.variable_scope("synthesis", reuse=tf.AUTO_REUSE):
        if args.activation == 'sigmoid':
            with tf.variable_scope('decoder'):
                tensor = tf.nn.sigmoid(tf.layers.dense(tensor, args.dim1))
                tensor = tf.layers.dense(tensor, 4 * 4 * num_filters)
        elif args.activation == 'softplus':
            with tf.variable_scope('decoder'):
                tensor = tf.nn.softplus(tf.layers.dense(tensor, args.dim1))
                if args.ac2 == 'True':
                    tensor = tf.nn.softplus(tf.layers.dense(tensor, 4 * 4 * num_filters))
                else:
                    tensor = tf.layers.dense(tensor, 4 * 4 * num_filters)
        elif args.activation == 'None':
            with tf.variable_scope('decoder'):
                tensor = tf.layers.dense(tensor, 4 * 4 * num_filters)
        with tf.variable_scope('reshape'):
            # dense layer
            tensor = tf.reshape(tensor, [-1, 4, 4, num_filters])

        with tf.variable_scope("layer_0"):
            layer = tfc.SignalConv2D(
                num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros",
                use_bias=True, activation=tfc.GDN(inverse=True))
            tensor = layer(tensor)

        with tf.variable_scope("layer_1"):
            layer = tfc.SignalConv2D(
                num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros",
                use_bias=True, activation=tfc.GDN(inverse=True))
            tensor = layer(tensor)

        with tf.variable_scope("layer_2"):
            layer = tfc.SignalConv2D(
                num_filters // 2, (5, 5), corr=False, strides_up=2, padding="same_zeros",
                use_bias=True, activation=tfc.GDN(inverse=True))
            tensor = layer(tensor)

        with tf.variable_scope("layer_3"):
            layer = tfc.SignalConv2D(
                3, (9, 9), corr=False, strides_up=2, padding="same_zeros",
                use_bias=True, activation=None)
            tensor = layer(tensor)

        return tensor

def quantize_image(image):
  image = tf.round(image * 255)
  image = tf.saturate_cast(image, tf.uint8)
  return image


def train():
    if not os.path.exists(args.checkpoint_dir):
        # shutil.rmtree(args.checkpoint_dir)
        os.makedirs(args.checkpoint_dir)
    log_name = os.path.join(args.checkpoint_dir, 'params.log')
    if os.path.exists(log_name):
        print('remove file:%s' % log_name)
        os.remove(log_name)
    params = open(log_name, 'w')
    for arg in vars(args):
        str_ = '%s: %s.\n' % (arg, getattr(args, arg))
        print(str_)
        params.write(str_)
    params.close()
    tf.logging.set_verbosity(tf.logging.INFO)
    # tf Graph input (only pictures)
    if args.data_set.lower() == 'celeba':
        data_glob = imgs_path = args.img_path + '/*.png'
        print(imgs_path)

    ip_train = inputpipeline.InputPipeline(
        inputpipeline.get_dataset(data_glob),
        args.patch_size, batch_size=args.batch_size,
        shuffle=True,
        num_preprocess_threads=6,
        num_crops_per_img=6)
    X = ip_train.get_batch()

    # Construct model
    #encoder_op = analysis_transform(X, 64)
    encoder_op, mean, var = analysis_transform(X, 64)
    if args.split == 'None':
        X_pred = synthesis_transform(encoder_op, 64)
    else:
        X_pred = synthesis_transform(mean, 64)
        X_pred2 = synthesis_transform(encoder_op, 64)

    # Define loss and optimizer, minimize the squared error
    mse_loss = tf.reduce_mean(tf.squared_difference(255 * X, 255 * X_pred))
    mse2_loss = tf.reduce_mean(tf.squared_difference( X, X_pred))
    msssim_loss = ms_ssim.MultiScaleSSIM(X * 255, X_pred * 255, data_format='NHWC')

    if args.loss1 =="mse2":
        # mse loss
        d1 = mse2_loss
    elif args.loss1 == 'myssim':
        d1 = tf.reduce_mean(ssim_matrix.ssim(X * 255, (X - X_pred) * 255, X_pred,  max_val=255, mode='train',compensation=1))
    else:
        print('error invalid loss1')
        return -1

    if args.split != 'None':
        if args.loss2 =="mse2":
            # mse loss
            mse2 = tf.reduce_mean(tf.squared_difference(X_pred, X_pred2))
            d2 = mse2
        elif args.loss2 == 'myssim':
            pssim2 =  tf.reduce_mean(ssim_matrix.ssim(X_pred * 255, (X_pred - X_pred2) * 255, X_pred2,  max_val=255, mode='train',compensation=1))
            d2 = pssim2

    # KL loss
    kl_div_loss = 1 + var - tf.square(mean) - tf.exp(var)
    kl_div_loss = -0.5 * tf.reduce_sum(kl_div_loss, 1)
    kl_div_loss = tf.reduce_mean(kl_div_loss)

    # total loss
    if args.split != 'None':
        train_loss = args.lambda1 * d1 + args.lambda2 * d2 + kl_div_loss
    else:
        train_loss = args.lambda1 * d1 + kl_div_loss

    step = tf.train.create_global_step()


    if args.finetune != 'None':
        learning_rate = 0.00001
    else:
        learning_rate = 0.0001
    main_lr = tf.train.AdamOptimizer(learning_rate)
    optimizer = main_lr.minimize(train_loss, global_step=step)
    tf.summary.scalar("loss", train_loss)
    #tf.summary.scalar("bpp", bpp)
    tf.summary.scalar("mse", mse_loss)
    logged_tensors = [
        tf.identity(train_loss, name="train_loss"),
        tf.identity(msssim_loss, name="ms-ssim")
    ]

    tf.summary.image("original", quantize_image(X))
    tf.summary.image("reconstruction", quantize_image(X_pred))

    hooks = [
        tf.train.StopAtStepHook(last_step=args.num_steps),
        tf.train.NanTensorHook(train_loss),
        tf.train.LoggingTensorHook(logged_tensors, every_n_secs=60),
        tf.train.SummarySaverHook(save_steps=args.save_steps, summary_op=tf.summary.merge_all()),
        tf.train.CheckpointSaverHook(save_steps=args.save_steps, checkpoint_dir=args.checkpoint_dir)
    ]

    X_rec = tf.clip_by_value(X_pred, 0, 1)
    X_rec = tf.round(X_rec * 255)
    X_rec = tf.cast(X_rec, tf.uint8)
    X_ori = tf.clip_by_value(X, 0, 1)
    X_ori = tf.round(X_ori * 255)
    X_ori = tf.cast(X_ori, tf.uint8)

    if args.finetune != 'None':
        init_fn_ae = tf.contrib.framework.assign_from_checkpoint_fn(args.finetune,tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))

    train_count = 0

    '''
    if args.split != 'None':
        parameter = '%s_%s_%s_la1%s_la2%s_%s_%s' % (args.activation, args.dim1, args.z, args.lambda1, args.lambda2, args.loss1, args.loss2)
    else:
        parameter = '%s_%s_%s_la1%s_%s' % (
        args.activation, args.dim1, args.z, args.lambda1, args.loss1)
    '''
    parameter = ""

    with tf.train.MonitoredTrainingSession(
            hooks=hooks) as sess:
        if args.finetune != 'None':
            init_fn_ae(sess)
            print('load from %s'%(args.finetune))
        while not sess.should_stop():
            if args.split != 'None':
                _, train_loss_ , d1_, d2_, kl_div_loss_, rec_img, X_ori_ = sess.run(
                    [optimizer, train_loss, d1, d2, kl_div_loss,  X_rec, X_ori])
                if (train_count + 1) % args.display_steps == 0:
                    f_log = open('%s/VAE_log.csv' % (args.checkpoint_dir), 'a')
                    f_log.write('%d,loss,%f, kl,%f, d1,%f, d2,%f\n' % (train_count + 1, train_loss_, kl_div_loss_, d1_, d2_))
                    print('%d,loss,%f, kl,%f, d1,%f, d2,%f\n' % (train_count + 1, train_loss_, kl_div_loss_, d1_, d2_))
                    f_log.close()
            else:
                _, train_loss_ , d1_, kl_div_loss_, rec_img, X_ori_ = sess.run(
                    [optimizer, train_loss, d1, kl_div_loss,  X_rec, X_ori])
                if (train_count + 1) % args.display_steps == 0:
                    f_log = open('%s/VAE_log.csv' % (args.checkpoint_dir), 'a')
                    f_log.write('%d,loss,%f, kl,%f, d1,%f\n' % (train_count + 1, train_loss_, kl_div_loss_, d1_))
                    print('%d,loss,%f, kl,%f, d1,%f\n' % (train_count + 1, train_loss_, kl_div_loss_, d1_))
                    f_log.close()

            if (train_count + 1) % args.save_steps == 0:

                num = math.floor(math.sqrt(rec_img.shape[0]))
                show_img = np.zeros([num * args.patch_size, num * args.patch_size, 3])
                ori_img = np.zeros([num * args.patch_size, num * args.patch_size, 3])
                for i in range(num):
                    for j in range(num):
                        show_img[i * args.patch_size:(i + 1) * args.patch_size,
                        j * args.patch_size:(j + 1) * args.patch_size, :] = rec_img[num * i + j, :, :, :]
                        ori_img[i * args.patch_size:(i + 1) * args.patch_size,
                        j * args.patch_size:(j + 1) * args.patch_size, :] = X_ori_[num * i + j, :, :, :]
                save_name = os.path.join(args.checkpoint_dir, 'VAE_rec_%s_%s.png' % (parameter, train_count + 1))
                scipy.misc.imsave(save_name, show_img)
                psnr_ = Psnr(ori_img, show_img)
                msssim_ = msssim(ori_img, show_img)
                # print('FOR calculation %s_%s_%s_%s_la1%s_la2%s_%s'%(
                # args.activation, args.dim1, args.dim2, args.z, args.lambda1, args.lambda2, train_count))
                print("PSNR (dB), %.2f,Multiscale SSIM, %.4f,Multiscale SSIM (dB), %.2f" % (
                    psnr_, msssim_, -10 * np.log10(1 - msssim_)))
                f_log_ssim = open('%s/VAE_log_ssim_%s.csv' % (args.checkpoint_dir, parameter), 'a')
                f_log_ssim.write('%s,%d,PSNR (dB), %.2f,Multiscale SSIM, %.4f,Multiscale SSIM (dB), %.2f\n' % (
                parameter, train_count + 1,
                psnr_, msssim_, -10 * np.log10(1 - msssim_)
                ))
                f_log_ssim.close()
            train_count += 1

def read_img(n):
    import random
    if args.data_set.lower() == 'celeba':
        images_path = args.img_path + '/*.png'

        images = glob.glob(images_path)
        images = sorted(images)
        # print(imgs.shape)
    imgs = np.zeros([n * n, args.patch_size, args.patch_size, 3])
    show_img = np.zeros([n * args.patch_size, n * args.patch_size, 3])

    for i in range(n * n):
        img_p = images[i]
        img = scipy.misc.imread(img_p).astype(np.float)
        h, w = img.shape[:2]
        if h > w:
            j = (h - w) // 2
            temp = scipy.misc.imresize(img[j:h - j, :, :], [args.patch_size, args.patch_size])
        else:
            j = (w - h) // 2
            temp = scipy.misc.imresize(img[:, j:w - j, :], [args.patch_size, args.patch_size])

        imgs[i, :, :, :] = temp
    for i in range(n):
        for j in range(n):
            show_img[i * args.patch_size:(i + 1) * args.patch_size, j * args.patch_size:(j + 1) * args.patch_size,
            :] = imgs[n * i + j, :, :, :]
    save_name = os.path.join(args.checkpoint_dir, 'clic_rdae_ori.png')
    scipy.misc.imsave(save_name, show_img)

    return imgs.astype(np.float), show_img


def geo_mean_overflow(iterable):
    a = np.log(iterable)
    return np.exp(a.sum()/len(a))

def read_png(filename):
    """Loads a PNG image file."""
    string = tf.read_file(filename)
    image = tf.image.decode_image(string, channels=3)
    image = tf.cast(image, tf.float32)
    image /= 255
    return image

def plot_analysis():
    cdim = args.cdim
    preprocess_threads = 6
    # read_img
    train_path = args.img_path + '/*.png'
    train_files = glob.glob(train_path)
    if not train_files:
        raise RuntimeError(
            "No training images found with glob '{}'.".format(train_path))
    train_dataset = tf.data.Dataset.from_tensor_slices(train_files)
    train_dataset = train_dataset.shuffle(buffer_size=len(train_files))  # .repeat()
    train_dataset = train_dataset.map(
        read_png, num_parallel_calls=preprocess_threads)
    train_dataset = train_dataset.map(
        lambda x: tf.random_crop(x, (args.patch_size, args.patch_size, 3)))
    train_dataset = train_dataset.batch(args.batch_size)
    train_dataset = train_dataset.prefetch(32)
    x = train_dataset.make_one_shot_iterator().get_next()

    # Construct model
    encoder_op, mean, var = analysis_transform(x, 64)
    x_pred = synthesis_transform(encoder_op,64)

    # Bring both images back to 0..255 range.
    x_pred = tf.clip_by_value(x_pred, 0, 1)
    x_pred = tf.round(x_pred * 255)
    x_pred = tf.cast(x_pred, tf.uint8)
    #end construct model

    #images_path = '../../data/CelebA/img_align_celeba_png/*.png'
    images_path = args.img_path + '/*.png'
    images = glob.glob(images_path)
    images = sorted(images)
    #temp = scipy.misc.imresize(img[:, j:-j, :], [args.patch_size, args.patch_size])

    batch_num = math.floor(len(images) / args.batch_size)
    print(len(images), batch_num)

    num = int(math.floor(math.sqrt(args.batch_size)))
    show_img = np.zeros([num * args.patch_size, num * args.patch_size, 3])

    #if args.split != 'None':
    #    parameter = '%s_%s_%s_la1%s_la2%s_%s_%s' % (args.activation, args.dim1, args.z, args.lambda1, args.lambda2, args.loss1, args.loss2)
    #else:
    #    parameter = '%s_%s_%s_la1%s_%s' % (args.activation, args.dim1, args.z, args.lambda1, args.loss1)
    parameter = ""


    n = 0
    with tf.Session() as sess:
        # restore model
        latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir)
        tf.train.Saver().restore(sess, save_path=latest)

        try:
            while True:
                #if n > 100:
                #    break
                #rec_img, z, mean_, var_ = sess.run([x_pred, encoder_op, mean, var], feed_dict={inputs: x})
                rec_img, z, mean_, var_ = sess.run([x_pred, encoder_op, mean, var])
                if n == 0:
                    zs = z
                    means_ = mean_
                    vars_ = 1/(np.power( np.exp(var_ * 0.5),2))
                else:
                    zs = np.vstack((zs, z))
                    means_ = np.vstack((means_, mean_))
                    vars_ = np.vstack((vars_, 1/(np.power( np.exp(var_ * 0.5),2))))
                n += 1
        except tf.errors.OutOfRangeError:\
            print('end!')

        print('zs', zs.shape)
        stds = np.squeeze(np.std(zs, axis=0))
        means = np.squeeze(np.mean(zs, axis=0))

        print('stds', stds.shape, 'means', means.shape)

        std_sorted = np.sort(stds)[::-1]
        std_index = np.argsort(stds)[::-1]

        var_sorted =np.power(std_sorted, 2)

        df = var_sorted.cumsum() / var_sorted.sum()

        geo_mean = geo_mean_overflow(var_sorted)
        coding_gain  = 10 * np.log10(np.mean(var_sorted) / geo_mean)

        x_1 = np.arange(0, var_sorted.shape[0], 1)

        fig, ax1 = plt.subplots()
        fig.text(0.2, 0.8, 'coding_gain=%.4f dB'% coding_gain)
        ax2 = ax1.twinx()
        ax1.bar(x_1, var_sorted)
        ax2.plot(x_1, df, color="red")
        ax1.set_xlabel('z(sorted by variance(descending order))')
        ax1.set_ylabel('variance of z')
        ax2.set_ylabel('Cumulative sum of variance')
        fig_name = os.path.join(args.checkpoint_dir, 'VAE_variance_df_%s.png' % (parameter))
        plt.savefig(fig_name)

        std_name = os.path.join(args.checkpoint_dir, 'std_df_%s.npy' % (parameter))
        np.save(std_name, stds)

        std_iname = os.path.join(args.checkpoint_dir, 'std_index_%s.npy' % (parameter))
        np.save(std_iname, std_index)

        mean_name = os.path.join(args.checkpoint_dir, 'mean_%s.npy' % (parameter))
        np.save(mean_name, means)

        # plt.scatter(x_sigma,rate)
        ########sigma#################
        sigma_stds = np.squeeze(np.std(vars_, axis=0))
        sigma_means = np.squeeze(np.mean(vars_, axis=0))
        #sigma_var_means = np.squeeze(np.mean(np.power(vars_,2), axis=0))
        sigma_var_means = np.squeeze(np.mean(vars_, axis=0))

        print('stds', sigma_stds.shape, 'means', sigma_means.shape)

        #sigma_std_sorted = np.sort(sigma_means)[::-1]
        sigma_std_index = np.argsort(sigma_means)[::-1]
        sigma_var_sorted =  np.sort(sigma_var_means)[::-1]
        sigma_df = sigma_var_sorted.cumsum() / sigma_var_sorted.sum()

        sigma_geo_mean = geo_mean_overflow(sigma_var_sorted)
        sigma_coding_gain  = 10 * np.log10(np.mean(sigma_var_sorted) / sigma_geo_mean)

        fig, ax1 = plt.subplots()
        fig.text(0.2, 0.8, 'coding_gain=%.4f dB'% sigma_coding_gain)
        ax2 = ax1.twinx()
        ax1.bar(x_1, sigma_var_sorted)
        ax2.plot(x_1, sigma_df, color="red")
        ax1.set_xlabel('z(sorted by variance(descending order))')
        ax1.set_ylabel('variance of z')
        ax2.set_ylabel('Cumulative sum of variance')
        fig_name = os.path.join(args.checkpoint_dir, 'VAE_variance_df_inv_sigma_%s.png' % (parameter))
        plt.savefig(fig_name)

        csv_name = os.path.join(args.checkpoint_dir, 'VAE_variance_df_inv_sigma2_%s.csv' % (parameter))
        #print(sigma_var_sorted)
        fcsv = open(csv_name,'w')
        for z in sigma_var_sorted:
            fcsv.write(str(z)+'\n')
        fcsv.close()

        std_name = os.path.join(args.checkpoint_dir, 'sigma_%s.npy' % (parameter))
        np.save(std_name, sigma_means)

        std_name = os.path.join(args.checkpoint_dir, 'sigma2_%s.npy' % (parameter))
        np.save(std_name, sigma_var_means)

        std_name = os.path.join(args.checkpoint_dir, 'sigma2_sorted_%s.npy' % (parameter))
        np.save(std_name, sigma_var_sorted)

        std_iname = os.path.join(args.checkpoint_dir, 'sigma_index_%s.npy' % (parameter))
        np.save(std_iname, sigma_std_index)


def sample_image():
    sample_num = 9
    cdim = args.cdim

    decoder_inputs = tf.placeholder(tf.float32, [sample_num * sample_num, args.z])
    inputs = tf.placeholder(tf.float32, [sample_num * sample_num, args.patch_size, args.patch_size, cdim])

    # Construct model
    encoder_op, mean_op, var_op = analysis_transform(inputs, 64)

    x_pred = synthesis_transform(encoder_op,64)
    # Bring both images back to 0..255 range.

    x_pred_2 = synthesis_transform(decoder_inputs, 64)
    x_pred_2_r = tf.clip_by_value(x_pred_2, 0, 1)
    x_pred_2_r = tf.round(x_pred_2_r * 255)
    x_pred_2_r = tf.cast(x_pred_2_r, tf.uint8)

    # metric
    ###add
    ssim_loss_gt = tf.image.ssim(inputs * 255, x_pred * 255, max_val=255)
    mse2_loss_gt = tf.reduce_mean(tf.squared_difference(inputs, x_pred), reduction_indices=[1,2,3])

    decoder_inputs_loss1 = tf.placeholder(tf.float32, [sample_num * sample_num, args.z])
    decoder_inputs_loss2 = tf.placeholder(tf.float32, [sample_num * sample_num, args.z])

    x_pred_loss1 = synthesis_transform(decoder_inputs_loss1, 64)
    x_pred_loss2 = synthesis_transform(decoder_inputs_loss2, 64)

    # metric
    mse_loss2 = tf.reduce_mean(tf.squared_difference(x_pred_loss1, x_pred_loss2))#D2 with delta
    mse_loss2_s = tf.reduce_mean(tf.squared_difference(x_pred_loss1, x_pred_loss2), reduction_indices=[1,2,3])#D2 with delta

    if args.loss1 == "mse2":
        loss_2_gt = (1 - mse2_loss_gt)
    elif args.loss1 == "myssim":
        loss_2_gt = (1 - ssim_loss_gt)
    else:
        print('error invalid loss1')
        return -1

    if args.loss1 =="mse2":
        # mse loss
        loss_2 = mse_loss2
        loss_2_s = mse_loss2_s
    elif args.loss1 == "myssim":
        myssim = ssim_matrix.ssim(255 * x_pred_loss1, 255 * (x_pred_loss1 - x_pred_loss2),
                                  255 * (x_pred_loss1 - x_pred_loss2), max_val=255, mode='train', compensation=1)
        loss_2 =  tf.reduce_mean(myssim)
        loss_2_s = myssim

    # define_input
    images_path = args.img_path + '/*.png'
    images = glob.glob(images_path)
    images = sorted(images)

    #if args.split == 'None':
    #    parameter = '%s_%s_%s_la1%s_%s' % (args.activation, args.dim1, args.z, args.lambda1, args.loss1)
    #else:
    #    parameter = '%s_%s_%s_la1%s_la2%s_%s_%s' % (args.activation, args.dim1, args.z, args.lambda1, args.lambda2, args.loss1, args.loss2)
    parameter = ""

    std_name = os.path.join(args.checkpoint_dir, 'sigma_%s.npy' % (parameter))
    std = np.load(std_name)

    mean_name = os.path.join(args.checkpoint_dir, 'mean_%s.npy' % (parameter))
    mean = np.load(mean_name)

    std_iname = os.path.join(args.checkpoint_dir, 'sigma_index_%s.npy' % (parameter))
    std_index = np.load(std_iname)

    sample_num = 9
    sample_cen = int((sample_num - 1)/2)

    samples_t = np.zeros([sample_num , sample_num, args.z])
    samples_t_all = np.zeros([4, sample_num , sample_num, args.z])

    std_range1 = std[std_index[0]]/((sample_num-1)/2) * 2
    std_range2 = std[std_index[1]]/((sample_num-1)/2) * 2
    print('high_std', std_range1, std_range2)

    std_ranges = []
    for i in range(args.z):
        std_ranges.append(std[std_index[i]]/((sample_num-1)/2) * 2)

    show_img_sample_t = np.zeros([sample_num * args.patch_size, sample_num * args.patch_size, 3])
    show_img_sample_t_tmp = np.zeros([sample_num * args.patch_size, sample_num * args.patch_size, 3])
    for s in range(4):
        for i in range(sample_num):
            for j in range(sample_num):
                samples_t_all[s, i, j] = mean
                samples_t_all[s, i, j][std_index[min(i+9*s, args.z-1)]] += (2/((sample_num-1)/2) * (j-sample_cen))


    z_range = int(args.z / 2)
    for i in range(sample_num):
        area = int(i / 3)
        print(z_range * area + i%3)
        for j in range(sample_num):
            samples_t[i, j] = mean
            if area < 2:
                samples_t[i, j][std_index[z_range * area + i%3]] += (2/((sample_num-1)/2) * (j-sample_cen))
            else:
                samples_t[i, j][std_index[-sample_num+i]] += (2 / ((sample_num - 1) / 2) * (j - sample_cen))

    samples_t =  samples_t.reshape((-1, args.z))
    samples_t_all =  samples_t_all.reshape((4, sample_num* sample_num, args.z))

    num = 9
    x, ori_img = read_img(num)
    x = x / 255.

    with tf.Session() as sess:
        # restore model
        latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        sample_rec_t = sess.run(x_pred_2_r, feed_dict={decoder_inputs: samples_t})

        #picture for top middle bottom
        for i in range(num):
            for j in range(num):
                show_img_sample_t[i * args.patch_size:(i + 1) * args.patch_size,
                j * args.patch_size:(j + 1) * args.patch_size, :] = sample_rec_t[num * i + j, :, :, :]

        save_name = os.path.join(args.checkpoint_dir, 'VAE_rec_%s_sample_t_sigma.png' % (parameter))
        scipy.misc.imsave(save_name, show_img_sample_t)

        #picture for all
        for si in range(4):
            sample_rec_tmp = sess.run(x_pred_2_r, feed_dict={decoder_inputs: samples_t_all[si]})
            for i in range(num):
                for j in range(num):
                    show_img_sample_t_tmp[i * args.patch_size:(i + 1) * args.patch_size,
                    j * args.patch_size:(j + 1) * args.patch_size, :] = sample_rec_tmp[num * i + j, :, :, :]
            save_name = os.path.join(args.checkpoint_dir, 'VAE_rec_%s_sample_t_sigma%s.png' % (parameter,str(si+1)))
            scipy.misc.imsave(save_name, show_img_sample_t_tmp)

        #calc statics
        dist_num = 10
        d_num=args.z
        for n in range(dist_num):
            print(n)
            imgs = np.zeros([sample_num * sample_num, args.patch_size, args.patch_size, 3])
            for i in range(num * num):
                img_p = images[n * sample_num * sample_num + i]
                img = scipy.misc.imread(img_p).astype(np.float)
                imgs[i, :, :, :] = img
            x = imgs.astype(np.float) / 255.
            encoder_op_, x_pred_, loss_gt, var_op_ = sess.run([mean_op, x_pred, loss_2_gt, var_op], feed_dict={inputs: x})
            #print(loss_gt.shape)

            d_list_s = []
            for d in range(d_num):
                delta = args.delta
                means_input = np.zeros([sample_num * sample_num, args.z])
                means_input[:, ] = mean
                are = int(d/20)
                id = are * int(args.z/3)+ (d%20)

                # show_img_sample_d = np.zeros([sample_num * args.patch_size, sample_num * args.patch_size, 3])
                dec_ip = encoder_op_.copy()#.reshape(-1, args.z)
                dec_ip_d = encoder_op_.copy()#.reshape(-1, args.z)
                dec_ip_d_sigma = encoder_op_.copy()#.reshape(-1, args.z)

                dec_ip[:, std_index[d]] = dec_ip[:, std_index[d]]
                dec_ip_d[:, std_index[d]] = dec_ip[:, std_index[d]]  + delta
                dec_ip_d_sigma[:, std_index[d]] = dec_ip[:, std_index[d]] + delta

                dist, dist_s = sess.run([loss_2, loss_2_s],    feed_dict={decoder_inputs_loss1: dec_ip, decoder_inputs_loss2: dec_ip_d, inputs: x})
                _, dist_s_sigma = sess.run([loss_2, loss_2_s], feed_dict={decoder_inputs_loss1: dec_ip, decoder_inputs_loss2: dec_ip_d_sigma, inputs: x})

                ds= dist_s / np.power(delta, 2)

                dist_s_sigma *= np.exp(var_op_[:, std_index[d]])
                ds_sigma = dist_s_sigma / np.power((delta) , 2)

                d_list_s.append([ds, ds_sigma])

            if n == 0:
                d_list_all_s = np.array(d_list_s).transpose(0,2,1)
            else:
                d_list_all_s = np.hstack((d_list_all_s, np.array(d_list_s).transpose(0,2,1)))

        d_list_all_s = np.squeeze(d_list_all_s)
        d_list_all_s = d_list_all_s.transpose(0, 2, 1)

        zs =[]
        means = []
        stds = []
        for z in range(d_num):
            zs.append('z%d'%(z))
            means.append(np.mean(d_list_all_s[z][0]))
            stds.append(np.std(d_list_all_s[z][0]))

        csv_pass = os.path.join(args.checkpoint_dir, 'VAE_rec_%s_dist_z%d_s_hakohige_delta%s.csv' % (parameter, d, args.delta))
        fw = open(csv_pass,'w')
        width = 0.45
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        #for i, j,k in zip(x,y,yerr): # (2)
        ii=0
        for i, j, k in zip(zs, means, stds):  # (2)
            fw.write('%.10f,%.10f\n'%(j,k))
            if ii < int(20):
                ax.bar(i,j, width, yerr = k, edgecolor = "black",
                       error_kw=dict(lw=1, capsize=8, capthick=1))  #  (3)

            ii+=1
        # ax.set(ylabel = '', title = 'Mean and SD of |D(z, z+Δ)|')
        ax.set(ylabel = 'Mean and SD of D(z, z+Δ)')
        fig_pass = os.path.join(args.checkpoint_dir, 'VAE_rec_%s_dist_z%d_s_hakohige_delta%s.png' % (parameter, d, args.delta))
        plt.savefig(fig_pass)
        fw.close()

        zs =[]
        means = []
        stds = []
        for z in range(d_num):
            zs.append('z%d'%(z))
            means.append(np.mean(d_list_all_s[z][1]))
            stds.append(np.std(d_list_all_s[z][1]))
        csv_pass = os.path.join(args.checkpoint_dir, 'VAE_rec_%s_dist_z%d_s_hakohige_delta%s_with_sigma.csv' % (parameter, d, args.delta))
        fw = open(csv_pass,'w')
        width = 0.45

        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ii=0
        for i, j, k in zip(zs, means, stds):  # (2)
            fw.write('%.10f,%.10f\n'%(j,k))
            if ii < int(20):
                ax.bar(i,j, width, yerr = k, edgecolor = "black",
                       error_kw=dict(lw=1, capsize=8, capthick=1))  #  (3)
            ii+=1
        ax.set(ylabel = 'Mean and SD of D(z, z+Δ)')
        fig_pass = os.path.join(args.checkpoint_dir, 'VAE_rec_%s_dist_z%d_s_hakohige_delta%s_with_sigma.png' % (parameter, d, args.delta))
        plt.savefig(fig_pass)
        fw.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "command", choices=["train", 'analy','sample_1'],
        help="What to do: 'train' loads training data and trains (or continues "
             "to train) a new model. 'test' load trained model and test.")
    parser.add_argument(
        "input", nargs="?",
        help="Input filename.")
    parser.add_argument(
        "output", nargs="?",
        help="Output filename.")
    parser.add_argument(
        "--batch_size", type=int, default=64,
        help="Batch size for training.")
    parser.add_argument(
        "--patch_size", type=int, default=64,
        help="Patch size for training.")
    parser.add_argument(
        "--data_set", default='CelebA',
        help="Batch size for training.")
    parser.add_argument(
        "--checkpoint_dir", default="sanity",
        help="Directory where to save/load model checkpoints.")
    parser.add_argument(
        "--img_path", default="../../data/CelebA/centered_celeba_64_10per/",
        help="Directory where to save/load model checkpoints.")
    parser.add_argument(
        "--num_steps", type=int, default=300000,
        help="Train up to this number of steps.")
    parser.add_argument(
        "--save_steps", type=int, default=100000,
        help="Train up to this number of steps.")
    parser.add_argument(
        "--display_steps", type=int, default=100,
        help="save loss for plot every this number of steps.")
    parser.add_argument(
        "--lambda1", type=float, default=1000,
        help="Lambda for distortion tradeoff.")
    parser.add_argument(
        "--lambda2", type=float, default=1000,
        help="Lambda for rate tradeoff.")
    parser.add_argument(
        "--loss1", type=str, default='myssim',
        help="mse, logmse, ssim, logssim")
    parser.add_argument(
        "--loss2", type=str, default='myssim',
        help=" mse or ssim")
    parser.add_argument(
        "--z", type=int, default=32,
        help="bottleneck number.")
    parser.add_argument(
        "--cdim", type=int, default=3,
        help="channel.")
    parser.add_argument(
        "--delta", type=float, default=0.01,
        help="bottleneck number.")
    parser.add_argument(
        "--dim1", type=int, default=1024,
        help="AE layer1.")
    parser.add_argument(
        "--activation", default="softplus")
    parser.add_argument(
        "--ac2", default='True')
    parser.add_argument(
        "--finetune", default="None")
    parser.add_argument(
        "--split", default="True")

    parser.add_argument('-gpu', '--gpu_id',
                        help='GPU device id to use [0]', default=0, type=int)

    args = parser.parse_args()
    # cpu mode
    if args.gpu_id < 0:
        os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
        os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)

    z_num = args.z

    if args.command == 'train':
        train()
        # python ae.py train --checkpoint_dir ae_718 --lambda 10 -gpu 0
    elif args.command == 'analy':
        plot_analysis()
        # if args.input is None or a
    elif args.command == 'sample_1':
        sample_image()