import random
import time
import numpy                 as np
import tensorflow            as tf
import tensorflow_addons     as tfa
import gudhi                 as gd
import sys

from NeuraLayout_Training_3 import *
import pickle
import argparse

def get_conv_val(loss):
    assert len(loss) >= 1000
    loss_mean = np.array(loss[-100:]).mean()
    loss_sig = np.sqrt(np.array(loss[-100:]).var())
    conv_val = loss_mean + loss_sig
    return conv_val

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="""
    Training original model. 
    """)
    parser.add_argument("file_dir", type=str, help="directory which training results will be stored")
    parser.add_argument("--N", type=int, default=150, help="point cloud size")
    parser.add_argument("--range", type=float, default=1.0, help='point cloud range')
    parser.add_argument("--card", type=int, default=50, help="maximum number of features to optimize")
    parser.add_argument("--train_lr", type=float, default=0.01, help="learning rate for training")
    parser.add_argument("--optim", type=str, default='Adam', help="optimizer used for both prefitting and training")
    parser.add_argument("--max_train_epoch", type=int, default=1000, help="maximum number of training epochs")
    parser.add_argument("--conv_patience", type=int, default=100, help="patience parameter in convergence criteria")
    parser.add_argument("--conv_thresh", type=float, default=0.02, help="threshold parameter in convergence criteria")

    args = parser.parse_args()
    
    FILE_DIR = args.file_dir
    N = args.N
    RANGE = args.range
    assert RANGE > 0
    CARD = args.card
    TRAIN_LR = args.train_lr
    OPTIM = args.optim
    MAX_TRAIN_EPOCH = args.max_train_epoch
    CONV_PATIENCE = args.conv_patience
    CONV_THRESH = args.conv_thresh
    
#     print('CONV_THRESH:',CONV_THRESH)
    
    np.random.seed(1)
    Xinit = np.array(np.random.uniform(high=RANGE, low=-RANGE, size=(N,2)), dtype=np.float32)

    # -- arg_dict: dictionary storing model parameters
            #    -- 'Rips_linear_1L': 'mel', 'dim', 'card'
            #    -- 'Rips_linear': 'hidden_dims','mel', 'dim', 'card'
            #    -- 'RipsGCN': 'hidden_dims','mel', 'dim', 'card','lam'

    arg_dict = {'mel':12.,'dim':1,'card':CARD}
    original_adam = ModelTrain('Rips_linear_1L',arg_dict,Xinit,FILE_DIR,r=RANGE,train_lr=TRAIN_LR,optim=OPTIM)
    
    with tf.device('/CPU:0'):
        losses,runtime_list,output_list = original_adam.train_regular(
            max_train_epoch=MAX_TRAIN_EPOCH,conv_patience=CONV_PATIENCE,conv_thresh=CONV_THRESH
        )
    
#     losses,runtime_list,output_list = original_adam.train_regular(
#             max_train_epoch=MAX_TRAIN_EPOCH,conv_patience=CONV_PATIENCE,conv_thresh=CONV_THRESH
#     )