
        
import os, glob, subprocess, argparse

parser = argparse.ArgumentParser()
parser.add_argument("--datasets", type=str, default="glue_cxg")

args = parser.parse_args()

datasets = args.datasets
max_epochs = {
    'cola':[[10,10]],
    'mnli':[[1,1]],
    'rte':[[20,20]],
    'sst2':[[10,10]],
    'qnli':[[3,3]],
    'qqp':[[1,1]],
    'wnli':[[20,20]],
    'mrpc':[[10,10]],
}



for qwen_size in ['base','large']:
    base_model = f'google/flan-t5-{qwen_size}'
    output_model_prefix = f'my-flan-t5{qwen_size}/'


    datas = []
    if 'glue' in datasets :
        datas += ['sst2','cola','rte','qnli','mrpc','wnli','qqp','mnli']
    if 'cxg' in datasets:
        datas += ['clone_detection','code_to_text','text_to_code','defect_detection']
    if 'cola' in datasets:
        datas = ['cola'] 
        
    template = 'accelerate launch train.py --base_model={} --output_model_prefix={} --dataset={} --steps={} --batch_size={}  &> {}\n'

    os.makedirs(f'tencent_logs/{output_model_prefix}/',exist_ok=True)

    batch = 256
    while batch>0:
        for data in datas:
            cases = 24000
            if '_' in data:
                cases = 48000
            epo = max_epochs[data][0][0]
            log = f'tencent_logs/{output_model_prefix}{batch}_{data}.log'
            cmd = template.format(base_model, output_model_prefix, data, 9600, batch, log)
            print(cmd)
            subprocess.call([cmd],shell=True)
            subprocess.call(['python clean_steps.py'],shell=True)
        batch = batch // 2
        

'''

CUDA_VISIBLE_DEVICES=0 accelerate launch train.py --base_model=google/flan-t5-base --output_model_prefix=my-flan-t5base/ --dataset=sst2 --cases=24000 --batch_size=256 

'''
        
    
# datasets = ['clone_detection','code_to_text','text_to_code','defect_detection']
# template = 'accelerate launch train.py --base_model=Qwen/Qwen2.5-3B-Instruct --output_model_prefix=my-qwen3b/ --dataset=sst2 --cases=4800 --batch_size=16 &>tencent_logs/qwenlogs_sst2.log\n'

# with open('cxg.sh','w+') as f:
#     for d in datasets:
#         f.write(template.replace('sst2',d))
        
        
        
        
        
        
        
        
# template = 'accelerate launch train.py --base_model=google/t5-v1_1-large --dataset=sst2 --output_model_prefix=my-t5-large/ '
# template2 = '&>tencent_logs/t5xl_logs_sst2.log\n'

    



# with open('t5_large.sh','w+') as f:
#     for d in datasets:
#         f.write(template.replace('sst2',d)+\
#             f'  --epochs={max_epochs[d][0][0]} --batch_size={max_batches[d][1][1]}  '+\
#                 template2.replace('sst2',d))
