# The python script to run our method with 5 stages
#   1. generate.py generate a set of responses for each prompt in the dataset using the latest model
#   2. get_grad.py calculate the gradient of each response w.r.t. the model's parameters
#   3. selection_iter.py select the most informative responses based on the gradients
#   4. label.py label the selected responses
#   5. train_iter.py train the model on the labeled responses
#   6. go back to step 1
# This script runs the above 5 stages for 5 iterations using the corresponding scripts

import json
import os
import subprocess
import argparse
import math
# add . to the path so that preference_datasets can be imported
import sys
sys.path.append('.')
cache_dir = '.cache'
from preference_datasets import model_to_config, eval_batch_size, train_accu_step

def run_stage(log_file_path, stage, dataset, model, init_model_path, iteration, original_data_dir, dim, normalize, select_percentage, sub_iter_every, num_response, tensor_parallel_size, exp_name, avaliable_gpus, args):
    if len(avaliable_gpus) >= 4:
        tensor_parallel_size = 4
    elif len(avaliable_gpus) == 3:
        tensor_parallel_size = 2
    else:
        tensor_parallel_size = len(avaliable_gpus)
    with open(log_file_path, 'r') as f:
        logs = json.loads(f.readline())
    model_path = logs['latest_model_path']
    latest_generated_data = logs['latest_generated_data']
    latest_selected_data = logs['latest_selected_data']
    latest_labeled_data = logs['latest_labeled_data']
    latest_gradient_path = logs['latest_gradient_path']
    if stage == 1:
        method_name = 'ours'
        model_id = model
        gpus_ = avaliable_gpus[:tensor_parallel_size]
        gpus_ = ','.join(gpus_)
        model_ = model.split('/')[-1]
        cmd = f'CUDA_VISIBLE_DEVICES={gpus_} python generate.py --method_name {method_name} --dataset {dataset} --model_id {model_id} --iter {iteration} --original_data_dir {original_data_dir} --num_response {num_response} --model_path {model_path} --tensor_parallel_size {tensor_parallel_size} --log_file_path {log_file_path} --exp_name {exp_name} 2>&1 | tee "logs/{exp_name}_generate_{model_}_{dataset}_iter_{iteration}"'
    elif stage == 2:
        # read the previous data and append to the generated data and save it
        pre_generate_path = logs['pre_generate_path']
        selected_data = []
        with open(latest_generated_data, 'r') as f:
            for line in f:
                selected_data.append(json.loads(line))
        selected_data = selected_data[1:]
        numdp = len(selected_data)
        for path in pre_generate_path:
            selected_data_ = []
            with open(path, 'r') as f:
                for line in f:
                    selected_data_.append(json.loads(line))
            selected_data_ = selected_data_[1:]
            selected_data.extend(selected_data_)
        selected_data = [{'selection': False}] + selected_data
        with open(latest_generated_data, 'w') as f:
            for data in selected_data:
                f.write(json.dumps(data) + '\n')
        logs['numdp'] = numdp
        with open(log_file_path, 'w') as f:
            f.write(json.dumps(logs) + '\n')
        # remove the previous gradient file {exp_name}_{dataset}_{model_to_config[model]}_lora_get_grad_iter_{iteration}
        if os.path.exists(f'{cache_dir}/{exp_name}_{dataset}_{model_to_config[model]}_lora_get_grad_iter_{iteration}'):
            os.remove(f'{cache_dir}/{exp_name}_{dataset}_{model_to_config[model]}_lora_get_grad_iter_{iteration}')
            print(f'removed the previous gradient file {exp_name}_{dataset}_{model_to_config[model]}_lora_get_grad_iter_{iteration}')
        
        bash_script = f'''
convert_to_array() {{
local input="$1"
local -a arr=()

# Split the input string into individual characters
for (( i=0; i<${{#input}}; i++ )); do
    arr+=("${{input:i:1}}")
done

# Echo the array elements
echo "${{arr[@]}}"
}}

# Capture the array by command substitution
gpus=($(convert_to_array "$1"))
num_gpus=${{#gpus[@]}}
cnt=0
for gpu in "${{gpus[@]}}"; do
    CUDA_VISIBLE_DEVICES="$gpu" python -u get_grad.py \
            model={model_to_config[model]} \
            datasets=[{dataset}] \
            loss=dpo \
            loss.beta=0.1 \
            exp_name={exp_name}_{dataset}_{model_to_config[model]}_lora_get_grad_iter_{iteration} \
            batch_size=1 \
            trainer=BasicTrainerGradInfo \
            sample_during_eval=false \
            model.fsdp_policy_mp=bfloat16 \
            lora.enable=true \
            model.archive={model_path} \
            selection_enable=true \
            selection_info_path={latest_generated_data} \
            minimum_log_interval_secs=3 \
            grad_split.enable=true \
            grad_split.n_splits=$num_gpus \
            grad_split.split_idx=$cnt \
            grad_split.gradient_type=full \
            log_file_path={log_file_path} \
            wandb.enabled=false 2>&1 | tee "logs/{exp_name}_get_grad_$gpu" &
    cnt=$((cnt+1))
done
wait
                        '''
        with open(f'logs/get_grad_tmp_{exp_name}_iter{iteration}.sh', 'w') as f:
            f.write(bash_script)
        cmd = f'bash logs/get_grad_tmp_{exp_name}_iter{iteration}.sh {avaliable_gpus}'
    elif stage == 3:
        model_ = model.split('/')[-1]
        numdp = logs['numdp']
        cmd = f'python selection_iter_addpre.py --dataset {dataset} --model {model_} --iteration {iteration} --gradient_path {latest_gradient_path} --original_data_dir {latest_generated_data} --dim {dim} --normalize {normalize} --select_percentage {select_percentage} --sub_iter_every {sub_iter_every} --log_file_path {log_file_path} --exp_name {exp_name} --numdp {numdp} 2>&1 | tee "logs/{exp_name}_selection_{model_}_{dataset}_iter_{iteration}"'
    elif stage == 4:
        model_ = model.split('/')[-1]
        cmd = f'python label.py --dataset {dataset} --data_path {latest_selected_data} --log_file_path {log_file_path}  2>&1 | tee "logs/{exp_name}_label_{model_}_{dataset}_iter_{iteration}"'


    elif stage == 5:

        with open(log_file_path, 'r') as f:
            logs = json.loads(f.readline())
        pre_data_path = logs['pre_data_path']
        pre_generate_path = logs['pre_generate_path']
        latest_labeled_data = logs['latest_labeled_data']
        latest_selected_data = logs['latest_selected_data']
        pre_data_path.append(latest_labeled_data)
        pre_generate_path.append(latest_selected_data)
        logs['pre_data_path'] = pre_data_path
        logs['pre_generate_path'] = pre_generate_path
        with open(log_file_path, 'w') as f:
            f.write(json.dumps(logs) + '\n')
        
        # combine all the previous data and save it
        pre_data_path = logs['pre_data_path']
        selected_data = []
        for path in pre_data_path:
            selected_data_ = []
            with open(path, 'r') as f:
                for line in f:
                    selected_data_.append(json.loads(line))
            selected_data_ = selected_data_[1:]
            selected_data.extend(selected_data_)
        selected_data = [{'selection': False}] + selected_data
        path_ = latest_labeled_data + '_combined'
        with open(path_, 'w') as f:
            for data in selected_data:
                f.write(json.dumps(data) + '\n')
        model_ = model.split('/')[-1]
        gpus_ = ','.join(avaliable_gpus)
        accu_dteps = math.ceil(train_accu_step[model] * (8 / len(avaliable_gpus)))
        batch_size_ = eval_batch_size[model]
        batch_size_ = int(batch_size_ * len(avaliable_gpus) / 4)
        cmd = f'''CUDA_LAUNCH_BLOCKING=1 \
    CUDA_VISIBLE_DEVICES={gpus_} \
    python -u train_iter.py \
                model={model_to_config[model]} \
                datasets=[{dataset}] \
                loss=dpo \
                loss.beta=0.1 \
                exp_name={exp_name}_{dataset}_dpo_{model_to_config[model]}_lora_selection_our_iter_{iteration} \
                gradient_accumulation_steps={accu_dteps} \
                batch_size=50 \
                eval_batch_size={batch_size_} \
                trainer=BasicTrainer \
                sample_during_eval=false \
                lora.enable=true \
                model.archive={model_path} \
                model.archive_reference={init_model_path} \
                selection_enable=true \
                selection_info_path={path_} \
                test_selection_info_path={args.test_data_dir} \
                eval_every=150 \
                n_epochs=3 \
                lr=2e-5 \
                save_while_training=false \
                n_eval_examples=1000 \
                log_file_path={log_file_path} \
                model.reference_dtype=float32 2>&1 | tee "logs/{exp_name}_training_{model_}_{dataset}_iter_{iteration}"'''
    else:
        raise ValueError(f'Invalid stage: {stage}')
    
    print(f'Running stage {stage} for iteration {iteration}')
    print(cmd)
    # if the subprocess is not successful, raise an error
    subprocess.run(cmd, shell=True, check=True)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Run the 5 stages')
    parser.add_argument('--dataset', type=str, default='tldr', help='Dataset name')
    parser.add_argument('--model', type=str, default='meta-llama/Llama-2-7b-hf', help='Model id')
    parser.add_argument('--model_path', type=str, default='tldr_sft_llama2_2024-12-13_13-16-23_968854/LATEST/hf_model', help='Model initial path')
    parser.add_argument('--original_data_dir', type=str, default='tldr_train_data_1000.jsonl', help='Path to the original data directory')
    parser.add_argument('--test_data_dir', type=str, default='tldr_test_data_labeled.jsonl', help='Path to the test data directory')
    parser.add_argument('--dim', type=int, default=8192, help='Dimension of the gradients')
    parser.add_argument('--normalize', type=bool, default=True, help='Whether to normalize the gradients')
    parser.add_argument('--select_percentage', type=int, default=5, help='Percentage of responses to select')
    parser.add_argument('--sub_iter_every', type=int, default=1, help='Number of responses to select in each sub-iteration')
    parser.add_argument('--num_response', type=int, default=3, help='Number of responses to generate')
    parser.add_argument('--tensor_parallel_size', type=int, default=4, help='Tensor parallel size')
    parser.add_argument('--exp_name', type=str, default='3000dp3res', help='The name of the experiment')
    parser.add_argument('--avaliable_gpus', type=str, default='01234567', help='The avaliable gpus')

    args = parser.parse_args()
    init_model_path = args.model_path
    model_ = args.model.split('/')[-1]
    log_file_path = f'log_{args.exp_name}_{model_}_{args.dataset}_ours.jsonl'
    logs = {'latest_model_path': args.model_path, 'latest_generated_data': None, 'latest_selected_data': None, 'latest_labeled_data': None, 'latest_gradient_path': None, 'pre_data_path': [], 'pre_generate_path': []}
    # if not os.path.exists(log_file_path):
    with open(log_file_path, 'w') as f:
        f.write(json.dumps(logs) + '\n')
    # else:
    #     print(f'Log file {log_file_path} already exists, please check whether you want remove it first, use the old log file now')
    for iteration in range(5):
        for stage in range(1, 6):
            run_stage(log_file_path, stage, args.dataset, args.model, init_model_path, iteration, args.original_data_dir, args.dim, args.normalize, args.select_percentage, args.sub_iter_every, args.num_response, args.tensor_parallel_size, args.exp_name, args.avaliable_gpus, args)

    # delete the checkpoints
    dir = '.cache'
    # get the subdir with string xx, must be a directory
    import os
    subdirs = []
    for subdir in os.listdir(dir):
        if os.path.isdir(os.path.join(dir, subdir)) and args.exp_name in subdir and 'dpo' in subdir:
            subdirs.append(subdir)

    # delete the files in each subdir, but keep the sub-subdir named 'adapter'
    for subdir in subdirs:
        if os.path.exists(os.path.join(dir, subdir, 'BEST')):
            for file in os.listdir(os.path.join(dir, subdir, 'BEST')):
                if file != 'adapter':
                    print(os.path.join(dir, subdir, 'BEST', file))
                    os.remove(os.path.join(dir, subdir, 'BEST', file))
    print('Finished running all stages for 10 iterations')

