
from __future__ import division, print_function, absolute_import

import tensorflow as tf
import argparse
import numpy as np
import os, sys
import matplotlib.pyplot as plt
import math,time
#import scipy.stats.mstats.gmean as gmean
from sklearn.linear_model import LinearRegression


old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)

def mse2psnr(mse):
    return 20. * np.log10(255.) - 10. * np.log10(mse)

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("mnist/", one_hot=False)

tf.logging.set_verbosity(old_v)

# tf Graph input (only pictures)
num_hidden_1 = 1000  # 1st layer num features
num_hidden_2 = 1000  # 2nd layer num features (the latent dim)
num_hidden_3 = 128  # 2nd layer num features (the latent dim)
num_input = 784

# initializer=tf.contrib.layers.variance_scaling_initializer()
initializer = tf.contrib.layers.xavier_initializer()

weights2 = {
    'encoder_h1': tf.Variable(initializer([num_input, num_hidden_1])),
    'encoder_h2': tf.Variable(initializer([num_hidden_1, num_hidden_2])),
    'encoder_h3': tf.Variable(initializer([num_hidden_2, num_hidden_3])),
    'decoder_h1': tf.Variable(initializer([num_hidden_3, num_hidden_2])),
    'decoder_h2': tf.Variable(initializer([num_hidden_2, num_hidden_1])),
    'decoder_h3': tf.Variable(initializer([num_hidden_1, num_input])),
}
biases2 = {
    'encoder_b1': tf.Variable(initializer([num_hidden_1])),
    'encoder_b2': tf.Variable(initializer([num_hidden_2])),
    'encoder_b3': tf.Variable(initializer([num_hidden_3])),

    'decoder_b1': tf.Variable(initializer([num_hidden_2])),
    'decoder_b2': tf.Variable(initializer([num_hidden_1])),
    'decoder_b3': tf.Variable(initializer([num_input])),
}


def encoder(tensor):
    """Builds the analysis transform."""

    with tf.variable_scope("encoder"):
        if args.activation == 'sigmoid':
            with tf.variable_scope('encoder1'):
                tensor = tf.nn.sigmoid(tf.layers.dense(tensor, args.dim1, kernel_initializer=tf.contrib.layers.xavier_initializer()))
                tensor = tf.nn.sigmoid(tf.layers.dense(tensor, args.dim2, kernel_initializer=tf.contrib.layers.xavier_initializer()))

                mean = tf.layers.dense(tensor, args.z)
                # mean of sigma
                sigma = tf.layers.dense(tensor, args.z)

                # dense layer
                # Sampler: Normal (gaussian) random distribution
                eps = tf.random_normal(tf.shape(mean), dtype=tf.float32, mean=0., stddev=1.0,
                                       name='epsilon')
                # reparameterization trick
                z = mean + tf.exp(sigma / 2) * eps
        elif args.activation == 'softplus':
            with tf.variable_scope('encoder'):
                tensor = tf.nn.softplus(tf.layers.dense(tensor, args.dim1, kernel_initializer=tf.contrib.layers.xavier_initializer()))
                tensor = tf.nn.softplus(tf.layers.dense(tensor, args.dim2 ,kernel_initializer=tf.contrib.layers.xavier_initializer()))
                mean = tf.layers.dense(tensor, args.z)
                sigma = tf.layers.dense(tensor, args.z)
                eps = tf.random_normal(tf.shape(mean), dtype=tf.float32, mean=0., stddev=1.0,
                                       name='epsilon')
                # reparameterization trick
                z = mean + tf.exp(sigma / 2) * eps
        elif args.activation == 'relu':
            with tf.variable_scope('encoder'):
                tensor = tf.nn.relu(
                    tf.layers.dense(tensor, args.dim1, kernel_initializer=tf.contrib.layers.xavier_initializer()))
                tensor = tf.nn.relu(
                    tf.layers.dense(tensor, args.dim2, kernel_initializer=tf.contrib.layers.xavier_initializer()))
                mean = tf.layers.dense(tensor, args.z)
                # mean of sigma
                sigma = tf.layers.dense(tensor, args.z)

                # dense layer
                # Sampler: Normal (gaussian) random distribution
                eps = tf.random_normal(tf.shape(mean), dtype=tf.float32, mean=0., stddev=1.0,
                                       name='epsilon')
                # reparameterization trick
                z = mean + tf.exp(sigma / 2) * eps
        return z, mean, sigma


def decoder(tensor):
    """Builds the synthesis transform."""

    with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
        if args.activation == 'sigmoid':
            with tf.variable_scope('decoder'):
                # tensor =  tf.nn.sigmoid(tf.layers.dense(tensor, args.dim2))
                tensor = tf.nn.sigmoid(tf.layers.dense(tensor, args.dim2,kernel_initializer=tf.contrib.layers.xavier_initializer()))
                tensor = tf.nn.sigmoid(tf.layers.dense(tensor, args.dim1,kernel_initializer=tf.contrib.layers.xavier_initializer()))
                tensor = tf.layers.dense(tensor, 784)
        elif args.activation == 'relu':
            with tf.variable_scope('decoder'):
                # tensor =  tf.nn.sigmoid(tf.layers.dense(tensor, args.dim2))
                tensor = tf.nn.relu(tf.layers.dense(tensor, args.dim2,kernel_initializer=tf.contrib.layers.xavier_initializer()))
                tensor = tf.nn.relu(tf.layers.dense(tensor, args.dim1,kernel_initializer=tf.contrib.layers.xavier_initializer()))
                tensor = tf.layers.dense(tensor, 784)
        elif args.activation == 'softplus':
            with tf.variable_scope('decoder'):
                # tensor =  tf.nn.softplus(tf.layers.dense(tensor, args.dim2))
                tensor = tf.nn.softplus(tf.layers.dense(tensor, args.dim2,kernel_initializer=tf.contrib.layers.xavier_initializer()))
                tensor = tf.nn.softplus(tf.layers.dense(tensor, args.dim1,kernel_initializer=tf.contrib.layers.xavier_initializer()))
                tensor = tf.layers.dense(tensor, 784)
        return tensor


def train():
    if not os.path.exists(args.checkpoint_dir):
        # shutil.rmtree(args.checkpoint_dir)
        os.makedirs(args.checkpoint_dir)
    log_name = os.path.join(args.checkpoint_dir, 'params.log')
    if os.path.exists(log_name):
        print('remove file:%s' % log_name)
        os.remove(log_name)
    params = open(log_name, 'w')
    for arg in vars(args):
        str_ = '%s: %s.\n' % (arg, getattr(args, arg))
        print(str_)
        params.write(str_)
    params.close()
    # Training Parameters
    learning_rate = 0.0001

    # Construct model
    X = tf.placeholder("float", [None, num_input])
    alpha=args.alpha
    X = X * (1- 2*alpha) + alpha
    encoder_op, mean, var = encoder(X)
    X_pred = tf.nn.sigmoid(decoder(mean))
    X_pred2 = tf.nn.sigmoid(decoder(encoder_op))

    X_pred = tf.clip_by_value(X_pred, alpha +  1e-8, 1 - alpha- 1e-8)
    X_pred2 = tf.clip_by_value(X_pred2, alpha +  1e-8, 1 - alpha- 1e-8)

    kl_div_loss = 1 + var - tf.square(mean) - tf.exp(var)
    kl_div_loss = -0.5 * tf.reduce_sum(kl_div_loss, 1)
    kl_div_loss = tf.reduce_mean(kl_div_loss)

    ##########VAE########################
    if args.split == 'True':
        marginal_likelihood = tf.reduce_sum(X * tf.log(X_pred) + (1 - X) * tf.log(1 - X_pred), 1) / num_input
        rec_loss1 = -tf.reduce_mean(marginal_likelihood)

        marginal_likelihood_2 = tf.reduce_sum(X_pred * tf.log(X_pred2) + (1 - X_pred) * tf.log(1 - X_pred2), 1) / num_input
        rec_loss2 = -tf.reduce_mean(marginal_likelihood_2)
        loss = kl_div_loss + args.lambda1 * rec_loss1 + args.lambda2 * rec_loss2
    else:
        marginal_likelihood = tf.reduce_sum(X * tf.log(X_pred2) + (1 - X) * tf.log(1 - X_pred2), 1) / num_input
        print('marginal',marginal_likelihood.shape)
        rec_loss1 = -tf.reduce_mean(marginal_likelihood)
        loss = kl_div_loss + args.lambda1 * rec_loss1

    mse = tf.reduce_mean(tf.squared_difference(255 * X, 255 * X_pred))


    step = tf.train.create_global_step()

    main_lr = learning_rate
    main_step = tf.train.AdamOptimizer(main_lr).minimize(loss, global_step=step)
    optimizer = main_step

    # Initialize the variables (i.e. assign their default value)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    # Start Training
    # Start a new TF session
    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    parameter = ''

    with tf.Session() as sess:
        # Run the initializer
        sess.run(init)
        # Training
        for i in range(1, args.num_steps + 1):
            # Prepare Data
            # Get the next batch of MNIST data (only images are needed, not labels)
            batch_x, _ = mnist.train.next_batch(args.batch_size)

            # Run optimization op (backprop) and cost op (to get loss value)
            if args.split == 'True':
                _, train_loss_ , d1_, d2_, rate_, mse_, enc_op, X_pred_ = sess.run(
                    [optimizer, loss, rec_loss1, rec_loss2, kl_div_loss, mse, encoder_op, X_pred],
                    feed_dict={X: batch_x})
            else:
                _, train_loss_ , d1_, rate_, mse_, enc_op, X_pred_ = sess.run(
                    [optimizer, loss, rec_loss1, kl_div_loss, mse, encoder_op, X_pred],
                    feed_dict={X: batch_x})
            if np.isnan(train_loss_):
                print('nan!!!')
                f_log = open('%s/MNIST_VAE_log_%s.csv' % (args.checkpoint_dir, parameter), 'a')
                f_log.write('nan\n')
                f_log.close()
                break

            if i % args.display_step == 0 and i > 1:
                if args.split == 'True':
                    f_log = open('%s/MNIST_VAE_log_%s.csv' % (args.checkpoint_dir, parameter), 'a')
                    f_log.write('%d,loss,%f,rate,%f,d1,%f,d2,%f\n' % (i + 1, train_loss_, rate_, d1_, d2_))
                    print(parameter)
                    print('%d,loss,%f,rate,%f,d1,%f,d2,%f\n' % (i + 1, train_loss_, rate_, d1_, d2_))
                    f_log.close()
                else:
                    f_log = open('%s/MNIST_VAE_log_%s.csv' % (args.checkpoint_dir, parameter), 'a')
                    f_log.write('%d,loss,%f,rate,%f,d1,%f\n' % (i + 1, train_loss_, rate_, d1_))
                    print(parameter)
                    print('%d,loss,%f,rate,%f,d1,%f\n' % (i + 1, train_loss_, rate_, d1_))
                    f_log.close()
            # save model
            if i % args.save_steps == 0 and i > 0:
                print('savestep!!')
                saved_dir = os.path.join(args.checkpoint_dir, 'model')
                #saved_dir = args.checkpoint_dir
                saver.save(sess, saved_dir, global_step=i)
                print('Save model to %s, step: %d.' % (saved_dir, i))
                print('Test psnr:%f' % (mse2psnr(mse_  ) ) )
                f_log_ssim = open('%s/rdae_MNIST_psnr_%s.csv' % (args.checkpoint_dir, parameter), 'a')
                f_log_ssim.write('%s,%d,PSNR (dB), %.2f\n' % (
                    parameter, i,
                    mse2psnr(mse_)
                ))
                f_log_ssim.close()
                # print('Training loss')
                n = int(np.sqrt(args.batch_size))
                canvas_orig = np.empty((28 * n, 28 * n))
                canvas_recon = np.empty((28 * n, 28 * n))
                for ji in range (n):
                    for j in range(n):
                        # Draw the original digits
                        canvas_orig[ji * 28:(ji + 1) * 28, j * 28:(j + 1) * 28] = \
                            batch_x[ji*n + j].reshape([28, 28])
                    # Display reconstructed images
                    for j in range(n):
                        # Draw the reconstructed digits
                        canvas_recon[ji * 28:(ji + 1) * 28, j * 28:(j + 1) * 28] = \
                            X_pred_[ji*n + j].reshape([28, 28])

                fig = plt.figure()
                ax1 = fig.add_subplot(1, 2, 1)
                ax1.imshow(canvas_orig, origin="upper", cmap="gray")
                ax2 = fig.add_subplot(1, 2, 2)
                ax2.imshow(canvas_recon, origin="upper", cmap="gray")
                fig_pass = os.path.join(args.checkpoint_dir,
                                        'rdae_rec_%s_step_%s.png' % (parameter, i))
                plt.savefig(fig_pass, format='png', dpi=1000, bbox_inches='tight')


def plot_analysis():
    preprocess_threads = 6

    # Construct model
    X = tf.placeholder("float", [None, num_input])
    encoder_op, mean, var = encoder(X)
    batch_num = math.floor(60000 / args.batch_size)

    parameter = ""


    n = 0
    with tf.Session() as sess:
        # restore model
        latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir)
        tf.train.Saver().restore(sess, save_path=latest)

        #try:
            #while True:
        for n in range(batch_num):
            batch_x, _ = mnist.train.next_batch(args.batch_size, shuffle=False)

            z, mean_, var_ = sess.run([encoder_op, mean, var], feed_dict={X: batch_x})
            if n == 0:
                zs = z
                means_ = mean_
                vars_ = 1/(np.power( np.exp(var_ * 0.5),2))
            else:
                zs = np.vstack((zs, z))
                means_ = np.vstack((means_, mean_))
                vars_ = np.vstack((vars_, 1/(np.power( np.exp(var_ * 0.5),2))))
            n += 1


        print('zs', zs.shape)
        stds = np.squeeze(np.std(zs, axis=0))
        means = np.squeeze(np.mean(zs, axis=0))

        print('stds', stds.shape, 'means', means.shape)

        std_sorted = np.sort(stds)[::-1]
        std_index = np.argsort(stds)[::-1]

        var_sorted =np.power(std_sorted, 2)

        df = var_sorted.cumsum() / var_sorted.sum()

        #geo_mean = geo_mean_overflow(var_sorted)
        #coding_gain  = 10 * np.log10(np.mean(var_sorted) / geo_mean)

        x_1 = np.arange(0, var_sorted.shape[0], 1)

        fig, ax1 = plt.subplots()
        #fig.text(0.2, 0.8, 'coding_gain=%.4f dB'% coding_gain)
        ax2 = ax1.twinx()
        ax1.bar(x_1, var_sorted)
        ax2.plot(x_1, df, color="red")
        ax1.set_xlabel('z(sorted by variance(descending order))')
        ax1.set_ylabel('variance of z')
        ax2.set_ylabel('Cumulative sum of variance')
        fig_name = os.path.join(args.checkpoint_dir, 'VAE_variance_df_%s.png' % (parameter))
        plt.savefig(fig_name)

        std_name = os.path.join(args.checkpoint_dir, 'std_df_%s.npy' % (parameter))
        np.save(std_name, stds)

        std_iname = os.path.join(args.checkpoint_dir, 'std_index_%s.npy' % (parameter))
        np.save(std_iname, std_index)

        mean_name = os.path.join(args.checkpoint_dir, 'mean_%s.npy' % (parameter))
        np.save(mean_name, means)

        # plt.scatter(x_sigma,rate)
        ########sigma#################
        sigma_stds = np.squeeze(np.std(vars_, axis=0))
        sigma_means = np.squeeze(np.mean(vars_, axis=0))
        #sigma_var_means = np.squeeze(np.mean(np.power(vars_,2), axis=0))
        sigma_var_means = np.squeeze(np.mean(vars_, axis=0))

        print('stds', sigma_stds.shape, 'means', sigma_means.shape)

        #sigma_std_sorted = np.sort(sigma_means)[::-1]
        sigma_std_index = np.argsort(sigma_means)[::-1]
        sigma_var_sorted =  np.sort(sigma_var_means)[::-1]
        sigma_df = sigma_var_sorted.cumsum() / sigma_var_sorted.sum()



        fig, ax1 = plt.subplots()
        ax2 = ax1.twinx()
        ax1.bar(x_1, sigma_var_sorted)
        ax2.plot(x_1, sigma_df, color="red")
        ax1.set_xlabel('z(sorted by variance(descending order))')
        ax1.set_ylabel('variance of z')
        ax2.set_ylabel('Cumulative sum of variance')
        fig_name = os.path.join(args.checkpoint_dir, 'VAE_variance_df_inv_sigma_%s.png' % (parameter))
        plt.savefig(fig_name)

        csv_name = os.path.join(args.checkpoint_dir, 'VAE_variance_df_inv_sigma2_%s.csv' % (parameter))
        #print(sigma_var_sorted)
        fcsv = open(csv_name,'w')
        for z in sigma_var_sorted:
            fcsv.write(str(z)+'\n')
        fcsv.close()

        std_name = os.path.join(args.checkpoint_dir, 'sigma_%s.npy' % (parameter))
        np.save(std_name, sigma_means)

        std_name = os.path.join(args.checkpoint_dir, 'sigma2_%s.npy' % (parameter))
        np.save(std_name, sigma_var_means)

        std_name = os.path.join(args.checkpoint_dir, 'sigma2_sorted_%s.npy' % (parameter))
        np.save(std_name, sigma_var_sorted)

        std_iname = os.path.join(args.checkpoint_dir, 'sigma_index_%s.npy' % (parameter))
        np.save(std_iname, sigma_std_index)

def sample_image():
    sample_num = 9


    decoder_inputs = tf.placeholder('float', [None, args.z])
    inputs = tf.placeholder("float", [None, num_input])

    # Construct model
    encoder_op, mean_op, var_op = encoder(inputs)

    x_pred = decoder(encoder_op)


    # metric
    decoder_inputs_loss1 = tf.placeholder(tf.float32, [args.batch_size, args.z])
    decoder_inputs_loss2 = tf.placeholder(tf.float32, [args.batch_size, args.z])

    x_pred_loss1 = tf.nn.sigmoid(decoder(decoder_inputs_loss1))
    x_pred_loss2 = tf.nn.sigmoid(decoder(decoder_inputs_loss2))

    # metric
    alpha = args.alpha

    x_pred_loss1 = tf.clip_by_value(x_pred_loss1, alpha + 1e-10, 1 - 1e-10 - alpha)
    x_pred_loss2 = tf.clip_by_value(x_pred_loss2, alpha + 1e-10, 1 - 1e-10 - alpha)

    marginal_likelihood = tf.reduce_sum(x_pred_loss1 * tf.log(x_pred_loss2) + (1 - x_pred_loss1) * tf.log(1 - x_pred_loss2), 1) /num_input
    marginal_likelihood_gt = tf.reduce_sum(x_pred_loss1 * tf.log(x_pred_loss1) + (1 - x_pred_loss1) * tf.log(1 - x_pred_loss1), 1) /num_input

    loss_2 = -tf.reduce_mean(marginal_likelihood)
    loss_2_gt = -marginal_likelihood_gt
    loss_2_s = -marginal_likelihood
    print(loss_2_s.shape)


    # define_input
    parameter = ""

    std_name = os.path.join(args.checkpoint_dir, 'sigma_%s.npy' % (parameter))
    std = np.load(std_name)

    mean_name = os.path.join(args.checkpoint_dir, 'mean_%s.npy' % (parameter))
    mean = np.load(mean_name)

    std_iname = os.path.join(args.checkpoint_dir, 'sigma_index_%s.npy' % (parameter))
    std_index = np.load(std_iname)

    sigma2_name = os.path.join(args.checkpoint_dir, 'sigma2_sorted_%s.npy' % (parameter))
    sigma_sorted = np.load(sigma2_name)

    sample_num = 9
    sample_cen = int((sample_num - 1)/2)

    std_range1 = std[std_index[0]]/((sample_num-1)/2) * 2
    std_range2 = std[std_index[1]]/((sample_num-1)/2) * 2
    print('high_std', std_range1, std_range2)

    std_ranges = []
    for i in range(args.z):
        std_ranges.append(std[std_index[i]]/((sample_num-1)/2) * 2)


    num = 9
    #x, ori_img = read_img(num)
    #x = x / 255.

    with tf.Session() as sess:
        # restore model
        latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir)
        tf.train.Saver().restore(sess, save_path=latest)

        dist_num = 10
        d_num=args.z
        for n in range(dist_num):
            print(n)

            x, _ = mnist.train.next_batch(args.batch_size, shuffle=False)
            print(x.shape)
            encoder_op_, x_pred_, var_op_ = sess.run([mean_op, x_pred, var_op], feed_dict={inputs: x})
            #print(loss_gt.shape)
            d_list = []
            d_list_s = []
            for d in range(d_num):
                r_list = []

                delta = args.delta

                means_input = np.zeros([sample_num * sample_num, args.z])
                means_input[:, ] = mean

                are = int(d/20)
                id = are * int(args.z/3)+ (d%20)
                #print(id)

                # show_img_sample_d = np.zeros([sample_num * args.patch_size, sample_num * args.patch_size, 3])
                dec_ip = encoder_op_.copy()#.reshape(-1, args.z)
                dec_ip_d = encoder_op_.copy()#.reshape(-1, args.z)
                dec_ip_d_sigma = encoder_op_.copy()#.reshape(-1, args.z)

                diff_from_mean = dec_ip[:, std_index[d]] - mean[std_index[d]]

                dec_ip[:, std_index[d]] = dec_ip[:, std_index[d]]
                dec_ip_d[:, std_index[d]] = dec_ip[:, std_index[d]] + delta

                dec_ip_d_sigma[:, std_index[d]] = dec_ip[:, std_index[d]] + delta

                dist, dist_s, dist_offset, x_pred_v = sess.run([loss_2, loss_2_s, loss_2_gt, x_pred_loss2],    feed_dict={decoder_inputs_loss1: dec_ip, decoder_inputs_loss2: dec_ip_d, inputs: x})

                _, dist_s_sigma = sess.run([loss_2, loss_2_s], feed_dict={decoder_inputs_loss1: dec_ip, decoder_inputs_loss2: dec_ip_d_sigma, inputs: x})
                dist_s = dist_s - dist_offset
                dist_s_sigma = dist_s_sigma - dist_offset

                #else:
                dist_s_sigma *= np.exp(var_op_[:, std_index[d]])


                #print('dist_s', dist_s.shape)
                ds= dist_s / np.power(delta, 2)
                ans = np.round_(50 + 10 * (ds - np.average(ds)) / np.std(ds))
                ds_sigma = dist_s_sigma / np.power((delta) , 2)

                d_list_s.append([ans, diff_from_mean, abs(diff_from_mean), ds, ds_sigma])

                r_list.append(dist/ np.power(delta, 2))

                d_list.append(r_list)
            if n == 0:
                d_list_all = np.array(d_list)
                d_list_all_s = np.array(d_list_s).transpose(0,2,1)
            else:
                d_list_all = np.hstack((d_list_all, np.array(d_list)))
                d_list_all_s = np.hstack((d_list_all_s, np.array(d_list_s).transpose(0,2,1)))


        d_list_all_s = np.squeeze(d_list_all_s)

        d_list_all_s = d_list_all_s.transpose(0, 2, 1)
        print(d_list_all_s.shape)

        zs =[]
        means = []
        stds = []
        #print(d_list_all_s[0][3][:20])
        for z in range(d_num):
            zs.append('z%d'%(z))
            means.append(np.mean(d_list_all_s[z][3]))
            stds.append(np.std(d_list_all_s[z][3]))

        print(len(stds))
        print(len(zs))
        print(len(means))
        print(means)

        plt_ind=np.arange(args.z)
        csv_pass = os.path.join(args.checkpoint_dir, 'VAE_rec_%s_dist_z%d_s_hakohige_delta%s.csv' % (parameter, d, args.delta))
        fw = open(csv_pass,'w')
        width = 0.45
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        #for i, j,k in zip(x,y,yerr): # (2)
        ii=0

        for i, j, k in zip(zs, means, stds):  # (2)
            fw.write('%.10f,%.10f\n'%(j,k))
            ii+=1

        #ax.lines.Line2D(plt_ind, np.array(means))
        ax.plot(plt_ind, np.array(means))
        # ax.set(ylabel = '', title = 'Mean and SD of |D(z, z+Δ)|')
        ax.set(ylabel = 'Mean and SD of D(z, z+Δ)')
        fig_pass = os.path.join(args.checkpoint_dir, 'VAE_rec_%s_dist_z%d_s_hakohige_delta%s.png' % (parameter, d, args.delta))
        plt.savefig(fig_pass)
        fw.close()

        zs =[]
        means = []
        stds = []
        for z in range(d_num):
            zs.append('z%d'%(z))
            means.append(np.mean(d_list_all_s[z][4]))
            stds.append(np.std(d_list_all_s[z][4]))
        csv_pass = os.path.join(args.checkpoint_dir, 'VAE_rec_%s_dist_z%d_s_hakohige_delta%s_with_sigma.csv' % (parameter, d, args.delta))
        fw = open(csv_pass,'w')
        width = 0.45

        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ii=0

        for i, j, k in zip(zs, means, stds):  # (2)
            fw.write('%.10f,%.10f\n'%(j,k))
            ii+=1

        #ax.lines.Line2D(plt_ind, np.array(means))
        ax.plot(plt_ind, np.array(means))
        ax.set(ylabel = 'Mean and SD of D(z, z+Δ)')
        fig_pass = os.path.join(args.checkpoint_dir, 'VAE_rec_%s_dist_z%d_s_hakohige_delta%s_with_sigma.png' % (parameter, d, args.delta))
        plt.savefig(fig_pass)
        fw.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "command", choices=["train", 'test', 'plot','analy','sample'],
        help="What to do: 'train' loads training data and trains (or continues "
             "to train) a new model. 'test' load trained model and test.")
    parser.add_argument(
        "--batch_size", type=int, default=256,
        help="Batch size for training.")
    parser.add_argument(
        "--image_size", type=int, default=28,
        help="Batch size for training.")
    parser.add_argument(
        "--num_steps", type=int, default=30000,
        help="Train up to this number of steps.")
    parser.add_argument(
        "--img_path", default="mnist/",
        help="Directory where to save/load model checkpoints.")
    parser.add_argument(
        "--display_step", type=int, default=100,
        help="save loss for plot every this number of steps.")
    parser.add_argument(
        "--save_steps", type=int, default=10000,
        help="Train up to this number of steps.")
    parser.add_argument(
        "--lambda1", type=float, default=10,
        help="Lambda for rate-distortion tradeoff.")
    parser.add_argument(
        "--lambda2", type=float, default=1,
        help="Lambda for rate-distortion tradeoff.")
    parser.add_argument(
        "--alpha", type=float, default=0.0)
    parser.add_argument(
        "--loss1", type=str, default='bce',
        help="mse, logmse, ssim, logssim, pssim")
    parser.add_argument(
        "--loss2", type=str, default='bce',
        help=" mse or ssim pssim")
    parser.add_argument(
        "--z", type=int, default=128,
        help="bottleneck number.")
    parser.add_argument(
        "--dim1", type=int, default=1000,
        help="AE layer1.")
    parser.add_argument(
        "--dim2", type=int, default=1000,
        help="AE layer2.")
    parser.add_argument(
        "--split", default='True',
        help="split loss.")
    parser.add_argument(
        "--checkpoint_dir", default="train",
        help="Directory where to save/load model checkpoints.")
    parser.add_argument(
        "--stds", default="std\\std_frdc_loss_lambda10",
        help="Filename to save std log")
    parser.add_argument(
        "--rate", default="rate",
        help="rate or AE")
    parser.add_argument(
        "--delta", type=float, default=0.01,
        help="Lambda for rate-distortion tradeoff.")

    parser.add_argument(
        "--activation", default="relu")
    parser.add_argument('-gpu', '--gpu_id',
                        help='GPU device id to use [0]', default=0, type=int)

    args = parser.parse_args()
    # cpu mode
    if args.gpu_id < 0:
        os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
        os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)

    if args.command == 'train':
        train()
    elif args.command == 'analy':
        plot_analysis()
    elif args.command == 'sample':
        sample_image()


