import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import os,sys
from scipy import io
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import argparse
from distutils.util import strtobool
from dagmm import DAGMM
from rdgmm import RDGMM
import tensorflow as tf
import os,sys
fig = plt.figure()
ax = Axes3D(fig)

# 軸にラベルを付けたいときは書く
ax.set_xlabel("z0")
ax.set_ylabel("z1")
ax.set_zlabel("z2")


def plot(args):
    fdir = '%s_%s_%s/%s/ch%s_eh%s_q%.1f_l1_%s_l2_%s_lr%s_%s/take%s' % (args.checkpoint_dir, args.model, args.normalize,
                                                                    args.dataset,
                                                            args.comp_hiddens, args.est_hiddens, args.qstep, args.lambda1,
                                                            args.lambda2, args.learning_rate, args.act, args.take)
    z =  np.array(io.loadmat(fdir+'/z.mat')['array'])[-1]

    z = z.transpose()

    # .plotで描画
    # linestyle='None'にしないと初期値では線が引かれるが、3次元の散布図だと大抵ジャマになる
    ax.scatter(z[0], z[1], z[2], s=2)

    ax.legend(loc='upper left')

    # plt.scatter(z[:, 0], z[:, 1], c=y_test, cmap='tab10', s=2)
    # ax.plot(z[:, 0], z[:, 1], z[:, 2], c=y_test, s=2, linestyle='None')
    # ax.scatter(z[:, 0], z[:, 1], z[:, 2], c=y_test, s=2)
    # plt.colorbar()
    #fig_pass = os.path.join(fdir, 'z_plot_%s_%s.png'% (dataset, params))
    #plt.savefig(fig_pass)

    plt.show()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '--mode', choices=["train", 'test', 'plot'],
        help="What to do: 'train' loads training data and trains (or continues "
        "to train) a new model. 'test' load trained model and test.")
    parser.add_argument(
        "--model", default="DAGMM",
        help="name of model")
    parser.add_argument(
        "--dataset", required=True, default="Arrhythmia",
        help="name of dataset(Arrhythmia, KDDCup99, KDDCup-rev) ")
    parser.add_argument(
        "--batch_size", type=int, default=128,
        help="Batch size for training.")
    parser.add_argument(
        "--learning_rate", type=float, default=1e-4,
        help="learning rate")
    parser.add_argument(
        "--epoch_size", type=int, default=10000,
        help="Train up to this number of epochs.")
    parser.add_argument(
        "--display_step", type=int, default=2000,
        help="save loss for plot every this number of steps.")
    parser.add_argument(
        "--lambda1", type=float, default=1,
        help="Lambda for rate-distortion tradeoff.")
    parser.add_argument(
        "--lambda2", type=float, default=1000,
        help="Lambda for rate-distortion tradeoff.")
    parser.add_argument(
        "--qstep", type=float, default=0.1,
        help="quantization step")
    parser.add_argument(
        "--checkpoint_dir", default="cache",
        help="Directory where to save/load model checkpoints.")
    parser.add_argument(
        "--stds", default="std\\std_frdc_loss_lambda10",
        help="Filename to save std log")
    parser.add_argument(
        "--comp_hiddens", default="10_2",
        help="number of compression layers")
    parser.add_argument(
        "--est_hiddens", default="10_2",
        help="number of hidden layers")
    parser.add_argument(
        "--take", default="0",
        help="number_of_take")
    parser.add_argument(
        "--normalize", default="min-max",
        help="normalize")
    parser.add_argument(
        "--zr", type=strtobool, default=True,
        help="use zr or not")

    parser.add_argument(
        "--act", default="tanh",
        help="normalize")

    parser.add_argument('--model_save_step', type=int, default=1000)

    parser.add_argument('-gpu','--gpu_id',
        help='GPU device id to use [0]',default=0, type=int)

    args = parser.parse_args()
    plot(args=args)
