import argparse
import math

parser = argparse.ArgumentParser()

################ Transformer hyper-parameter ################
parser.add_argument('--nhid', type=int, default=64)
parser.add_argument('--nlayers', type=int, default=3)
parser.add_argument('--nips', type=int, default=16)
parser.add_argument('--nhead', type=int, default=1)
parser.add_argument('--activation', type=str, default='gelu')
parser.add_argument('--num_freqs', type=int, default=5)
parser.add_argument('--AGF_depth', type=int, default=5)
parser.add_argument('--alpha', type=float, default=2.0)
parser.add_argument('--beta', type=float, default=-1.0)
parser.add_argument('--fixI', type=int, default=1)      # 1: True, 0: False
#####################################################################

################ Newly added for DeepOnet ################
parser.add_argument('--branch_depth', type=int, default=2)
parser.add_argument('--trunk_depth', type=int, default=3)
parser.add_argument('--width', type=int, default=50)
#####################################################################


################ Newly added for DeepOnet ################
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--model', type=str, default='PDE-PFN')
parser.add_argument('--data', type=str, default='2D_cdr')
parser.add_argument('--is_grid', type=int, default=0)
# 1: True, 0: False
parser.add_argument('--system', type=str, default='2D_cdr')
# Seen coeff / Inter coeff / Extra coeff / Seen coeff given noisy for famliy of 2D cdr equations
# Seen dataset / Unseen dataset for SWE
# Unseen dataset for CFD
# Initial value
parser.add_argument('--extrapolation', type=int, default=0)
# 1: Temporal extrapolation(operator learning) / 0: Temporal interpolation
parser.add_argument('--numerical', type=int, default=1)
# 1: train with numeric prior in 2D cdr / 0: train with PINN prior in 2D cdr
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--lr_decay', type=float, default=0.95)
parser.add_argument('--step', type=int, default=5)
parser.add_argument('--scheduler', type=str, default='constant')
# constant / OneCycleLR / exponential / linear
parser.add_argument('--total', type=int, default=0)
# 1: True, 0: False
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--batch', type=int, default=1)

##### Data generation information #####
parser.add_argument('--N_train_shot', type=int, default=0)
parser.add_argument('--N_valid', type=int, default=0)
parser.add_argument('--N_test', type=int, default=0) 
parser.add_argument('--spatial_size', type=str, default="2pi")
parser.add_argument('--xgrid', type=int, default=0)
parser.add_argument('--nt', type=int, default=101)
parser.add_argument('--max_time', type=float, default=1.0)

##### PINN hyperparameter #####
parser.add_argument('--PINN_epoch', type=int, default=20000)
parser.add_argument('--PINN_batch', type=int, default=1000)
parser.add_argument('--threshold', type=float, default = 1e-4)
parser.add_argument('--PINN_based_prior_ratio', type=int, default = 0)

##### For 2D cdr #####
parser.add_argument('--beta_min', type=float, default = 0.0)
parser.add_argument('--beta_y_min', type=float, default = 0.0)
parser.add_argument('--nu_min', type=float, default = 0.0)
parser.add_argument('--nu_y_min', type=float, default = 0.0)
parser.add_argument('--rho_min', type=float, default = 0.0)
parser.add_argument('--epsilon_min', type=float, default = 0.0)
parser.add_argument('--theta_min', type=float, default = 0.0)

parser.add_argument('--beta_max', type=float, default = 0.0)
parser.add_argument('--beta_y_max', type=float, default = 0.0)
parser.add_argument('--nu_max', type=float, default = 0.0)
parser.add_argument('--nu_y_max', type=float, default = 0.0)
parser.add_argument('--rho_max', type=float, default = 0.0)
parser.add_argument('--epsilon_max', type=float, default = 0.0)
parser.add_argument('--theta_max', type=float, default = 0.0)

parser.add_argument('--beta_step', type=float, default=1.)
parser.add_argument('--nu_step', type=float, default=1.)
parser.add_argument('--rho_step', type=float, default=1.)
parser.add_argument('--epsilon_step', type=float, default=1.)
parser.add_argument('--theta_step', type=float, default=1.)

def get_config():
    return parser.parse_args()