import os
import argparse


parser = argparse.ArgumentParser()
std_rate = 0.5
parser.add_argument('--optim_multiplier', type=float, default=10)
parser.add_argument('--gpu_id', type=int, default=4)
parser.add_argument('--train_method', type=str, default='LTP')
parser.add_argument('--nh', type=int, default=1)
parser.add_argument('--nl', type=int, default=2)
parser.add_argument('--embedding_std', type=float, default=std_rate)
parser.add_argument('--qk_std', type=float, default=std_rate)
parser.add_argument('--vo_std', type=float, default=std_rate)
parser.add_argument('--mlp_std', type=float, default=std_rate)
args = parser.parse_args()

# std_rate = args.std_rate
optim_multiplier = args.optim_multiplier
target = '3x_to_x'

train_method=args.train_method

# lr = 0.003
lr = 2e-5
gpu_id = args.gpu_id
batch_size = 1800
scheduler = 'GradualWarmupScheduler_CosineAnnealingLR'
# scheduler = None
# scheduler = 'StepLR'
model = 'GPT2_init_for_diff_part_prenorm'
# model = 'GPT_sandwitchLN'
# model='GPT'
data_size = 2000

# xm0x mod seq_len = 0，xelx else，
dname = ['train', 'test']
dmode = ['train', 'test']
doutlier = [0, 0]
dtrain = [1, 0]
dshow = [1, 1]
dpercent = [9, 1] # 90%，10%，900，100

dn = ' '.join(map(str, dname))
dp = ' '.join(map(str, dpercent))
dmode = ' '.join(map(str, dmode))
doutlier = ' '.join(map(str, doutlier))
dtrain = ' '.join(map(str, dtrain))
dshow = ' '.join(map(str, dshow))

feedforward = 1200
branch_num = 50
branch_d_feedforward = int(feedforward / branch_num)

# calculate_hessian = 'precondition_max_eigenvalue'
# calculate_hessian = 'None'
calculate_hessian = 'grad_eigenvalue'

checkpoint = None


L,H=args.nl,args.nh

# proj_name=f'diff_init_part_{train_method}_{target}_{model}_14'
# 
# proj_name='last_token_model_scale_data_90w_normal_init_wo_43_34'
# proj_name='hessian_spectrum_iter'
# proj_name='condense_net_normal'
proj_name=f'{target}_spike_{model}_{calculate_hessian}_seed1_0428'

description = 'cosine，30000epoch，spike，12L12H，spike，grad_eigenvalue，12L12H，weight—decay'

# proj_name='test'

# for seed in [1, 2, 3, 4, 5]:
for seed in [1]:
       # for std_rate in [std_rate1]:
       #  for L in [2]:
        for L in [12]:
              for WD in [0.5]:
                     H = 12   
                     if checkpoint is None:
                            # dir_suffix = f'{L}L{H}H_embedding_std_{args.embedding_std}_qk_std_{args.qk_std}_vo_std_{args.vo_std}_mlp_std_{args.mlp_std}_om{optim_multiplier}_WD{WD}_bs{batch_size}_condense_normalize'
                            # dir_suffix = f'{L}L{H}H_embedding_std_{args.embedding_std}_qk_std_{args.qk_std}_vo_std_{args.vo_std}_mlp_std_{args.mlp_std}_om{optim_multiplier}_WD{WD}_bs{batch_size}_branch_{branch_num}'
                            dir_suffix = f'{L}L{H}H_embedding_std_{args.embedding_std}_qk_std_{args.qk_std}_vo_std_{args.vo_std}_mlp_std_{args.mlp_std}_om{optim_multiplier}_WD{WD}_bs{batch_size}__{args.train_method}'
                            suffix = f'seed{seed}'
                            # 
                            os.system(f'CUDA_VISIBLE_DEVICES={gpu_id} python3 -m main -data_size {data_size} -seed {seed} -func {target} -lr {lr} -m {model} \
                                                 -scheduler {scheduler} -ne 30000 -nl {L} -nh {H} -bs {batch_size} -dir_suffix {dir_suffix} -pname {proj_name} -dk 64 -dv 64  --d_feedforward {feedforward} --branch_num {branch_num} --branch_d_feedforward {branch_d_feedforward}\
                                                 -dmode {dmode} -dp {dp} -dn {dn} -dtrain {dtrain} -dshow {dshow} -suffix {suffix} --train_method {args.train_method} -dm 200\
                                                 -ple 1000 -pae 1000 -plae 1000 -sme 1000 -wd {WD} -embedding_std {args.embedding_std} -qk_std {args.qk_std} -vo_std {args.vo_std} -mlp_std {args.mlp_std} -wd {WD}\
                                                 --optim_T_max 30000 --optim_eta_min 1e-5 --optim_multiplier {optim_multiplier} --optim_total_epoch 10 --calculate_hessian {calculate_hessian} ')
                     if checkpoint is not None:
                            # dir_suffix = f'{L}L{H}H_embedding_std_{args.embedding_std}_qk_std_{args.qk_std}_vo_std_{args.vo_std}_mlp_std_{args.mlp_std}_om{optim_multiplier}_WD{WD}_bs{batch_size}_condense_normalize'
                            # dir_suffix = f'{L}L{H}H_embedding_std_{args.embedding_std}_qk_std_{args.qk_std}_vo_std_{args.vo_std}_mlp_std_{args.mlp_std}_om{optim_multiplier}_WD{WD}_bs{batch_size}_branch_{branch_num}'
                            dir_suffix = f'{L}L{H}H_embedding_std_{args.embedding_std}_qk_std_{args.qk_std}_vo_std_{args.vo_std}_mlp_std_{args.mlp_std}_om{optim_multiplier}_WD{WD}_bs{batch_size}_{args.train_method}_continue'
                            suffix = f'seed{seed}'
                            # 
                            os.system(f'CUDA_VISIBLE_DEVICES={gpu_id} python3 -m main -data_size {data_size} -seed {seed} -func {target} -lr {lr} -m {model} \
                                                 -scheduler {scheduler} -ne 30000 -nl {L} -nh {H} -bs {batch_size} -dir_suffix {dir_suffix} -pname {proj_name} -dk 64 -dv 64  --d_feedforward {feedforward} --branch_num {branch_num} --branch_d_feedforward {branch_d_feedforward}\
                                                 -dmode {dmode} -dp {dp} -dn {dn} -dtrain {dtrain} -dshow {dshow} -suffix {suffix} --train_method {args.train_method} -dm 200\
                                                 -ple 1 -pae 1 -plae 1 -sme 1000 -wd {WD} -embedding_std {args.embedding_std} -qk_std {args.qk_std} -vo_std {args.vo_std} -mlp_std {args.mlp_std} -wd {WD}\
                                                 --optim_T_max 200 --optim_eta_min 1e-5 --optim_multiplier {optim_multiplier} --optim_total_epoch 10 --calculate_hessian {calculate_hessian} --load_checkpoint {checkpoint}')