from transformers import AutoProcessor, LlavaNextForConditionalGeneration, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from datasets import load_dataset
import torch
import torch.nn.functional as F
import os
from typing import Optional, Dict, Any



torch.cuda.empty_cache()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_name = 'llava-hf/llava-v1.6-vicuna-7b-hf'
data_file = 'test_tune_dpo.json'
image_folder = ''
output_dir = './results_kl'
use_images = False
num_epochs = 3
batch_size = 1
learning_rate = 5e-5
beta = 0.1
kl_lambda = 0.5


processor = AutoProcessor.from_pretrained(model_name)
model = LlavaNextForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.float16)

target_tokens = ["\nsame", "\nmale", "\nfemale"]
target_token_ids = [processor.tokenizer.encode(t, add_special_tokens=False)[0] for t in target_tokens]

dataset = load_dataset("json", data_files=data_file, split="train")

def preprocess(example):
    user_text = example["conversation"][0]["text"]
    chosen_text = example["chosen_texts"][0]
    rejected_text = example["rejected_texts"][0]

    inputs = processor(
        text=user_text,
        images=example["image"] if use_images else None,
        return_tensors="pt",
        padding="max_length",
        add_special_tokens=False,
        max_length=32,
        truncation=True
    )

    chosen_inputs = processor(
        text=chosen_text,
        return_tensors="pt",
        add_special_tokens=False,
        padding="max_length",
        max_length=64,
        truncation=True
    )
    rejected_inputs = processor(
        text=rejected_text,
        add_special_tokens=False,
        return_tensors="pt",
        padding="max_length",
        max_length=64,
        truncation=True
    )

    return {
        "input_ids": inputs["input_ids"][0],
        "attention_mask": inputs["attention_mask"][0],
        "chosen_ids": chosen_inputs["input_ids"][0],
        "rejected_ids": rejected_inputs["input_ids"][0],
        "chosen_attn_mask": chosen_inputs["attention_mask"][0],
        "rejected_attn_mask": rejected_inputs["attention_mask"][0]
    }

if use_images:
    dataset = dataset.map(lambda x: {"image": os.path.join(image_folder, x["image_path"])})
dataset = dataset.map(preprocess)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "chosen_ids", "rejected_ids","chosen_attn_mask", "rejected_attn_mask"])




# class CustomDataCollator:
#     def __call__(self, features):
#         batch = {}
#         for key in features[0].keys():
#             if isinstance(features[0][key], torch.Tensor):
#                 batch[key] = torch.stack([f[key] for f in features])
#             else:
#                 batch[key] = [f[key] for f in features]
#         return batch



import torch
import torch.nn.functional as F

def compute_dpo_loss_single_forward(model, prompt_input_ids, chosen_ids, rejected_ids, chosen_attn_mask, rejected_attn_mask, beta=0.1, pad_token_id=-100):
    """
    使用单次 forward + 自回归 token 评估的 DPO loss 实现
    """
    batch_size, _ = chosen_ids.size()

    # Step 1: forward once on prompt
    with torch.no_grad():
        outputs = model(prompt_input_ids)
        logits = outputs.logits  # (batch, seq_len, vocab_size)

    input_ids = prompt_input_ids
    next_token_logits = logits[:, -1, :]  # (batch, vocab)
    device = input_ids.device
    chosen_logprobs = torch.zeros(batch_size, device=device)
    rejected_logprobs = torch.zeros(batch_size, device=device)
    print(chosen_ids)
    print(rejected_ids)
    chosen_lengths = chosen_attn_mask.sum(dim=1)
    print("chosen length:", chosen_lengths)

    rejected_lengths = rejected_attn_mask.sum(dim=1)  # (batch,)
    print("reject length:", rejected_lengths)

    for t in range(chosen_ids.size(1)):
        print("chosen：")

        still_valid = (t < chosen_lengths)

        if still_valid.any():
            print("chosen id step:", t+1)
            chosen_t = chosen_ids[:, t]  # 当前步 token
            log_probs = F.log_softmax(next_token_logits, dim=-1)
            chosen_logprobs[still_valid] += log_probs[still_valid].gather(1, chosen_t[still_valid].unsqueeze(-1)).squeeze(-1)

            input_ids = torch.cat([input_ids, chosen_t.unsqueeze(1)], dim=1)
            with torch.no_grad():
                next_token_logits = model(input_ids).logits[:, -1, :]

        else:
            break

    input_ids = prompt_input_ids
    with torch.no_grad():
        outputs = model(prompt_input_ids)
        logits = outputs.logits
        next_token_logits = logits[:, -1, :]

    for t in range(rejected_ids.size(1)):
        print("reject：")
        still_valid = (t < rejected_lengths)

        if still_valid.any():
            print("reject id step:", t+1)

            rejected_t = rejected_ids[:, t]
            log_probs = F.log_softmax(next_token_logits, dim=-1)
            rejected_logprobs[still_valid] += log_probs[still_valid].gather(1, rejected_t[still_valid].unsqueeze(-1)).squeeze(-1)

            input_ids = torch.cat([input_ids, rejected_t.unsqueeze(1)], dim=1)
            with torch.no_grad():
                next_token_logits = model(input_ids).logits[:, -1, :]
        else:
            break

    diff = chosen_logprobs - rejected_logprobs  # (batch,)
    return -F.logsigmoid(beta * diff).mean()


class DPOKLTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=False):
        input_ids = inputs["input_ids"]
        chosen_ids = inputs["chosen_ids"]
        rejected_ids = inputs["rejected_ids"]
        chosen_attn_mask = inputs["chosen_attn_mask"]
        rejected_attn_mask = inputs["rejected_attn_mask"]

        print("1")
        dpo_loss = compute_dpo_loss_single_forward(
            model=model,
            prompt_input_ids=input_ids,
            chosen_ids=chosen_ids,
            rejected_ids=rejected_ids,
            chosen_attn_mask=chosen_attn_mask,
            rejected_attn_mask=rejected_attn_mask,
            beta = beta
        )
        print("2")


        outputs = model(input_ids=input_ids)
        logits = outputs.logits
        last_logits = logits[:, -1, :]

        selected_logits = last_logits[:, target_token_ids]  # shape: (batch, 3)
        pred_log_probs = F.log_softmax(selected_logits, dim=-1)
        print("3")

        target_distribution = torch.zeros_like(pred_log_probs)
        target_distribution[:, 0] = 1.0

        kl_loss = F.kl_div(pred_log_probs, target_distribution, reduction="batchmean", log_target=False)
        print("4")

        loss = dpo_loss + kl_lambda * kl_loss
        print("5")
        return (loss, outputs) if return_outputs else loss





training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy="no",
    eval_steps=0,
    learning_rate=learning_rate,
    per_device_train_batch_size=1,
    num_train_epochs=num_epochs,
    save_steps=10_000,
    save_total_limit=2,
    logging_steps=1,
    logging_dir='./logs',
    dataloader_num_workers=2,
    deepspeed='=/ds_config.json',
    report_to="none",
    ddp_find_unused_parameters=False,
    disable_tqdm=False,
    fp16=True,
    remove_unused_columns=False
)




print("warping trainer...")
trainer = DPOKLTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=processor,
    # data_collator=CustomDataCollator(),
)
print("training starts now...")
trainer.train()
print("training finished...")


trainer.accelerator.unwrap_model(trainer.model)





trainer.save_model('path')

processor.save_pretrained('path')
torch.cuda.empty_cache()
