import os
import argparse


parser = argparse.ArgumentParser()
std_rate = 1
parser.add_argument('--optim_multiplier', type=float, default=30)
parser.add_argument('--gpu_id', type=int, default=7)
parser.add_argument('--train_method', type=str, default='NTP')
parser.add_argument('--nh', type=int, default=1)
parser.add_argument('--nl', type=int, default=2)
parser.add_argument('--embedding_std', type=float, default=std_rate)
parser.add_argument('--qk_std', type=float, default=std_rate)
parser.add_argument('--vo_std', type=float, default=std_rate)
parser.add_argument('--mlp_std', type=float, default=std_rate)
args = parser.parse_args()

# std_rate = args.std_rate
optim_multiplier = args.optim_multiplier
target = 'composition'

train_method=args.train_method

lr = 1e-5
gpu_id = args.gpu_id
batch_size = 2048
scheduler = 'GradualWarmupScheduler_CosineAnnealingLR'
# scheduler = 'StepLR'
# model = 'GPT2_init_for_diff_part_prenorm'
model = 'GPT2_init_for_diff_part'
# model = 'GPT_sandwitchLN'
# model='GPT'
data_size = 900000

# xm0x mod seq-1 = 0，xelx else，
dname = ['13_xm0', '23_xm0', '43_xm0', '31_xm0', '32_xm0', '34_xm0'] + ['12_xm0', '14_xm0', '21_xm0', '41_xm0', '24_xm0', '42_xm0'] + ['11_xm0', '22_xm0', '33_xm0', '44_xm0']\
       +['13_xel', '23_xel', '43_xel', '31_xel', '32_xel', '34_xel'] + ['12_xel', '14_xel', '21_xel', '41_xel', '24_xel', '42_xel'] + ['11_xel', '22_xel', '33_xel', '44_xel']
dmode = ['13_xm0', '23_xm0', '43_xm0', '31_xm0', '32_xm0', '34_xm0'] + ['12_xm0', '14_xm0', '21_xm0', '41_xm0', '24_xm0', '42_xm0'] + ['11_xm0', '22_xm0', '33_xm0', '44_xm0']\
       +['13_xel', '23_xel', '43_xel', '31_xel', '32_xel', '34_xel'] + ['12_xel', '14_xel', '21_xel', '41_xel', '24_xel', '42_xel'] + ['11_xel', '22_xel', '33_xel', '44_xel']

# dtrain = [0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 0, 0] + [0, 0, 0, 0] \
#        + [1, 1, 1, 1, 1, 1] + [1, 1, 1, 1, 1, 1] + [1, 1, 1, 1] \
#        + [1, 0, 0, 1] + [1, 0, 0, 1]

# 1，3，4anchor，2
dtrain = [0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 0, 0] + [0, 0, 0, 0] \
       + [1, 1, 0, 1, 1, 0] + [1, 1, 1, 1, 1, 1] + [1, 1, 1, 1] 

###！！！！！！！！！！！！！！！！！！！

dshow = [0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 0, 0] + [0, 0, 0, 0] \
        + [0, 0, 1, 0, 0, 0] + [0, 0, 0, 0, 0, 0] + [0, 0, 0, 0]
dpercent = [1, 1, 1, 1, 1, 1] + [1, 1, 1, 1, 1, 1] + [1, 1, 1, 1] \
        + [9, 9, 9, 9, 9, 9] + [9, 9, 9, 9, 9, 9] + [9, 9, 9, 9] 

dn = ' '.join(map(str, dname))
dp = ' '.join(map(str, dpercent))
dmode = ' '.join(map(str, dmode))
dtrain = ' '.join(map(str, dtrain))
dshow = ' '.join(map(str, dshow))

feedforward = 1200
branch_num = 50
branch_d_feedforward = int(feedforward / branch_num)

# calculate_hessian = 'precondition_max_eigenvalue'
# calculate_hessian = 'max_eigenvalue'
calculate_hessian = 'grad_eigenvalue'

checkpoint = None


L,H=args.nl,args.nh

# proj_name=f'diff_init_part_{train_method}_{target}_{model}_14'
# 
# proj_name='last_token_model_scale_data_90w_normal_init_wo_43_34'
# proj_name='hessian_spectrum_iter'
# proj_name='condense_net_normal'
proj_name=f'{target}_spike_{model}_{calculate_hessian}_0420'

description='420，base model'

# proj_name='test'

# for seed in [1, 2, 3, 4, 5]:
for seed in [1]:
       # for std_rate in [std_rate1]:
       #  for L in [2]:
        for L in [8]:
              for WD in [0.1]:
                     H = 1    
                     if checkpoint is None:
                            # dir_suffix = f'{L}L{H}H_embedding_std_{args.embedding_std}_qk_std_{args.qk_std}_vo_std_{args.vo_std}_mlp_std_{args.mlp_std}_om{optim_multiplier}_WD{WD}_bs{batch_size}_condense_normalize'
                            # dir_suffix = f'{L}L{H}H_embedding_std_{args.embedding_std}_qk_std_{args.qk_std}_vo_std_{args.vo_std}_mlp_std_{args.mlp_std}_om{optim_multiplier}_WD{WD}_bs{batch_size}_branch_{branch_num}'
                            dir_suffix = f'{L}L{H}H_embedding_std_{args.embedding_std}_qk_std_{args.qk_std}_vo_std_{args.vo_std}_mlp_std_{args.mlp_std}_om{optim_multiplier}_WD{WD}_bs{batch_size}_{args.train_method}_{lr}_dm400_om{optim_multiplier}'
                            suffix = f'seed{seed}'
                            # 
                            os.system(f'CUDA_VISIBLE_DEVICES={gpu_id} python3 -m main -data_size {data_size} -seed {seed} -func {target} -lr {lr} -m {model} \
                                                 -scheduler {scheduler} -ne 100 -nl {L} -nh {H} -bs {batch_size} -dir_suffix {dir_suffix} -pname {proj_name} -dk 200 -dv 200  --d_feedforward {feedforward} --branch_num {branch_num} --branch_d_feedforward {branch_d_feedforward}\
                                                 -dmode {dmode} -dp {dp} -dn {dn} -dtrain {dtrain} -dshow {dshow} -suffix {suffix} --train_method {args.train_method} -dm 400\
                                                 -ple 1 -pae 1 -plae 1 -sme 1 -wd {WD} -embedding_std {args.embedding_std} -qk_std {args.qk_std} -vo_std {args.vo_std} -mlp_std {args.mlp_std} -wd {WD}\
                                                 --optim_T_max 100 --optim_eta_min 1e-5 --optim_multiplier {optim_multiplier} --optim_total_epoch 10 --calculate_hessian {calculate_hessian} ')
                     if checkpoint is not None:
                            # dir_suffix = f'{L}L{H}H_embedding_std_{args.embedding_std}_qk_std_{args.qk_std}_vo_std_{args.vo_std}_mlp_std_{args.mlp_std}_om{optim_multiplier}_WD{WD}_bs{batch_size}_condense_normalize'
                            # dir_suffix = f'{L}L{H}H_embedding_std_{args.embedding_std}_qk_std_{args.qk_std}_vo_std_{args.vo_std}_mlp_std_{args.mlp_std}_om{optim_multiplier}_WD{WD}_bs{batch_size}_branch_{branch_num}'
                            dir_suffix = f'{L}L{H}H_embedding_std_{args.embedding_std}_qk_std_{args.qk_std}_vo_std_{args.vo_std}_mlp_std_{args.mlp_std}_om{optim_multiplier}_WD{WD}_bs{batch_size}_{args.train_method}_{lr}_dm400_om{optim_multiplier}'
                            suffix = f'seed{seed}'
                            # 
                            os.system(f'CUDA_VISIBLE_DEVICES={gpu_id} python3 -m main -data_size {data_size} -seed {seed} -func {target} -lr {lr} -m {model} \
                                                 -scheduler {scheduler} -ne 10 -nl {L} -nh {H} -bs {batch_size} -dir_suffix {dir_suffix} -pname {proj_name} -dk 200 -dv 200  --d_feedforward {feedforward} --branch_num {branch_num} --branch_d_feedforward {branch_d_feedforward}\
                                                 -dmode {dmode} -dp {dp} -dn {dn} -dtrain {dtrain} -dshow {dshow} -suffix {suffix} --train_method {args.train_method} -dm 400\
                                                 -ple 1 -pae 1 -plae 1 -sme 1 -wd {WD} -embedding_std {args.embedding_std} -qk_std {args.qk_std} -vo_std {args.vo_std} -mlp_std {args.mlp_std} -wd {WD}\
                                                 --optim_T_max 100 --optim_eta_min 1e-5 --optim_multiplier {optim_multiplier} --optim_total_epoch 10 --calculate_hessian {calculate_hessian} --load_checkpoint {checkpoint}')






