import gc
import os
import torch
import wandb
from datasets import load_dataset
import time
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format

# Model
model_id = "cognitivecomputations/dolphin-2.1-mistral-7b"  # Updated model
new_model = "Orpo-Dolphin-Mistral"



# Set torch dtype and attention implementation
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
    attn_implementation = "flash_attention_2"
else:
    torch_dtype = torch.float16
    attn_implementation = "eager"

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left' 
tokenizer.truncation_side = 'left' # left means the beginning will be removed so we keep the important assistant response ending


# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)
#model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

# Load local JSON dataset
train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")

# # Shuffle and take a subset for quick demo (Remove `.select(range(1000))` if you want full dataset)
# train_dataset = train_dataset.shuffle(seed=42).select(range(1000))  

# # Format dataset
# def format_chat_template(row):
#     row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
#     row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
#     return row

# train_dataset = train_dataset.map(
#     format_chat_template,
#     num_proc=os.cpu_count(),
# )

train_dataset = train_dataset.train_test_split(test_size=0.01)

# this from from the dpo run
# prompt_length = 1024
# max_seq_length = 1512
orpo_args = ORPOConfig(
    learning_rate=8e-6,
    lr_scheduler_type="linear",
    max_length=1024,
    max_prompt_length=512,
    beta=0.1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    optim="paged_adamw_8bit",
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    report_to="wandb",
    output_dir="./results/",
)

trainer = ORPOTrainer(
    model=model,
    args=orpo_args,
    train_dataset=train_dataset["train"],
    eval_dataset=train_dataset["test"],
    peft_config=peft_config,
    tokenizer=tokenizer,
)
wandb.init(project="ICML_CVXDPO", name="DPO_ultra_dophin1")

# Track time
start_time = time.time()

# Reset peak GPU memory before training
if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()


trainer.train()
trainer.save_model(new_model)

# Calculate total run time
total_run_time = time.time() - start_time

# Get peak GPU memory usage (in MB)
if torch.cuda.is_available():
    peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 2)  # Convert bytes to MB
else:
    peak_memory = "N/A (No GPU)"

# Log results to W&B
wandb.log({
    "Total Run Time (seconds)": total_run_time,
    "Peak GPU Memory Usage (MB)": peak_memory
})
# Flush memory
del trainer, model
gc.collect()
gc.collect()
torch.cuda.empty_cache()

# # Reload tokenizer and model
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# fp16_model = AutoModelForCausalLM.from_pretrained(
#     model_id,
#     low_cpu_mem_usage=True,
#     return_dict=True,
#     torch_dtype=torch.float16,
#     device_map="auto",
# )
# fp16_model, tokenizer = setup_chat_format(fp16_model, tokenizer)

# # Merge adapter with base model
# model = PeftModel.from_pretrained(fp16_model, new_model)
# model = model.merge_and_unload()