#!/usr/bin/env python3
import os
from dataclasses import dataclass, field
from typing import Optional

import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, HfArgumentParser
from trl import DPOTrainer, DPOConfig


# -----------------------------
# Script-specific arguments
# -----------------------------

@dataclass
class ScriptArguments:
    train_pairs_path: str = field(metadata={"help": "Path to training pair JSON file (e.g., ultra_truth_pairs_train.json)."})
    eval_pairs_path: str = field(metadata={"help": "Path to eval/val pair JSON file (e.g., ultra_truth_pairs_val.json)."})
    model_name: str = field(default="meta-llama/Llama-3.1-8B-Instruct", metadata={"help": "Base model name or path."})
    max_train_samples: Optional[int] = field(default=None, metadata={"help": "Optional limit on number of training samples."})
    max_eval_samples: Optional[int] = field(default=None, metadata={"help": "Optional limit on number of eval samples."})
    margin_alpha: float = field(default=1.0, metadata={"help": "Scale for ODPO-style margin term."})
    conflict_lambda: float = field(default=1.0, metadata={"help": "Scale for conflict adjustment in margin."})


# -----------------------------
# CAMP trainer
# -----------------------------

class CAMPTrainer(DPOTrainer):
    def __init__(self, *args, margin_alpha: float = 1.0, conflict_lambda: float = 1.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.margin_alpha = float(margin_alpha)
        self.conflict_lambda = float(conflict_lambda)
        self._current_margin = None

    def dpo_loss(self, chosen_logps, rejected_logps, ref_chosen_logps, ref_rejected_logps,
                 loss_type="sigmoid", model_output=None):
        pi_logratios = chosen_logps - rejected_logps
        ref_logratios = ref_chosen_logps - ref_rejected_logps
        logits = pi_logratios - ref_logratios

        if self._current_margin is None:
            losses = -F.logsigmoid(self.beta * logits)
        else:
            margin = self._current_margin.to(device=logits.device, dtype=logits.dtype).view(-1)
            losses = -F.logsigmoid(self.beta * logits - self.margin_alpha * margin)

        chosen_rewards = (self.beta * (chosen_logps - ref_chosen_logps)).detach()
        rejected_rewards = (self.beta * (rejected_logps - ref_rejected_logps)).detach()
        return losses, chosen_rewards, rejected_rewards

    def get_batch_loss_metrics(self, model, batch, train_eval="train"):

        truth_chosen = batch.get("truth_chosen", None)
        truth_rejected = batch.get("truth_rejected", None)
        help_chosen = batch.get("help_chosen", None)
        help_rejected = batch.get("help_rejected", None)

        self._current_margin = (truth_chosen - truth_rejected) if (truth_chosen is not None and truth_rejected is not None) else None

        if (truth_chosen is not None and truth_rejected is not None and help_chosen is not None and help_rejected is not None):
            truth_delta = truth_chosen - truth_rejected
            help_delta  = help_chosen  - help_rejected

            synergy_mask = help_delta >= 0
            conflict_margin = truth_delta + (self.conflict_lambda * help_delta)
            conflict_margin = torch.clamp(conflict_margin, min=0.0)
            margin = torch.where(synergy_mask, truth_delta, conflict_margin)

            self._current_margin = margin.to(dtype=torch.float32).view(-1)

        loss, metrics = super().get_batch_loss_metrics(model, batch, train_eval=train_eval)

        if self._current_margin is not None:
            metrics[f"{train_eval}/margin_mean"] = float(self._current_margin.float().mean().detach().cpu())
        return loss, metrics


# -----------------------------
# Helpers
# -----------------------------

def format_prompt(ex):
    """Format prompt from pair JSON -> string used as DPO 'prompt'."""
    return ex.get("instruction", "")

def load_dpo_from_pairs(train_path: str, eval_path: str,
                        max_train_samples: Optional[int] = None,
                        max_eval_samples: Optional[int] = None):
    """Load pair JSONs and add DPO columns: prompt, chosen, rejected."""
    data_files = {"train": train_path, "eval": eval_path}
    raw = load_dataset("json", data_files=data_files)

    def add_dpo_columns(ex):
        ex["prompt"] = format_prompt(ex)
        ex["chosen"] = ex["chosen_response"]
        ex["rejected"] = ex["rejected_response"]

        scores = ex.get("scores", None)
        if scores is None:
            raise ValueError("Missing 'scores' in example")

        for k in ["truth_chosen", "truth_rejected", "help_chosen", "help_rejected"]:
            if k not in scores:
                raise ValueError(f"Missing '{k}' in scores")

        ex["truth_chosen"] = scores["truth_chosen"]
        ex["truth_rejected"] = scores["truth_rejected"]
        ex["help_chosen"] = scores["help_chosen"]
        ex["help_rejected"] = scores["help_rejected"]

        return ex

    train_dataset = raw["train"].map(add_dpo_columns)
    eval_dataset = raw["eval"].map(add_dpo_columns)

    if max_train_samples is not None and max_train_samples > 0:
        train_dataset = train_dataset.select(range(min(max_train_samples, len(train_dataset))))
    if max_eval_samples is not None and max_eval_samples > 0:
        eval_dataset = eval_dataset.select(range(min(max_eval_samples, len(eval_dataset))))

    for col in ["prompt", "chosen", "rejected", "truth_chosen", "truth_rejected", "help_chosen", "help_rejected"]:
        if col not in train_dataset.column_names:
            raise ValueError(f"Column '{col}' not found in train dataset. Available: {train_dataset.column_names}")
        if col not in eval_dataset.column_names:
            raise ValueError(f"Column '{col}' not found in eval dataset. Available: {eval_dataset.column_names}")

    return train_dataset, eval_dataset


# -----------------------------
# Main
# -----------------------------

def main():
    parser = HfArgumentParser((ScriptArguments, DPOConfig))
    script_args, training_args = parser.parse_args_into_dataclasses()
    assert isinstance(script_args, ScriptArguments)
    assert isinstance(training_args, DPOConfig)

    if training_args.max_prompt_length is None:
        training_args.max_prompt_length = 256
    if getattr(training_args, "max_completion_length", None) is None:
        training_args.max_completion_length = 256
    if training_args.max_length is None:
        training_args.max_length = training_args.max_prompt_length + training_args.max_completion_length

    os.makedirs(training_args.output_dir, exist_ok=True)

    train_dataset, eval_dataset = load_dpo_from_pairs(
        train_path=script_args.train_pairs_path,
        eval_path=script_args.eval_pairs_path,
        max_train_samples=script_args.max_train_samples,
        max_eval_samples=script_args.max_eval_samples,
    )
    print(f"[Data] Loaded pair JSONs from: {script_args.train_pairs_path}, {script_args.eval_pairs_path}")
    print(f"[Data] Train size: {len(train_dataset)}, Eval size: {len(eval_dataset)}")

    tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_name,
        torch_dtype=torch_dtype,
        device_map="auto",
    )

    dpo_trainer = CAMPTrainer(
        model=model,
        ref_model=None,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        processing_class=tokenizer,
        margin_alpha=script_args.margin_alpha,
        conflict_lambda=script_args.conflict_lambda,
    )

    base_collator = dpo_trainer.data_collator
    def collate_with_truth_scores(features):
        truth_chosen = torch.tensor([float(f.get("truth_chosen", 0.0)) for f in features], dtype=torch.float32)
        truth_rejected = torch.tensor([float(f.get("truth_rejected", 0.0)) for f in features], dtype=torch.float32)
        help_chosen = torch.tensor([float(f.get("help_chosen", 0.0)) for f in features], dtype=torch.float32)
        help_rejected = torch.tensor([float(f.get("help_rejected", 0.0)) for f in features], dtype=torch.float32)
        batch = base_collator(features)
        batch["truth_chosen"] = truth_chosen
        batch["truth_rejected"] = truth_rejected
        batch["help_chosen"] = help_chosen
        batch["help_rejected"] = help_rejected
        return batch
    dpo_trainer.data_collator = collate_with_truth_scores

    print("[Train] Starting DPO training...")
    dpo_trainer.train()
    print("[Train] Done.")


if __name__ == "__main__":
    main()
