import os
import subprocess
import argparse
import time
import concurrent.futures




checkpoints = ['14k_cllama_code_m2w_scrp_clueweb_wiki_7000_hist_stop_8e']



def create_ppl_script(checkpoint, i):
    ppl_tasks = ['m2w_task', 'm2w_domain', 'm2w_website', 'wiki_test', 'clueweb_test']
    ppl_string = f'''#!/bin/bash
#SBATCH --job-name=ppl_{checkpoint}_{ppl_tasks[i]}
#SBATCH --output=ppl_{checkpoint}_{ppl_tasks[i]}.out
#SBATCH --error=ppl_{checkpoint}_{ppl_tasks[i]}.err


#SBATCH --partition=compute
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:A100:1
#SBATCH --time=8:00:00
#SBATCH --mail-type=ALL
#SBATCH --mail-user=bo@andrew.cmu.edu

source ~/.bashrc
conda activate agent_train

python pred_batch.py \
--checkpoints-folder /data/b_ou/ckpts/output_{checkpoint} \
--index {i}
    '''
    with open(f'/data/b_ou/agent/fine_tune/pred_batch_record/tmp_{i}.sh', 'w') as file:
        file.write(ppl_string)
        
    return


def create_eval_script(eval_type, eval_file_name, checkpoints_folder, checkpoint, i):
    model_name = checkpoints_folder.split('/')[-1]
    eval_string = f'''#!/bin/bash
#SBATCH --job-name=pred_{model_name}_{checkpoint}_{eval_type}
#SBATCH --output=pred_{model_name}_{checkpoint}_{eval_type}.out
#SBATCH --error=pred_{model_name}_{checkpoint}_{eval_type}.err
#SBATCH --partition=compute
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:A100:1
#SBATCH --time=16:00:00
#SBATCH --mail-type=ALL
#SBATCH --mail-user=bo@andrew.cmu.edu

source ~/.bashrc
conda activate agent_train
accelerate launch --main_process_port 2951{i} --num_processes 1 --num_machines 1 --mixed_precision "no" /data/b_ou/agent-model/LLaMA-Factory/src/train_bash.py \
--stage sft \
--model_name_or_path /data/b_ou/ckpts/cllama/models--codellama--CodeLlama-7b-hf/snapshots/bc5283229e2fe411552f55c71657e97edf79066c/ \
--do_predict \
--dataset {eval_file_name} \
--dataset_dir /data/b_ou/agent/data/code_scrp/ \
--template llama2 \
--finetuning_type full \
--checkpoint_dir {checkpoints_folder}/{checkpoint} \
--output_dir {checkpoints_folder}/pred_{eval_type} \
--per_device_eval_batch_size 1 \
--predict_with_generate True \
--use_safetensors True \
--cutoff_len 4096 \
--do_sample False \
--bf16 True

    '''
    with open(f'/data/b_ou/agent/fine_tune/validation/tmp_{eval_type}.sh', 'w') as file:
        file.write(eval_string)
        
    return

def get_checkpoint_dirs(checkpoints_folder):
    print(checkpoints_folder)
    res = []
    for f in os.listdir(checkpoints_folder):
        if 'checkpoint' in f and os.path.isdir(os.path.join(checkpoints_folder, f)):
            res.append(f)
    return res

def run_script(i,checkpoints_folder,checkpoint, eval_types, eval_file_names):
    env = os.environ.copy()
    create_eval_script(eval_types[i], eval_file_names[i],checkpoints_folder,checkpoint, i)
    subprocess.run(['sbatch', f'/data/b_ou/agent/fine_tune/validation/tmp_{eval_types[i]}.sh'],env=env)

eval_types = ['task', 'website', 'domain']
eval_file_names = ['m2w_code_scrape_test_task_hist', 'm2w_code_scrape_test_website_hist', 'm2w_code_scrape_test_domain_hist']



for checkpoint in checkpoints:
    for i in range(len(eval_types)):
        all_ckpt_nums = set(get_checkpoint_dirs(f'/data/b_ou/ckpts/output_{checkpoint}'))
        create_ppl_script(checkpoint, i)
        subprocess.run(['sbatch', f'/data/b_ou/agent/fine_tune/pred_batch_record/tmp_{i}.sh'])
        for ckpt_num in all_ckpt_nums:
            run_script(i,f'/data/b_ou/ckpts/output_{checkpoint}',ckpt_num, eval_types, eval_file_names)
        
        
        