import argparse

def arguments():


    # ###################### Save file name ######################
    parser = argparse.ArgumentParser()

    parser.add_argument('--args_type', type = str, default = 'None')

    parser.add_argument('--title', type = int, default = 0)
    parser.add_argument('--gpu_id', type = str, default = '0')


    ###################### fixed ######################
    parser.add_argument('--FIELD_list', type = str, nargs='+', default= ["along.x", "along.y", "along.0.5x+0.9y", "along.0.7x+0.7y", "along.0.25x-1y"])
    parser.add_argument('--PSI_list', type = str, nargs='+', default= ["identity", "ellipse", "fish", "kite", "custom"])


    ############################################
    parser.add_argument('--arg_field', type = int, default = 0)  # background field; integer index
    parser.add_argument('--arg_conformal', type = int, default = 2)  # shape
    parser.add_argument('--custom_coeffs', type = float, nargs = '+', default = [])  # [a1, a2, a3, ...]

    # ###################### Arbitrary background field ###################### when arg_field == 2
    parser.add_argument('--complex_background', type = float, nargs='+', default = [1, 0]) # first one is stand for real part, second one is for imaginary part

    ############################################ Interface design
    parser.add_argument('--p_nn_flag', type = str, default = "False") # whether the interface be NN. 
    #### If p is a neural network ####
    parser.add_argument('--layers_interface', type = int, nargs='+', default = [2, 20, 20, 20, 20, 2])
    parser.add_argument('--activation_interface', type = str, default = 'tanh') 
    parser.add_argument('--p_nn_order', type = int, default = 30)
    
    #### Else if p has a coeffs-expansion ####
    parser.add_argument('--order_interface', type = int, default = 10) # order of interface function p(x)


    ############################################ Learning
    parser.add_argument('--lr_pinn', type = float, default = 1e-3) # learning rate for NN
    parser.add_argument('--lr_inv', type = float, default = 1e-1) # learnin g rate for interface function
    parser.add_argument('--plot_interval', type = int, default = 100)
    parser.add_argument('--loss_tracking_interval', type = int, default = 500)
    parser.add_argument('--num_epochs', type = int, default = 1)
    parser.add_argument('--S_training', type = int, default = 25000)
    parser.add_argument('--gamma', type = float, default = 0.7) # learning decay rate
    parser.add_argument('--reg_flag', type = str, default = "False") # regularization
    parser.add_argument('--reg_weight', type = float, default = 1e-5) # regularization
    parser.add_argument('--init_flag', type = str, default = "False") # initialization





    ############################################ 
    parser.add_argument('--sigma_c', type = int, default = 5) # for interior
    parser.add_argument('--sigma_m', type = int, default = 1) # for exterior
    parser.add_argument('--eps_bd', type = float, default = 1e-6) # tubular neighborhood.
    parser.add_argument('--L_dom', type = float, default = 5.0)
    parser.add_argument('--R_bd', type = float, default = 1.0)


    ############################################ sampling points
    parser.add_argument('--import_flag', type = str, default = "False")
    parser.add_argument('--import_title', type = int, default = 0)    
    parser.add_argument('--Fixed', type = int, default = 1)
    parser.add_argument('--S_carte', type = int, default = 150)
    parser.add_argument('--S_boundary', type = int, default = 6000)
    parser.add_argument('--S_int_angular', type = int, default = 100)
    parser.add_argument('--S_int_radial', type = int, default = 60)


    # ###################### For numerical computation ######################
    parser.add_argument('--Conformal_order', type = int, default = 10)
    parser.add_argument('--Matrix_truncation', type = int, default = 30)
    parser.add_argument('--W_expansion', type = int, default = 11)


    # ###################### For u-NN structures ######################
    parser.add_argument('--weights', type = float, nargs='+', default= [0.3, 0.2, 0.2, 0.3])
    parser.add_argument('--layers' , type = int, nargs='+', default= [2, 20, 20, 20, 20, 1])
    parser.add_argument('--activation', type = str, default= 'tanh')

    args, _ = parser.parse_known_args()



    if args.args_type == "None":
        raise ValueError("args_type is not given.")
    
    if args.args_type == "A":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"

    elif args.args_type == "aa":  # classical-PINNs
        args.p_nn_flag = "True"
        args.reg_flag = "True"
        args.p_nn_order = 20
        args.lr_pinn = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
    
    elif args.args_type == "Stability_3_CoCo":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.sigma_c = 3

    elif args.args_type == "Stability_3_classical":  # classical-PINNs
        args.p_nn_flag = "True"
        args.reg_flag = "True"
        args.p_nn_order = 20
        args.lr_pinn = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.sigma_c = 3

    elif args.args_type == "Stability_4_CoCo":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.sigma_c = 4

    elif args.args_type == "Stability_4_classical":  # classical-PINNs
        args.p_nn_flag = "True"
        args.reg_flag = "True"
        args.p_nn_order = 20
        args.lr_pinn = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.sigma_c = 4

    elif args.args_type == "Stability_5_CoCo":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.sigma_c = 5

    elif args.args_type == "Stability_5_classical":  # classical-PINNs
        args.p_nn_flag = "True"
        args.reg_flag = "True"
        args.p_nn_order = 20
        args.lr_pinn = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.sigma_c = 5

    elif args.args_type == "Stability_6_CoCo":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.sigma_c = 6

    elif args.args_type == "Stability_6_classical":  # classical-PINNs
        args.p_nn_flag = "True"
        args.reg_flag = "True"
        args.p_nn_order = 20
        args.lr_pinn = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.sigma_c = 6

    elif args.args_type == "Stability_7_CoCo":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.sigma_c = 7

    elif args.args_type == "Stability_7_classical":  # classical-PINNs
        args.p_nn_flag = "True"
        args.reg_flag = "True"
        args.p_nn_order = 20
        args.lr_pinn = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.sigma_c = 7


    ## Hyperparameter_tuning ##
    ###########################
    ## =========== ##
    elif args.args_type == "B":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.lr_inv = 1e-4
        args.lr_pinn = 1e-4
        args.gamma = 0.8
        args.init_flag = "True"


    elif args.args_type == "bb":  # classical-PINNs
        args.p_nn_flag = "True"
        args.reg_flag = "True"
        args.p_nn_order = 20
        args.lr_pinn = 1e-4
        args.gamma = 0.8
        args.init_flag = "True"


    ## =========== ##
    ## similar with [A, aa] but 50000 iterations
    elif args.args_type == "C":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.lr_inv = 2e-3
        args.lr_pinn = 2e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.S_training = 50000

    elif args.args_type == "cc":  # classical-PINNs
        args.p_nn_flag = "True"
        args.reg_flag = "True"
        args.p_nn_order = 20
        args.lr_pinn = 2e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.S_training = 50000


    ## =========== ##
    elif args.args_type == "D":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.lr_inv = 5e-4
        args.lr_pinn = 5e-5
        args.gamma = 0.95
        args.init_flag = "True"

    elif args.args_type == "dd":  # classical-PINNs
        args.p_nn_flag = "True"
        args.reg_flag = "True"
        args.p_nn_order = 20
        args.lr_pinn = 5e-5
        args.gamma = 0.95
        args.init_flag = "True"    


    ## =========== ##
    elif args.args_type == "E":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.lr_inv = 1e-3
        args.lr_pinn = 1e-4
        args.gamma = 0.75
        args.init_flag = "True"

    elif args.args_type == "ee":  # classical-PINNs
        args.p_nn_flag = "True"
        args.reg_flag = "True"
        args.p_nn_order = 20
        args.lr_pinn = 1e-4
        args.gamma = 0.75
        args.init_flag = "True"   


    ## =========== ##
    elif args.args_type == "F":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.lr_inv = 2e-3
        args.lr_pinn = 3e-4
        args.gamma = 0.75
        args.init_flag = "True"

    elif args.args_type == "ff":  # classical-PINNs
        args.p_nn_flag = "True"
        args.reg_flag = "True"
        args.p_nn_order = 20
        args.lr_pinn = 3e-4
        args.gamma = 0.75
        args.init_flag = "True"  


    ## =========== ##
    elif args.args_type == "G":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.lr_inv = 6e-3
        args.lr_pinn = 2e-3
        args.gamma = 0.7
        args.init_flag = "True"

    elif args.args_type == "gg":  # classical-PINNs
        args.p_nn_flag = "True"
        args.reg_flag = "True"
        args.p_nn_order = 20
        args.lr_pinn = 2e-3
        args.gamma = 0.7
        args.init_flag = "True"  


    #### For additional experiments ####
    ####################################


    ###########################
    ## Increasing Order of interface:

    elif args.args_type == "Apdx_IO_n3":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.arg_conformal = 3  # kite-shape
        args.order_interface = 3  # change the order_interface to 3.
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"

    elif args.args_type == "Apdx_IO_n5":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.arg_conformal = 3  # kite-shape
        args.order_interface = 5  # change the order_interface to 5.
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"

    elif args.args_type == "Apdx_IO_n7":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.arg_conformal = 3  # kite-shape
        args.order_interface = 7  # change the order_interface to 7.
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"    

    elif args.args_type == "Apdx_IO_n9":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.arg_conformal = 3  # kite-shape
        args.order_interface = 9  # change the order_interface to 9.
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"  

    elif args.args_type == "Apdx_IO_n10":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.arg_conformal = 3  # kite-shape
        args.order_interface = 10  # change the order_interface to 10.
        args.Conformal_order = 15
        args.Matrix_truncation = 40
        args.W_expansion = 15
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"  

    elif args.args_type == "Apdx_IO_n15":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.arg_conformal = 3  # kite-shape
        args.order_interface = 15  # change the order_interface to 15.
        args.Conformal_order = 20
        args.Matrix_truncation = 40
        args.W_expansion = 20
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"  

    elif args.args_type == "Apdx_IO_n20":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.arg_conformal = 3  # kite-shape
        args.order_interface = 20  # change the order_interface to 20.
        args.Conformal_order = 25
        args.Matrix_truncation = 50
        args.W_expansion = 25
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"  

    elif args.args_type == "Apdx_IO_n30":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.arg_conformal = 3  # kite-shape
        args.order_interface = 30  # change the order_interface to 30.
        args.Conformal_order = 30
        args.Matrix_truncation = 50
        args.W_expansion = 30
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"  

    ################################
    ## General Background fields:

    elif args.args_type == "Apdx_GB_type2":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.arg_conformal = 3  # kite-shape
        args.arg_field = 2  # "along.0.5x+0.9y"; H(x, y) = 0.5x + 0.9y
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"      
        args.S_training = 50000
        
    elif args.args_type == "Apdx_GB_type3":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.arg_conformal = 3  # kite-shape
        args.arg_field = 3  # "along.0.7x+0.7y"; H(x, y) = 0.7x + 0.7y
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"   
        args.S_training = 50000

    elif args.args_type == "Apdx_GB_type4":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.arg_conformal = 3  # kite-shape
        args.arg_field = 4  # "along.0.25x-1y"; H(x, y) = 0.25x - y
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"   
        args.S_training = 50000


    ## General Conformal mappings:

    elif args.args_type == "Apdx_GC_type1":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.arg_conformal = 4
        args.custom_coeffs = [0.15, 0.11, 0.07, 0.03, -0.01, -0.02, -0.03]
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.S_training = 50000

    elif args.args_type == "Apdx_GC_type2":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 20
        args.arg_conformal = 4
        args.custom_coeffs = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02, 0.01, -0.005, -0.003, 0.001]
        args.Conformal_order = 20
        args.W_expansion = 20
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.S_training = 50000

    elif args.args_type == "Apdx_GC_type3":  # CoCo-PINNs
        args.p_nn_flag = "False"
        args.reg_flag = "True"
        args.order_interface = 30
        args.arg_conformal = 4
        args.custom_coeffs = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.001]
        args.Conformal_order = 30
        args.Matrix_truncation = 50
        args.W_expansion = 30
        args.lr_inv = 1e-3
        args.gamma = 0.7
        args.init_flag = "True"
        args.S_training = 50000

    else:
        raise ValueError("args_type is not given as a known one.")

    args.sigma = [args.sigma_m, args.sigma_c]

    ## Report error for conformal map.
    if args.arg_conformal == 4 and len(args.custom_coeffs) == 0:
        raise ValueError("custom conformal mapping is entered, but any coefficient is given.")
    
    elif args.arg_conformal != 4 and len(args.custom_coeffs) >= 1:
        raise ValueError("It is not a custom conformal mapping, but you entered the custom_coeffs")
    
    args.concat_points = [args.S_carte, args.S_boundary, args.S_int_angular, args.S_int_radial]

    if args.Fixed == 0 and args.num_epochs < 5:
        raise ValueError("The collocation points are not fixed. More number of epochs are needed. Check the number of S_training also.")

    return args