import os
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--config_path', type=str, default='/scratch/project_465000903/NLPContextScaling/nanoGPT/config/')
parser.add_argument('--context_length', type=int, default=384)
parser.add_argument('--start_from',type=str,default='0.0')
parser.add_argument('--percent_name',type=str,default='100p0')
parser.add_argument('--long_context',type=bool,default=True)

parser.add_argument('--train_token_number', type=int, default=500000000)
args = parser.parse_args()
CONTEXT_LENGTH = args.context_length
PERCENT_NAME = args.percent_name
START_FROM = args.start_from

TOKEN_NUMBER = args.train_token_number

TOKEN_NUMBER_M_unit = TOKEN_NUMBER//1000000

percent_name_to_str_dict = {
    '0p1':'0.1', # 9B*0.1% = 9M tokens, ~ 18 steps per epoch.
    '0p5':'0.5',
    '1p0':'1.0',
    '2p0':'2.0',
    '4p0':'4.0',
    '6p0':'6.0',
    '10p0':'10.0',
    '100p0':'100.0',
}
BATCH_SIZE_DICT = {
    10927:1,
    8192:1,
    6144:1,
    4096:3,
    2048:6,
    1024:12,
    512:24,
}

GRADIENT_ACC_DICT = {
    10927:9/8,
    8192:3/2,
    6144:2,
    4096:1,
    2048:1,
    1024:1,
    512:1
}

BATCH_SIZE = BATCH_SIZE_DICT[CONTEXT_LENGTH]
GRADIENT_ACC = GRADIENT_ACC_DICT[CONTEXT_LENGTH]

EVAL_ITERS = 4*7864320 // (BATCH_SIZE * CONTEXT_LENGTH)
# assert 7864320 % (BATCH_SIZE * CONTEXT_LENGTH) == 0
file_content = r'''
# config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB
# launch as the following (e.g. in a screen session) and wait ~5 days:
# $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py

wandb_log = True
n_layer = 6
wandb_project = 'NLP_Long_contextScaling_033101ver'''+r'''pct_'''+r'''_ctl_fp32_6layer'
wandb_run_name='gpt6layer_'''+str(TOKEN_NUMBER_M_unit)+r'''Mtokens_'''+str(CONTEXT_LENGTH)+r'''ctl'

# these make the total batch size be ~0.1M
# batch_size * block_size * grad_acc_step = 24576
batch_size = '''+ str(BATCH_SIZE) + r'''
# block_size =  # context_length
block_size = ''' + str(CONTEXT_LENGTH) + r'''
gradient_accumulation_steps = ''' + str(int(GRADIENT_ACC*8)) + r'''

# this makes total number of tokens be TOKEN_NUMBER
max_iters = ''' + str(TOKEN_NUMBER) + r''' // (batch_size * block_size * gradient_accumulation_steps)
lr_decay_iters = ''' + str(TOKEN_NUMBER) + r''' // (batch_size * block_size * gradient_accumulation_steps)
warmup_iters = 1000


# eval stuff, eval_interval is how often to run eval, currently set to 100 * 0.5M = 50M tokens.
eval_interval = (max_iters+4)//3 # we would evaluate at the last epoch in train.py
eval_iters = ''' + str(EVAL_ITERS) + r''' # this * batch_size * block_size = 7.8M tokens
log_interval = 10

# weight decay
weight_decay = 1e-1

# 9B * 0.5% = 45M tokens
train_file_name = 'Long1000CtlSubset_0330_train_P'''+percent_name_to_str_dict[PERCENT_NAME]+r'''_SFrom'''+START_FROM+r'''.bin'

out_dir = 'ckpts/Long1000CtlSubset_0330_train_P'''+ str(TOKEN_NUMBER)+r'''ctl'''+str(CONTEXT_LENGTH)+r''''
'''

# write to file.
config_file_path = os.path.join(args.config_path,'LONGCONTEXT_train_gpt2_033101ver'+str(CONTEXT_LENGTH)+'ctl_100pct_'+str(TOKEN_NUMBER_M_unit)+'M_small.py')

with open(config_file_path, 'w') as f:
    f.write(file_content)
    print(f'File written to {config_file_path}')
