import random
import time
import numpy                 as np
import tensorflow            as tf
import tensorflow_addons     as tfa
import gudhi                 as gd
import sys

from NeuraLayout_Training_3 import *
import pickle
import argparse

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="""
    Training original model. 
    """)
    parser.add_argument("file_dir", type=str, help="directory which training results will be stored")
    parser.add_argument("--N", type=int, default=150, help="point cloud size")
    parser.add_argument("--range", type=float, default=1.0, help='point cloud range')
    parser.add_argument("--card", type=int, default=50, help="maximum number of features to optimize")
    parser.add_argument("--lam", type=float, default=0.6, help="lambda parameter")
    parser.add_argument("--h1", type=int, default=8, help="hidden dimensions 1")
    parser.add_argument("--h2", type=int, default=6, help="hidden dimensions 2")
    parser.add_argument("--h3", type=int, default=0, help="hidden dimensions 3")
    parser.add_argument("--prefit_lr", type=float, default=0.01, help="learning rate for prefitting")
    parser.add_argument("--train_lr", type=float, default=0.01, help="learning rate for training")
    parser.add_argument("--optim", type=str, default='Adam', help="optimizer used for both prefitting and training")
    parser.add_argument("--max_fit_epoch", type=int, default=200, help="maximum number of pre-fitting epochs")
    parser.add_argument("--mse_thresh", type=float, default=0.003, help="mse threshold of pre-fitting error")
    parser.add_argument("--max_train_epoch", type=int, default=1000, help="maximum number of training epochs")
    parser.add_argument("--conv_patience", type=int, default=100, help="patience parameter in convergence criteria")
    parser.add_argument("--conv_thresh", type=float, default=0.02, help="threshold parameter in convergence criteria")
    parser.add_argument("--conv_val", type=float, default=-100., help="stop value for convergence")
    parser.add_argument("--mode", type=str, default="hybrid", choices=["hybrid","regular"],help="threshold parameter in convergence criteria")
    parser.add_argument("--switch_patience", type=int, default=15, help="patience parameter in switching criteria")

    args = parser.parse_args()
    
    FILE_DIR = args.file_dir
    N = args.N
    RANGE = args.range
    assert RANGE > 0
    CARD = args.card
    LAM = args.lam
    HIDDEN_DIMS = [args.h1,args.h2,args.h3]
    PREFIT_LR = args.prefit_lr
    TRAIN_LR = args.train_lr
    OPTIM = args.optim
    MAX_FIT_EPOCH = args.max_fit_epoch
    MSE_THRESH = args.mse_thresh
    MAX_TRAIN_EPOCH = args.max_train_epoch
    CONV_PATIENCE = args.conv_patience
    CONV_THRESH = args.conv_thresh
    CONV_VAL = args.conv_val
    MODE = args.mode
    SWITCH_PATIENCE = args.switch_patience
    
    if HIDDEN_DIMS[-1] == 0:
        HIDDEN_DIMS.pop()
    print('HIDDEN_DIMS',HIDDEN_DIMS)
    
    np.random.seed(1)
    Xinit = np.array(np.random.uniform(high=RANGE, low=-RANGE, size=(N,2)), dtype=np.float32)

    # -- arg_dict: dictionary storing model parameters
            #    -- 'Rips_linear_1L': 'mel', 'dim', 'card'
            #    -- 'Rips_linear': 'hidden_dims','mel', 'dim', 'card'
            #    -- 'Rips_GCN': 'hidden_dims','mel', 'dim', 'card','lam'

    arg_dict = {'hidden_dims':HIDDEN_DIMS,'mel':12.,'dim':1,'card':CARD,'lam':LAM}

    GCN = ModelTrain('Rips_GCN',arg_dict,Xinit,FILE_DIR,r=RANGE,prefit_lr=PREFIT_LR,train_lr=TRAIN_LR,optim=OPTIM)
    
    with tf.device('/CPU:0'):
        GCN.prefit(mse_thresh=MSE_THRESH,max_train_epoch=MAX_FIT_EPOCH)
        
        if MODE == "hybrid":
            GCN.train_hybrid(
                max_train_epoch=MAX_TRAIN_EPOCH,conv_patience=CONV_PATIENCE,conv_thresh=CONV_THRESH,conv_val=CONV_VAL,switch_patience=SWITCH_PATIENCE
            )
        else:
            GCN.train_regular(
                max_train_epoch=MAX_TRAIN_EPOCH,conv_patience=CONV_PATIENCE,conv_thresh=CONV_THRESH,conv_val=CONV_VAL,
            )
    