import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"

import copy
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
from datasets import load_dataset
from trl.trainer.dpo_config import DPOConfig
from trl import DPOTrainer
from transformers.trainer_callback import TrainerCallback

class SaveLogCallback(TrainerCallback):
    def __init__(self, log_file="training_log.txt"):
        self.log_file = log_file
        with open(self.log_file, "w") as f:
            f.write("step, train_loss, eval_loss\n")
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            step = state.global_step
            train_loss = logs.get("loss", "NA")
            eval_loss = logs.get("eval_loss", "NA")
            with open(self.log_file, "a") as f:
                f.write(f"{step}, {train_loss}, {eval_loss}\n")

class ClassificationDPOTrainer(DPOTrainer):
    @staticmethod
    def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens):

        prompt_encoded = processing_class(
            features["prompt"],
            truncation=True,
            padding="max_length",
            max_length=max_prompt_length,
            add_special_tokens=add_special_tokens,
        )
        prompt_input_ids = prompt_encoded["input_ids"]
        chosen_input_ids = [features["chosen"]]
        rejected_input_ids = [features["rejected"]]
        return {
            "prompt_input_ids": prompt_input_ids,
            "chosen_input_ids": chosen_input_ids,
            "rejected_input_ids": rejected_input_ids,
        }

    def concatenated_forward(self, model, batch):
        input_ids = batch["prompt_input_ids"]
        attention_mask = batch.get("prompt_attention_mask")
        if attention_mask is None:
            attention_mask = (input_ids != self.args.padding_value).long()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  
        logprobs = torch.log_softmax(logits, dim=-1)  
        chosen_label = batch["chosen_input_ids"].squeeze(1)
        rejected_label = batch["rejected_input_ids"].squeeze(1)
        chosen_logps = logprobs.gather(1, chosen_label.unsqueeze(1)).squeeze(1)
        rejected_logps = logprobs.gather(1, rejected_label.unsqueeze(1)).squeeze(1)
        return {
            "chosen_logps": chosen_logps,
            "rejected_logps": rejected_logps,
            "mean_chosen_logits": chosen_logps,  
            "mean_rejected_logits": rejected_logps,
        }

model_name = "/data/path/to/models/Qwen2.5-0.5B"
config = AutoConfig.from_pretrained(model_name, num_labels=3)
model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    if tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token
    else:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        model.resize_token_embeddings(len(tokenizer))
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
config.pad_token_id = tokenizer.pad_token_id
model.config.pad_token_id = tokenizer.pad_token_id

ref_model = copy.deepcopy(model)
ref_model.eval()
for param in ref_model.parameters():
    param.requires_grad = False

dataset = load_dataset("json", data_files="data/dpo/bird_train_dataset_simplified/classifier_train.json")["train"]
e_dataset = load_dataset("json", data_files="data/dpo/bird_train_dataset_simplified/classifier_valid.json")["train"]

def preprocess(example):
    return {
        "prompt": str(example.get("text", "")),
        "chosen": int(example.get("chosen", 0)),
        "rejected": int(example.get("rejected", 0)),
    }


dataset = dataset.map(preprocess, batched=False, remove_columns=dataset.column_names)
e_dataset = e_dataset.map(preprocess, batched=False, remove_columns=e_dataset.column_names)


dpo_config = DPOConfig(
    output_dir="/data/xxx/saves/Qwen2.5-0.5B-router/dpo/bird_train_dataset_simplified/dpo-qwen2.5-0.5b_0220", 
    per_device_train_batch_size=8,              
    num_train_epochs=12,                        
    fp16=True,                                  
    learning_rate=1e-4,                         
    loss_type="sigmoid",                       
    beta=0.1,                                
    max_prompt_length=512,                      
    max_completion_length=128,                 
    max_length=640,                              
    padding_value=tokenizer.pad_token_id,      
    generate_during_eval=False,                 
    dataset_num_proc=None,                      
    remove_unused_columns=False,              
    logging_steps=100,                           
    save_steps=100,                              
    eval_strategy="steps", 
    eval_steps=100,           
    logging_dir="/data/xxx/saves/Qwen2.5-0.5B-router/dpo/bird_train_dataset_simplified/dpo-qwen2.5-0.5b_0220/logs", 
)


trainer = ClassificationDPOTrainer(
    model=model,
    ref_model=ref_model,
    args=dpo_config,
    train_dataset=dataset,
    eval_dataset=e_dataset,
    processing_class=tokenizer,  
    data_collator=None,          
)

trainer.add_callback(SaveLogCallback(log_file=os.path.join(dpo_config.output_dir, "training_log.txt")))

trainer.train()

final_metrics = trainer.evaluate()

print(f"Model and checkpoints have saved to: {dpo_config.output_dir}")

