import os
from dataclasses import dataclass, field
from typing import Dict, Optional
import pickle

import hydra
import setproctitle
import torch
from accelerate import Accelerator
from datasets import Dataset, load_dataset
from omegaconf import DictConfig, OmegaConf
from peft import LoraConfig, PeftModel
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from trl import DPOConfig, DPOTrainer

title = 'llama-attacks-train-dpo-llama3'
setproctitle.setproctitle(title)


@torch.no_grad()
def init_reference_lora(dpo_params, reference_params):
    for dpo, ref in zip(dpo_params, reference_params):
        ref.data = dpo.detach().clone()

def get_stack_exchange_paired(
    data_dir: str,
    num_proc=24,
) -> Dataset:
    """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts are structured as follows:
      "<s>+instruct"
    """
    # csv file
    dataset = load_dataset('csv', data_files=data_dir)
    dataset = dataset["train"]
    original_columns = dataset.column_names
    print(original_columns)

    def return_prompt_and_responses(samples) -> Dict[str, str]:
        return {
            # "prompt": ["<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + instruct + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" for instruct in samples["instruct"]],
            "prompt": ["<s>" + instruct for instruct in samples["instruct"]],
            "chosen": samples["chosen"],
            "rejected": samples["rejected"],
        }

    return dataset.map(
        return_prompt_and_responses,
        batched=True,
        num_proc=num_proc,
        remove_columns=original_columns,
    )


@hydra.main(version_base=None, config_path="conf")
def main(cfg: DictConfig):
    tqdm.write("Starting run...")
    tqdm.write(f"Using parameters: \n{OmegaConf.to_yaml(cfg)}")

    set_seed(cfg.seed)

    # 1. load a pretrained model
    model_params = cfg.llm_params
    torch_dtype = torch.float
    if model_params.model_dtype == "float16":
        torch_dtype = torch.float16
    elif model_params.model_dtype == "bfloat16":
        torch_dtype = torch.bfloat16
    
    lora_params = cfg.lora_params

    model = AutoModelForCausalLM.from_pretrained(
        model_params.model_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch_dtype,
        device_map={"": Accelerator().local_process_index},
    )
    # model.config.use_cache = False

    # load lora adapter twice with different name [https://huggingface.co/docs/trl/main/en/dpo_trainer#using-option-3---load-the-adapter-twice]
    model = PeftModel.from_pretrained(
        model, 
        lora_params.lora_checkpoint,
        is_trainable=True,
        adapter_name="dpo-base",
        autocast_adapter_dtype=False,
    )

    # lora_config = LoraConfig.from_pretrained(lora_params.lora_checkpoint)
    # model.add_adapter(peft_config=lora_config, adapter_name="reference")
    model.load_adapter(lora_params.lora_checkpoint, adapter_name="reference")
    model = model.to(torch.bfloat16)

    # dpo_params = [param for name, param in model.named_parameters() if 'dpo' in name]
    # reference_params = [param for name, param in model.named_parameters() if 'reference' in name]
    # init_reference_lora(dpo_params, reference_params)

    # assert all([torch.allclose(x, y) for x, y in zip(dpo_params, reference_params)])
    
    tokenizer = AutoTokenizer.from_pretrained(model_params.model_path,)
    if tokenizer.pad_token is None:
        if tokenizer.unk_token is not None:
            tokenizer.pad_token = tokenizer.unk_token
        else:
            tokenizer.pad_token = tokenizer.eos_token
    
    # load dataset
    train_dataset = get_stack_exchange_paired(data_dir=cfg.data_params.data_path)

    os.makedirs(cfg.output_dir, exist_ok=True)
    os.makedirs(cfg.log_dir, exist_ok=True)
    
    # set the dpo configuration
    dpo_params = cfg.dpo_params
    train_args = DPOConfig(
        model_adapter_name = "dpo-base",
        ref_adapter_name = "reference",
        beta = dpo_params.beta,
        seed=cfg.seed,
        bf16=dpo_params.bf16,
        run_name=cfg.run_name,
        logging_steps=dpo_params.logging_steps,
        output_dir=cfg.output_dir,
        per_device_train_batch_size=dpo_params.per_device_train_batch_size,
        learning_rate=cfg.opt_params.lr,
        lr_scheduler_type=cfg.opt_params.lr_scheduler_type,
        num_train_epochs=dpo_params.num_train_epochs
    )

    dpo_trainer = DPOTrainer(
        model, 
        args=train_args,
        beta=dpo_params.beta,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
    )
    dpo_trainer.train()

    output_dir = os.path.join(cfg.output_dir, "final_checkpoint")
    dpo_trainer.model.save_pretrained(output_dir)
    tqdm.write("Finished!")



if __name__ == "__main__":
    main()