import os
import subprocess
from concurrent.futures import ThreadPoolExecutor
import torch
import wandb

models = ['gemma-7b']
dataset_types = ['basic_capibara']
learning_rates = ['5e-5']
losses = ['tampo']
betas = ['025']
available_gpus = [0,1]
config_dir = f"alpaca_eval/src/alpaca_eval/models_configs"
command_template_step_1 = "VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES={gpus} alpaca_eval generate_completions --model_configs {config_file}"
command_template_step_2 = "VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES={gpus} alpaca_eval evaluate_from_model --model_configs {config_file} --reference_model_configs gpt4 --annotators_config alpaca_eval_vllm_llama3_70b_fn"
version='v15'
gpus_per_process = 2
start_n = 0
max_workers = 1

yaml_template = """
{model}-{loss}-{datatype}-{learning_rate}-{beta}-{version}:
  prompt_template: {prompt_template}
  fn_completions: "vllm_local_completions"
  completions_kwargs:
    model_name: "{model_path}"
    enable_lora: True
    lora_path: "alignment-handbook/data/{model}-{loss}-{datatype}-{learning_rate}-{beta}-{version}"
    parallel_size: {num_gpus_per_process}
    model_kwargs:
      dtype: 'bfloat16'
    max_new_tokens: 2048
    temperature: 0.0
    top_p: 1.0
    batch_size: 8
  requires_chatml: True
  pretty_name: "{model} {loss} {datatype}"
"""

def create_config_file(config_file, model, dataset_type, learning_rate, loss, beta, num_gpus_per_process, version):
    if model == "gemma-7b":
        model_path = "google/gemma-7b"
        prompt_template = "gemma-2-9b-it-DPO/prompt.txt"
    elif model == "gemma-7b-sft":
        model_path = "google/gemma-7b"
        prompt_template = "gemma-2-9b-it-DPO/prompt.txt"
    elif model == "mistral-7b":
        model_path = "mistralai/Mistral-7B-v0.3"
        prompt_template = "Mixtral-8x7B-Instruct-v0.1/togetherai_prompt.txt"
    yaml_adapted = yaml_template.format(model=model, datatype=dataset_type, loss=loss, prompt_template=prompt_template, model_path=model_path, num_gpus_per_process=num_gpus_per_process, learning_rate=learning_rate, beta=beta, version=version)
    with open(config_file, 'w') as f:
        f.write(yaml_adapted)

def run_command(command):
    subprocess.run(command, shell=True)
    
def main():
    commands = []
    version_n = start_n
    for model in models:
        for dataset_type in dataset_types:
            for learning_rate in learning_rates:
                for loss in losses:
                    for beta in betas:
                        version_name = f"{version}{version_n}"
                        # version_n += 1
                        # make a directory with the name of the config file name
                        config_name = f"{model}-{loss}-{dataset_type}-{learning_rate}-{beta}-{version_name}"
                        config_directory = f"{config_dir}/{config_name}"
                        print(config_directory)
                        # if not os.path.exists(config_directory):
                        os.makedirs(config_directory, exist_ok=True)
                        config_file_path = f"{config_directory}/configs.yaml"
                        print("config file path", config_file_path)
                        create_config_file(config_file_path, model, dataset_type, learning_rate, loss, beta, gpus_per_process, version_name)
                        # print(f"Config file {config_file_path} does not exist. Skipping.")
                        commands.append(command_template_step_1.format(gpus="{gpus}", config_file=config_name))

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for i in range(0, len(commands), max_workers):
            futures = []
            for j in range(max_workers):
                if i + j < len(commands):
                    gpus = ",".join([str(available_gpus[(j*gpus_per_process) + k]) for k in range(gpus_per_process)])
                    command = commands[i + j].format(gpus=gpus)
                    print(command)
                    futures.append(executor.submit(run_command, command))
            for future in futures:
                future.result()
                
    commands = []
    version_n = start_n
    for model in models:
        for dataset_type in dataset_types:
            for learning_rate in learning_rates:
                for loss in losses:
                    for beta in betas:
                        version_name = f"{version}{version_n}"
                        # version_n += 1
                        config_name = f"{model}-{loss}-{dataset_type}-{learning_rate}-{beta}-{version_name}"
                        result_file_path = f"alpaca_eval/results/{config_name}/model_outputs.json"
                        result_file_path_alt = f"alpaca_eval/results/{config_name}/model_outputs2.json"
                        out_result_file_path = f"alpaca_eval/results/{config_name}/model_outputs_mod.json"
                        with open(result_file_path, "rt") as fin:
                            with open(out_result_file_path, "wt") as fout:
                                for line in fin:
                                    fout.write(line.replace('<|im_start|>model ', '').replace('<|im_start|>model', '').replace('<|im_end|>', ''))
                        # Ensure file handles are closed before renaming
                        fin.close()
                        fout.close()
                        # remove old file
                        os.rename(result_file_path, result_file_path_alt)
                        # rename file
                        os.rename(out_result_file_path, result_file_path)
                        config_directory = f"{config_dir}/{config_name}"
                        commands.append((command_template_step_2.format(gpus="{gpus}", config_file=config_name), config_directory, config_name))

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for i in range(0, len(commands), max_workers):
            futures = []
            for j in range(max_workers):
                if i + j < len(commands):
                    command, config_directory, config_name = commands[i + j]
                    gpus = ",".join([str(available_gpus[(j*gpus_per_process) + k]) for k in range(gpus_per_process)])
                    command = command.format(gpus=gpus)
                    print(command)
                    futures.append(executor.submit(run_command, command))
            for future in futures:
                future.result()

if __name__ == "__main__":
    main()