

import torch

print(torch.cuda.is_available(), torch.version.cuda)




import tempfile
import os
from cruise.utilities.hdfs_io import hcopy
def load_hdfs_path(ckpt_path):
    if ckpt_path.startswith("hdfs"):
        tmp_dir = os.path.join(
            tempfile.gettempdir(), os.path.basename(ckpt_path)
        )
        local_dir = tmp_dir
        hcopy(ckpt_path, local_dir)
    else:
        local_dir = ckpt_path
    return local_dir




from datasets import load_dataset

dir_path = load_hdfs_path("....../home/.../.../user/.../metadata/dpo-mix-7k")


dataset = load_dataset(dir_path,split="train")

print("Example dataset:")
print(dataset[0]["chosen"][0]["content"])





from datasets import load_dataset
from transformers import AutoTokenizer

def chatml_format(example):
    message = {"role": "user", "content": example['chosen'][0]['content']}
    prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
    chosen = example['chosen'][1]['content']+tokenizer.eos_token
    rejected = example['rejected'][1]['content']+tokenizer.eos_token

    return {
        "prompt": prompt,
        "chosen": chosen,
        "rejected": rejected,
    }

dataset = load_dataset(dir_path,split="train")
original_columns = dataset.column_names


model_name = load_hdfs_path("....../home/.../.../user/.../debug/gemma-2b/gemma-2b")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"


dataset = dataset.map(
    chatml_format,
    remove_columns=original_columns
)

print(dataset[1])


print(dataset)

import os
import gc
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, PeftModel
import wandb
from trl import ORPOTrainer, DPOTrainer
from trl import ORPOConfig, DPOConfig
wandb.init(project="....dpo", name="dpo-gemma2b-mixdpo7k", notes="this is a DPO test with gemma-2b + mix-dpo-7k")

local_model_path =model_name

model = AutoModelForCausalLM.from_pretrained(
    local_model_path,
    torch_dtype="auto",
    trust_remote_code=True
)
model.config.use_cache = False

new_model = "/.../.../.../gemma2b/saved"



training_args = DPOConfig (
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={'use_reentrant':True},
    remove_unused_columns=False,
    learning_rate=5e-5,
    lr_scheduler_type="cosine",
    num_train_epochs=50,
    save_strategy="no",
    logging_steps=1,
    output_dir=new_model,
    warmup_steps=80,
    max_prompt_length=256,
    max_length=1024,
    report_to="wandb",
)

dpo_trainer = DPOTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
)

dpo_trainer.train()
