import os
import subprocess
from concurrent.futures import ThreadPoolExecutor
from itertools import product
from transformers import AutoModelForCausalLM
from peft import PeftModel
import torch

def run_command(command):
    subprocess.run(command, shell=True)

yaml_template = """
# Model arguments
model_name_or_path: {model_formal_name}
model_revision: main
tokenizer_name_or_path: {tokenizer}
torch_dtype: bfloat16
attn_implementation: flash_attention_2
  

# LoRA arguments
use_peft: true
lora_r: 128
lora_alpha: 256
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj

# Data training arguments
dataset_mixer:
  {dataset_name}: 1.0
auto_insert_empty_system_msg: false

resume_from_checkpoint: false
dataset_splits:
- train
preprocessing_num_workers: 12
chat_template: "{chat_template}"

# SFT trainer config
bf16: true
do_eval: False
po_cutoff: {po_cutoff}
po_cutoff_default: {po_cutoff_default}
sft_threshold: {sft_threshold}
alpha_schedule: {alpha_schedule}
beta: {beta}
gradient_accumulation_steps: {gradient_accumulation_steps}
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
hub_strategy: every_save
learning_rate: {learning_rate_float}
log_level: info
logging_steps: 5  
logging_strategy: steps
lr_scheduler_type: cosine
max_length: 2048
max_steps: -1
num_train_epochs: {epochs}
loss_type: {loss}
output_dir:  data/{model}-{loss}-{datatype}-{learning_rate}-{beta_name}-{version}
overwrite_output_dir: false
per_device_eval_batch_size: 8
per_device_train_batch_size: {per_device_train_batch_size}
push_to_hub: true
remove_unused_columns: true
report_to:
- tensorboard
- wandb
save_strategy: "steps"
save_steps: 100
save_total_limit: 1
seed: 42
warmup_ratio: 0.1
"""

multi_gpu_yaml_template = """
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: {num_processes}
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: {port}
"""

def update_multi_gpu_file(num_processes):
    for i in range(4):
        multi_gpu_file = f"recipes/gemma-7b/mpo/multi_gpu{i}_{num_processes}.yaml"
        multi_gpu_yaml_adapted = multi_gpu_yaml_template.format(
            num_processes=num_processes,
            port=29500+i,
        )
        with open(multi_gpu_file, 'w') as f:
            f.write(multi_gpu_yaml_adapted)
    
def get_model_name(model):
    if model == "gemma-7b":
        model_formal_name = "google/gemma-7b"
        tokenizer = "google/gemma-7b-it"
        chat_template = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}"
    elif model == "gemma-7b-sft":
        model_formal_name = "alignment-handbook/data/gemma-7b-sft-basic-5e-5-005-v140-full"
        tokenizer = "google/gemma-7b-it"
        chat_template = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}"
        tokenizer = "google/gemma-7b-it"
        chat_template = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}"
        tokenizer = "google/gemma-7b-it"
        chat_template = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}"
    elif model == "gemma-7b-sft-capibara":
        model_formal_name = "alignment-handbook/data/gemma-7b-sft-basic_capibara-5e-5-000-v140-full"
        tokenizer = "google/gemma-7b-it"
        chat_template = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}"
    elif model == "gemma-7b-sft-deita":
        model_formal_name = "HuggingFaceH4/zephyr-7b-gemma-sft-v0.1"
        tokenizer = "google/gemma-7b-it"
        chat_template = "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}"
    elif model == "mistral-7b":
        model_formal_name = "mistralai/Mistral-7B-v0.3"
        tokenizer = "mistralai/Mistral-7B-Instruct-v0.3"
        # chat_template = """{%- if messages[0][\\"role\\"] == \\"system\\" %}\\n    {%- set system_message = messages[0][\\"content\\"] %}\\n    {%- set loop_messages = messages[1:] %}\\n{%- else %}\\n    {%- set loop_messages = messages %}\\n{%- endif %}\\n{%- if not tools is defined %}\\n    {%- set tools = none %}\\n{%- endif %}\\n{%- set user_messages = loop_messages | selectattr(\\"role\\", \\"equalto\\", \\"user\\") | list %}\\n\\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\\n{%- set ns = namespace() %}\\n{%- set ns.index = 0 %}\\n{{- bos_token }}\\n{%- for message in loop_messages %}\\n    {%- if message[\\"role\\"] == \\"user\\" %}\\n        {%- if tools is not none and (message == user_messages[-1]) %}\\n            {{- \\"[AVAILABLE_TOOLS] [\\" }}\\n            {%- for tool in tools %}\\n                {%- set tool = tool.function %}\\n                {{- '{\\"type\\": \\"function\\", \\"function\\": {' }}\\n                {%- for key, val in tool.items() if key != \\"return\\" %}\\n                    {%- if val is string %}\\n                        {{- '\\"' + key + '\\": \\"' + val + '\\"' }}\\n                    {%- else %}\\n                        {{- '\\"' + key + '\\": ' + val|tojson }}\\n                    {%- endif %}\\n                    {%- if not loop.last %}\\n                        {{- \\", \\" }}\\n                    {%- endif %}\\n                {%- endfor %}\\n                {{- \\"}}\\" }}\\n                {%- if not loop.last %}\\n                    {{- \\", \\" }}\\n                {%- else %}\\n                    {{- \\"]\\" }}\\n                {%- endif %}\\n            {%- endfor %}\\n            {{- \\"[/AVAILABLE_TOOLS]\\" }}\\n            {%- endif %}\\n        {%- if loop.last and system_message is defined %}\\n            {{- \\"[INST] \\" + system_message + \\"\\\\n\\\\n\\" + message[\\"content\\"] + \\"[/INST]\\" }}\\n        {%- else %}\\n            {{- \\"[INST] \\" + message[\\"content\\"] + \\"[/INST]\\" }}\\n        {%- endif %}\\n    {%- elif message.tool_calls is defined and message.tool_calls is not none %}\\n        {{- \\"[TOOL_CALLS] [\\" }}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- set out = tool_call.function|tojson %}\\n            {{- out[:-1] }}\\n            {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\\n                {{- raise_exception(\\"Tool call IDs should be alphanumeric strings with length 9!\\") }}\\n            {%- endif %}\\n            {{- ', \\"id\\": \\"' + tool_call.id + '\\"}' }}\\n            {%- if not loop.last %}\\n                {{- \\", \\" }}\\n            {%- else %}\\n                {{- \\"]\\" + eos_token }}\\n            {%- endif %}\\n        {%- endfor %}\\n    {%- elif message[\\"role\\"] == \\"assistant\\" %}\\n        {{- \\" \\" + message[\\"content\\"]|trim + eos_token}}\\n    {%- elif message[\\"role\\"] == \\"tool_results\\" or message[\\"role\\"] == \\"tool\\" %}\\n        {%- if message.content is defined and message.content.content is defined %}\\n            {%- set content = message.content.content %}\\n        {%- else %}\\n            {%- set content = message.content %}\\n        {%- endif %}\\n        {{- '[TOOL_RESULTS] {\\"content\\": ' + content|string + \\", \\" }}\\n        {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\\n            {{- raise_exception(\\"Tool call IDs should be alphanumeric strings with length 9!\\") }}\\n        {%- endif %}\\n        {{- '\\"call_id\\": \\"' + message.tool_call_id + '\\"}[/TOOL_RESULTS]' }}\\n    {%- else %}\\n        {{- raise_exception(\\"Only user and assistant roles are supported, with the exception of an initial optional system message!\\") }}\\n    {%- endif %}\\n{%- endfor %}\\n"""
        chat_template = "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}"
    elif model == "mistral-7b-sft":
        model_formal_name = "alignment-handbook/data/mistral-7b-sft-basic-5e-5-000-v140-full"
        tokenizer = "mistralai/Mistral-7B-Instruct-v0.3"
        chat_template = "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}"
        # chat_template = """{%- if messages[0]['role'] == 'system' %}\\n    {%- set system_message = messages[0]['content'] %}\\n    {%- set loop_messages = messages[1:] %}\\n{%- else %}\\n    {%- set loop_messages = messages %}\\n{%- endif %}\\n{%- if not tools is defined %}\\n    {%- set tools = none %}\\n{%- endif %}\\n{%- set user_messages = loop_messages | selectattr(\\"role\\", \\"equalto\\", \\"user\\") | list %}\\n\\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\\n{%- set ns = namespace() %}\\n{%- set ns.index = 0 %}\\n{{- bos_token }}\\n{%- for message in loop_messages %}\\n    {%- if message[\\"role\\"] == \\"user\\" %}\\n        {%- if tools is not none and (message == user_messages[-1]) %}\\n            {{- \\"[AVAILABLE_TOOLS] [\\" }}\\n            {%- for tool in tools %}\\n                {%- set tool = tool.function %}\\n                {{- '{\\"type\\": \\"function\\", \\"function\\": {' }}\\n                {%- for key, val in tool.items() if key != \\"return\\" %}\\n                    {%- if val is string %}\\n                        {{- '\\"' + key + '\\": \\"' + val + '\\"' }}\\n                    {%- else %}\\n                        {{- '\\"' + key + '\\": ' + val|tojson }}\\n                    {%- endif %}\\n                    {%- if not loop.last %}\\n                        {{- \\", \\" }}\\n                    {%- endif %}\\n                {%- endfor %}\\n                {{- \\"}}\\" }}\\n                {%- if not loop.last %}\\n                    {{- \\", \\" }}\\n                {%- else %}\\n                    {{- \\"]\\" }}\\n                {%- endif %}\\n            {%- endfor %}\\n            {{- \\"[/AVAILABLE_TOOLS]\\" }}\\n            {%- endif %}\\n        {%- if loop.last and system_message is defined %}\\n            {{- \\"[INST] \\" + system_message + \\"\\\\n\\\\n\\" + message[\\"content\\"] + \\"[/INST]\\" }}\\n        {%- else %}\\n            {{- \\"[INST] \\" + message[\\"content\\"] + \\"[/INST]\\" }}\\n        {%- endif %}\\n    {%- elif message.tool_calls is defined and message.tool_calls is not none %}\\n        {{- \\"[TOOL_CALLS] [\\" }}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- set out = tool_call.function|tojson %}\\n            {{- out[:-1] }}\\n            {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\\n                {{- raise_exception(\\"Tool call IDs should be alphanumeric strings with length 9!\\") }}\\n            {%- endif %}\\n            {{- ', \\"id\\": \\"' + tool_call.id + '\\"}' }}\\n            {%- if not loop.last %}\\n                {{- \\", \\" }}\\n            {%- else %}\\n                {{- \\"]\\" + eos_token }}\\n            {%- endif %}\\n        {%- endfor %}\\n    {%- elif message[\\"role\\"] == \\"assistant\\" %}\\n        {{- \\" \\" + message[\\"content\\"]|trim + eos_token}}\\n    {%- elif message[\\"role\\"] == \\"tool_results\\" or message[\\"role\\"] == \\"tool\\" %}\\n        {%- if message.content is defined and message.content.content is defined %}\\n            {%- set content = message.content.content %}\\n        {%- else %}\\n            {%- set content = message.content %}\\n        {%- endif %}\\n        {{- '[TOOL_RESULTS] {\\"content\\": ' + content|string + \\", \\" }}\\n        {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\\n            {{- raise_exception(\\"Tool call IDs should be alphanumeric strings with length 9!\\") }}\\n        {%- endif %}\\n        {{- '\\"call_id\\": \\"' + message.tool_call_id + '\\"}[/TOOL_RESULTS]' }}\\n    {%- else %}\\n        {{- raise_exception(\\"Only user and assistant roles are supported, with the exception of an initial optional system message!\\") }}\\n    {%- endif %}\\n{%- endfor %}\\n"""
    return model_formal_name, tokenizer, chat_template

def create_config_file(config_file, model, dataset_type, loss, learning_rate, beta, po_cutoff, po_cutoff_default, sft_threshold, alpha_schedule, gpus_per_process, version_name):
    if dataset_type == 'basic':
        dataset_name = 'argilla/dpo-mix-7k'
    elif dataset_type == 'basic_long':
        dataset_name = 'argilla/ultrafeedback-binarized-preferences-cleaned'
    elif dataset_type == "basic_capibara":
        dataset_name = "argilla/distilabel-capybara-dpo-7k-binarized"
    
    model_formal_name, tokenizer, chat_template = get_model_name(model)
    if gpus_per_process == 2:
        per_device_train_batch_size = 1
        gradient_accumulation_steps = 32
    elif gpus_per_process == 1:
        per_device_train_batch_size = 1
        gradient_accumulation_steps = 64
    elif gpus_per_process == 4:
        per_device_train_batch_size = 1
        gradient_accumulation_steps = 16
    else:
        raise ValueError("Invalid number of gpus per process")
        
    print("model_formal_name", model_formal_name)
    learning_rate_float = learning_rate.replace('e-', '.0e-').replace('e+', '.0e+')
    if loss == 'mpo' or loss == 'tampo' or loss == 'orpo' or loss == 'cpo':
        epochs = 4
    elif loss == 'dpo' or loss == 'simpo':
        epochs = 1
    elif loss == 'sft':
        epochs = 3
    else:
        raise ValueError("Invalid loss type")
    yaml_adapted = yaml_template.format(
        model_formal_name=model_formal_name,
        tokenizer=tokenizer,
        model=model, 
        datatype=dataset_type, 
        version=version_name, 
        loss=loss, 
        learning_rate=learning_rate,
        learning_rate_float=learning_rate_float,
        beta=beta, 
        epochs=epochs,
        po_cutoff=po_cutoff, 
        po_cutoff_default=po_cutoff_default,
        sft_threshold=sft_threshold,
        alpha_schedule=alpha_schedule,
        beta_name=beta.replace('.', ''), 
        dataset_name=dataset_name, 
        chat_template=chat_template,
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps
    )
    with open(config_file, 'w') as f:
        f.write(yaml_adapted)
    return model_formal_name


def generate_commands_and_run(available_gpus, command_template, version, gpus_per_process, version_n, learning_rate, param_combinations, param_combinations2=None):
    commands = []
    start_version_n = version_n
    for model, beta, dataset_type, loss, po_cutoff, po_cutoff_default, sft_threshold, alpha_schedule in param_combinations:
        version_name = f"{version}{version_n}"
        version_n += 1
        config_directory = f"recipes/{model}/{loss}"
        os.makedirs(config_directory, exist_ok=True)
        if loss == 'dpo':
            loss_command_name = 'dpo'
        else:
            loss_command_name = 'mpo'
        config_file = f"{config_directory}/config_lora_{dataset_type}-{beta.replace('.', '')}-{learning_rate}-{version_name}.yaml"
        model_formal_name = create_config_file(config_file, model, dataset_type, loss, learning_rate, beta, po_cutoff, po_cutoff_default, sft_threshold, alpha_schedule, gpus_per_process, version_name)
        commands.append(command_template.format(
            gpus="{gpus}", mi="{mi}", model=model, loss=loss, 
            dataset=dataset_type, beta_name=beta.replace('.', ''), 
            learning_rate=learning_rate, version=version_name,
            loss_command=loss_command_name,
            num_processes=gpus_per_process,
        ))
    
    with ThreadPoolExecutor(max_workers=max_processes) as executor:
        for i in range(0, len(commands), max_processes):
            futures = []
            for j in range(max_processes):
                if i + j < len(commands):
                    gpus = ",".join([str(available_gpus[(j*gpus_per_process) + k]) for k in range(gpus_per_process)])
                    command = commands[i + j].format(gpus=gpus, mi=j+1)
                    print(command)
                    futures.append(executor.submit(run_command, command))
            for future in futures:
                future.result()

    version_n = start_version_n
    if param_combinations2 is None:
        return
    print("hello")
    for model, beta, dataset_type, loss, po_cutoff, po_cutoff_default, sft_threshold, alpha_schedule in param_combinations2:
        print("test")
        model_formal_name, _, _ = get_model_name(model)
        version_name = f"{version}{version_n}"
        version_n += 1
        if loss == "sft":
            print("Saving full model for sft")
            base_model = AutoModelForCausalLM.from_pretrained(model_formal_name, torch_dtype=torch.bfloat16)
            peft_model_id =f"alignment-handbook/data/{model}-{loss}-{dataset_type}-{learning_rate}-{beta.replace('.', '')}-{version_name}"
            model_to_merge = PeftModel.from_pretrained(base_model, peft_model_id)
            merged_model = model_to_merge.merge_and_unload() 
            merged_model.save_pretrained(f"alignment-handbook/data/{model}-{loss}-{dataset_type}-{learning_rate}-{beta.replace('.', '')}-{version_name}-full")
            print(f"Model saved to {peft_model_id}-full")
        
max_processes = 1
def main():
    
    models = ['mistral-7b-sft']
    dataset_types = ['basic']
    losses = ['simpo']
    alpha_schedule = ['linear']
    betas = ['0.05']
    po_cutoffs = [0.0, 0.2, 0.5, 1.5, 2.0]
    sft_thresholds = [0.3]
    po_cutoffs_default = [0.0]
    available_gpus = [4,5,6,7]
    learning_rate = '5e-7'
    command_template = "CUDA_VISIBLE_DEVICES={gpus} ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/gemma-7b/mpo/multi_gpu{mi}_{num_processes}.yaml scripts/run_{loss_command}.py recipes/{model}/{loss}/config_lora_{dataset}-{beta_name}-{learning_rate}-{version}.yaml"
    version='v14'
    gpus_per_process = 4

    version_n = 2
    
    update_multi_gpu_file(gpus_per_process)
    param_combinations = product(
        models, betas, dataset_types, losses, po_cutoffs, 
        po_cutoffs_default, sft_thresholds, alpha_schedule
    )
    param_combinations2 = None

    generate_commands_and_run(available_gpus, command_template, version, gpus_per_process, version_n, learning_rate, param_combinations, param_combinations2)
    

if __name__ == "__main__":
    main()