
        
import os, glob, subprocess, argparse

parser = argparse.ArgumentParser()
parser.add_argument("--llama_size", type=str, default="8B")
parser.add_argument("--datasets", type=str, default="glue")

args = parser.parse_args()

llama_size = args.llama_size
datasets = args.datasets


base_model = f'meta-llama/Llama-3.1-{args.llama_size}-Instruct'
output_model_prefix = f'my-llama{args.llama_size}/'


datas = []
if 'glue' in datasets :
    datas += ['sst2','cola','rte','qnli','mrpc','wnli','qqp','mnli']
if 'cxg' in datasets:
    datas += ['clone_detection','code_to_text','text_to_code','defect_detection']
    
template = 'accelerate launch train.py --base_model={} --output_model_prefix={} --dataset={} --cases={}  --batch_size={} &> {}\n'

os.makedirs(f'tencent_logs/{output_model_prefix}/',exist_ok=True)

batch = 64
while batch>0:
    for data in datas:
        cases = 2400
        if '_' in data:
            cases = 4800
        log = f'tencent_logs/{output_model_prefix}{batch}_{data}.log'
        cmd = template.format(base_model, output_model_prefix, data, cases, batch, log)
        subprocess.call([cmd],shell=True)
        subprocess.call(['python clean_steps.py'],shell=True)
    batch = batch // 2
    

        
        
    
# datasets = ['clone_detection','code_to_text','text_to_code','defect_detection']
# template = 'accelerate launch train.py --base_model=Qwen/Qwen2.5-3B-Instruct --output_model_prefix=my-qwen3b/ --dataset=sst2 --cases=4800 --batch_size=16 &>tencent_logs/qwenlogs_sst2.log\n'

# with open('cxg.sh','w+') as f:
#     for d in datasets:
#         f.write(template.replace('sst2',d))
        
        
        
        
        
        
        
        
# template = 'accelerate launch train.py --base_model=google/t5-v1_1-large --dataset=sst2 --output_model_prefix=my-t5-large/ '
# template2 = '&>tencent_logs/t5xl_logs_sst2.log\n'

    



# with open('t5_large.sh','w+') as f:
#     for d in datasets:
#         f.write(template.replace('sst2',d)+\
#             f'  --epochs={max_epochs[d][0][0]} --batch_size={max_batches[d][1][1]}  '+\
#                 template2.replace('sst2',d))
