from datasets import load_dataset
import wandb
import time
# import torch
# from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
# from peft import LoraConfig
# from trl import DPOTrainer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
# Hugging Face model id
MODEL_NAME = "openai-community/gpt2" 
TRAIN_DATASET = "/home/miria/cvxdpo/train_edu_dataset_subset.json"
TEST_DATASET = "/home/miria/cvxdpo/test_edu_dataset_subset.json"
OUTPUT_DIR = "mistral-dpo-demo"
NUM_EPOCHS = 1
RUN_NAME = "DPO_edu_gpt2"

train_dataset = load_dataset("json", data_files=TRAIN_DATASET, split="train")
eval_dataset = load_dataset("json", data_files=TEST_DATASET, split="train")

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
 
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    use_cache=False,
    #attn_implementation="sdpa",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left' 
tokenizer.truncation_side = 'left' # left means the beginning will be removed so we keep the important assistant response ending

prompt_length = 1024
max_seq_length = 1512

from peft import LoraConfig
 
# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
        lora_alpha=128,
        lora_dropout=0.05,
        r=256,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM",
)

from transformers import TrainingArguments
from trl import DPOTrainer, DPOConfig

dpo_args = {
    "beta": 0.1,                            # The beta factor in DPO loss. Higher beta means less divergence
    "loss_type": "sigmoid"                  # The loss type for DPO.
}

args = DPOConfig(
    output_dir=OUTPUT_DIR,               # directory to save and repository id
    num_train_epochs=NUM_EPOCHS,                     # number of training epochs
    per_device_train_batch_size=2,         # batch size per device during training
    per_device_eval_batch_size=1,           # batch size for evaluation
    gradient_accumulation_steps=1,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    learning_rate=5e-5,                     # 10x higher LR than QLoRA paper
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.1,                       # warmup ratio based on QLoRA paper
    lr_scheduler_type="cosine",             # use cosine learning rate scheduler
    logging_steps=25,                       # log every 25 steps
    save_steps=500,                         # when to save checkpoint
    save_total_limit=2,                     # limit the total amount of checkpoints
    evaluation_strategy="steps",            # evaluate every 1000 steps
    eval_steps=700,                         # when to evaluate
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    push_to_hub=False,                      # push model to hub
    report_to="wandb",    
    max_length=max_seq_length,
    max_prompt_length=prompt_length,
    beta=dpo_args["beta"],
    loss_type=dpo_args["loss_type"],
)
 

from trl import DPOTrainer
 
trainer = DPOTrainer(
    model,
    ref_model=None, # set to none since we use peft
    peft_config=peft_config,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)

wandb.init(project="ICML_CVXDPO", name=RUN_NAME)

# Track time
start_time = time.time()

# Reset peak GPU memory before training
if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()


# start training
trainer.train()
 
trainer.save_model()
total_run_time = time.time() - start_time

# Get peak GPU memory usage in MB
if torch.cuda.is_available():
    peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 2)  # Convert bytes to MB
else:
    peak_memory = "N/A (No GPU)"

wandb.log({
    "Total Run Time (seconds)": total_run_time,
    "Peak GPU Memory Usage (MB)": peak_memory
})

del model
del trainer
torch.cuda.empty_cache()
wandb.finish()