import inspect
import random
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import torch.profiler
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from datasets import Dataset
from huggingface_hub.utils._deprecation import _deprecate_arguments
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForCausalLM,
    DataCollator,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Trainer,
    is_apex_available,
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.training_args import OptimizerNames
from trl import DPOTrainer
from trl.data_utils import (
    apply_chat_template,
    is_conversational,
    maybe_apply_chat_template,
)
from trl.models.utils import unwrap_model_for_generation
from trl.trainer.dpo_config import (
    DPOConfig,
    FDivergenceConstants,
    FDivergenceType,
)
from trl.trainer.online_dpo_trainer import OnlineDPOTrainer
from trl.trainer.utils import (
    cap_exp,
    empty_cache,
    get_reward,
    pad_to_length,
    truncate_right,
)

if is_apex_available():
    from apex import amp


def compute_p_oracle(
    chosen_scores: torch.Tensor,
    rejected_scores: torch.Tensor,
    margin_scale: float = 1.0,
    label_type: Literal["oracle", "binary", "conditioned", "3level"] = "oracle",
    soft_threshold: float = 0.1,
) -> torch.Tensor:
    p_oracle = torch.nn.functional.sigmoid(
        margin_scale * (chosen_scores - rejected_scores)
    )
    if label_type == "oracle":
        return p_oracle
    elif label_type == "binary":
        return torch.bernoulli(p_oracle)
    elif label_type == "conditioned":
        diff = chosen_scores - rejected_scores
        diff_mask = diff > soft_threshold
        return torch.where(diff_mask, torch.ones_like(diff), p_oracle)

        # mask = chosen_scores > soft_threshold
        # return torch.where(mask, torch.ones_like(chosen_scores), p_oracle)
    elif label_type == "3level":
        mask_low = p_oracle < 0.5
        mask_high = p_oracle >= 0.5
        sample_low = 0.5 * torch.bernoulli(2 * p_oracle)
        sample_high = 0.5 * (1 + torch.bernoulli(2 * p_oracle - 1))
        output = torch.zeros_like(p_oracle)
        output[mask_low] = sample_low[mask_low]
        output[mask_high] = sample_high[mask_high]
        return output
    else:
        raise ValueError(f"Invalid label_type: {label_type}")


class SoftDPOTrainer(DPOTrainer):
    """Use oracle-labeled (p_oracle) as weight for the modified sigmoid DPO loss.
    To use, make sure the dataset contains a `p_oracle` column.
    ! we do not use any col name with `_label`
    ! to avoid any potential conflict with concatenate operation.

    How the loss computation get called:
    - compute_loss() ->
        - get_batch_loss_metrics() ->
            - concatenated_forward()
            - dpo_loss()
    """

    def __init__(
        self,
        model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
        ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
        beta: float = 0.1,
        label_smoothing: float = 0,
        loss_type: Optional[str] = None,
        args: Optional[DPOConfig] = None,
        data_collator: Optional[DataCollator] = None,
        label_pad_token_id: int = -100,
        padding_value: Optional[int] = None,
        truncation_mode: str = "keep_end",
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[
            torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR
        ] = (None, None),
        preprocess_logits_for_metrics: Optional[
            Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
        ] = None,
        max_length: Optional[int] = None,
        max_prompt_length: Optional[int] = None,
        max_target_length: Optional[int] = None,
        peft_config: Optional[Dict] = None,
        is_encoder_decoder: Optional[bool] = None,
        disable_dropout: bool = True,
        generate_during_eval: bool = False,
        compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
        precompute_ref_log_probs: bool = False,
        dataset_num_proc: Optional[int] = None,
        model_init_kwargs: Optional[Dict] = None,
        ref_model_init_kwargs: Optional[Dict] = None,
        model_adapter_name: Optional[str] = None,
        ref_adapter_name: Optional[str] = None,
        reference_free: bool = False,
        force_use_ref_model: bool = False,
    ):
        super().__init__(
            model=model,
            ref_model=ref_model,
            beta=beta,
            label_smoothing=label_smoothing,
            loss_type=loss_type,
            args=args,
            data_collator=data_collator,
            label_pad_token_id=label_pad_token_id,
            padding_value=padding_value,
            truncation_mode=truncation_mode,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            model_init=model_init,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
            max_length=max_length,
            max_prompt_length=max_prompt_length,
            max_target_length=max_target_length,
            peft_config=peft_config,
            is_encoder_decoder=is_encoder_decoder,
            disable_dropout=disable_dropout,
            generate_during_eval=generate_during_eval,
            compute_metrics=compute_metrics,
            precompute_ref_log_probs=precompute_ref_log_probs,
            dataset_num_proc=dataset_num_proc,
            model_init_kwargs=model_init_kwargs,
            ref_model_init_kwargs=ref_model_init_kwargs,
            model_adapter_name=model_adapter_name,
            ref_adapter_name=ref_adapter_name,
            reference_free=reference_free,
            force_use_ref_model=force_use_ref_model,
        )

    @staticmethod
    def get_batch_logps(
        logits: torch.FloatTensor,
        labels: torch.LongTensor,
        label_pad_token_id: int = -100,
        is_encoder_decoder: bool = False,
    ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
        """Compute the log probabilities of the given labels under the given logits.

        Args:
            logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
            labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
            label_pad_token_id: The label pad token id.
            is_encoder_decoder: Whether the model is an encoder-decoder model.

        Returns:
            A Tuple of two tensor of shape ((batch_size,), (batch_size,)) containing the sum of log probabilities of the given labels under the given logits in the first tensor and the number of non-masked tokens in the second tensor.
        """
        if logits.shape[:-1] != labels.shape:
            raise ValueError(
                "Logits (batch and sequence length dim) and labels must have the same shape."
            )

        if not is_encoder_decoder:
            labels = labels[:, 1:].clone()
            logits = logits[:, :-1, :]
        loss_mask = labels != label_pad_token_id

        # dummy token; we'll ignore the losses on these tokens later
        labels[labels == label_pad_token_id] = 0

        per_token_logps = torch.gather(
            logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
        ).squeeze(2)

        return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)

    def compute_loss(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs=False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
        # if not self.use_dpo_data_collator:
        #     warnings.warn(
        #         "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
        #         "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
        #     )

        compute_loss_context_manager = (
            amp.autocast("cuda")
            if self._peft_has_been_casted_to_bf16
            else nullcontext()
        )

        with compute_loss_context_manager:
            loss, metrics = self.get_batch_loss_metrics(
                model, inputs, train_eval="train"
            )

        # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
        loss = loss.to(self.args.device)
        # force log the metrics
        self.store_metrics(metrics, train_eval="train")

        if return_outputs:
            return (loss, metrics)
        return loss

    @staticmethod
    def concatenated_inputs(
        batch: Dict[str, Union[List, torch.LongTensor]],
        is_encoder_decoder: bool = False,
        is_vision_model: bool = False,
        label_pad_token_id: int = -100,
        padding_value: int = 0,
        device: Optional[torch.device] = None,
    ) -> Dict[str, torch.LongTensor]:
        """Concatenate the chosen and rejected inputs into a single tensor.

        Args:
            batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
            is_encoder_decoder: Whether the model is an encoder-decoder model.
            label_pad_token_id: The label pad token id.
            padding_value: The padding value to use for the concatenated inputs_ids.
            device: The device for the concatenated inputs.

        Returns:
            A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
        """
        concatenated_batch = {}

        if is_encoder_decoder:
            max_length = max(
                batch["chosen_labels"].shape[1],
                batch["rejected_labels"].shape[1],
            )
        else:
            max_length = max(
                batch["chosen_input_ids"].shape[1],
                batch["rejected_input_ids"].shape[1],
            )

        for k in batch:
            if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
                if "labels" in k or is_encoder_decoder:
                    pad_value = label_pad_token_id
                elif k.endswith("_input_ids"):
                    pad_value = padding_value
                elif k.endswith("_attention_mask"):
                    pad_value = 0
                concatenated_key = k.replace("chosen", "concatenated")
                concatenated_batch[concatenated_key] = pad_to_length(
                    batch[k], max_length, pad_value=pad_value
                )
        for k in batch:
            if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
                if "labels" in k or is_encoder_decoder:
                    pad_value = label_pad_token_id
                elif k.endswith("_input_ids"):
                    pad_value = padding_value
                elif k.endswith("_attention_mask"):
                    pad_value = 0
                concatenated_key = k.replace("rejected", "concatenated")
                concatenated_batch[concatenated_key] = torch.cat(
                    (
                        concatenated_batch[concatenated_key],
                        pad_to_length(
                            batch[k], max_length, pad_value=pad_value
                        ),
                    ),
                    dim=0,
                ).to(device=device)

        if is_encoder_decoder:
            concatenated_batch["concatenated_input_ids"] = (
                batch["prompt_input_ids"].repeat(2, 1).to(device=device)
            )
            concatenated_batch["concatenated_attention_mask"] = (
                batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
            )

        if is_vision_model:
            concatenated_batch["pixel_values"] = torch.cat(
                [batch["prompt_pixel_values"], batch["prompt_pixel_values"]],
                dim=0,
            )
            if "prompt_pixel_attention_mask" in batch:
                concatenated_batch["pixel_attention_mask"] = torch.cat(
                    [
                        batch["prompt_pixel_attention_mask"],
                        batch["prompt_pixel_attention_mask"],
                    ],
                    dim=0,
                )
        return concatenated_batch

    def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[
        torch.FloatTensor,
        torch.FloatTensor,
        torch.FloatTensor,
        torch.FloatTensor,
        torch.FloatTensor,
        Optional[torch.FloatTensor],
    ]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

        We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = self.concatenated_inputs(
            batch,
            is_encoder_decoder=self.is_encoder_decoder,
            is_vision_model=self.is_vision_model,
            label_pad_token_id=self.label_pad_token_id,
            padding_value=self.padding_value,
            device=self.accelerator.device,
        )
        len_chosen = batch["chosen_labels"].shape[0]

        model_kwargs = {}

        if self.is_encoder_decoder:
            model_kwargs["labels"] = concatenated_batch["concatenated_labels"]
            model_kwargs["decoder_input_ids"] = concatenated_batch.pop(
                "concatenated_decoder_input_ids", None
            )

        if self.is_vision_model:
            model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
            if "pixel_attention_mask" in concatenated_batch:
                model_kwargs["pixel_attention_mask"] = concatenated_batch[
                    "pixel_attention_mask"
                ]

        if self.aux_loss_enabled:
            model_kwargs["output_router_logits"] = True

        outputs = model(
            concatenated_batch["concatenated_input_ids"],
            attention_mask=concatenated_batch["concatenated_attention_mask"],
            use_cache=False,
            **model_kwargs,
        )
        all_logits = outputs.logits

        if (
            all_logits.shape[:2]
            != concatenated_batch["concatenated_labels"].shape[:2]
        ):
            # for llava, the model returns logits for the entire sequence, including the image tokens (placed before the text tokens)
            seq_len = concatenated_batch["concatenated_labels"].shape[1]
            all_logits = all_logits[:, -seq_len:]

        all_logps, size_completion = self.get_batch_logps(
            all_logits,
            concatenated_batch["concatenated_labels"],
            # average_log_prob=self.loss_type == "ipo",
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
        )

        def cross_entropy_loss(logits, labels):
            if not self.is_encoder_decoder:
                # Shift so that tokens < n predict n
                logits = logits[..., :-1, :].contiguous()
                labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            logits = logits.view(-1, logits.shape[-1])
            labels = labels.view(-1)
            # Enable model parallelism
            labels = labels.to(logits.device)
            loss = loss_fct(logits, labels)
            return loss

        labels = concatenated_batch["concatenated_labels"].clone()
        nll_loss = cross_entropy_loss(
            all_logits[:len_chosen], labels[:len_chosen]
        )

        if self.loss_type == "ipo":
            all_logps = all_logps / size_completion

        chosen_logps = all_logps[:len_chosen]
        rejected_logps = all_logps[len_chosen:]

        chosen_logits = all_logits[:len_chosen]
        rejected_logits = all_logits[len_chosen:]

        if "p_oracle" in batch:
            p_oracles = batch["p_oracle"].to(device=chosen_logps.device)
        else:
            p_oracles = None

        if self.aux_loss_enabled:
            return (
                chosen_logps,
                rejected_logps,
                chosen_logits,
                rejected_logits,
                nll_loss,
                p_oracles,
                outputs.aux_loss,
            )

        return (
            chosen_logps,
            rejected_logps,
            chosen_logits,
            rejected_logits,
            nll_loss,
            p_oracles,
        )

    def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        forward_output = self.concatenated_forward(model, batch)
        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
            policy_nll_loss,
            p_oracles,
        ) = forward_output[:6]
        if self.aux_loss_enabled:
            aux_loss = forward_output[6]

        # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
        if (
            "reference_chosen_logps" in batch
            and "reference_rejected_logps" in batch
            and self.args.rpo_alpha is not None
        ):
            reference_chosen_logps = batch["reference_chosen_logps"]
            reference_rejected_logps = batch["reference_rejected_logps"]
        else:
            with torch.no_grad():
                if self.ref_model is None:
                    with self.null_ref_context():
                        (
                            reference_chosen_logps,
                            reference_rejected_logps,
                            _,
                            _,
                            _,
                            _,
                        ) = self.concatenated_forward(self.model, batch)
                else:
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        _,
                        _,
                        _,
                        _,
                    ) = self.concatenated_forward(self.ref_model, batch)

        losses, chosen_rewards, rejected_rewards = self.dpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps,
            p_oracles,
        )
        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        if self.args.rpo_alpha is not None:
            # RPO loss from V3 of the paper:
            losses = losses + policy_nll_loss * self.args.rpo_alpha

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
        metrics[f"{prefix}rewards/margins"] = (
            (chosen_rewards - rejected_rewards).mean().cpu()
        )
        metrics[f"{prefix}logps/rejected"] = (
            policy_rejected_logps.detach().mean().cpu()
        )
        metrics[f"{prefix}logps/chosen"] = (
            policy_chosen_logps.detach().mean().cpu()
        )
        metrics[f"{prefix}logits/rejected"] = (
            policy_rejected_logits.detach().mean().cpu()
        )
        metrics[f"{prefix}logits/chosen"] = (
            policy_chosen_logits.detach().mean().cpu()
        )

        policy_logprobs = torch.cat(
            (policy_chosen_logps, policy_rejected_logps), dim=0
        )
        ref_logprobs = torch.cat(
            (reference_chosen_logps, reference_rejected_logps), dim=0
        )
        kl = policy_logprobs - ref_logprobs
        metrics[f"{prefix}obj/kl"] = kl.mean().detach().cpu()

        non_score_reward = -self.beta * kl
        metrics[f"{prefix}obj/non_score_reward"] = (
            non_score_reward.mean().detach().cpu()
        )

        scores = torch.cat((chosen_rewards, rejected_rewards), dim=0)
        metrics[f"{prefix}obj/scores"] = scores.mean().detach().cpu()

        rlhf_reward = scores + non_score_reward
        metrics[f"{prefix}obj/rlhf_reward"] = rlhf_reward.mean().detach().cpu()

        mean_entropy = -policy_logprobs.mean()
        metrics[f"{prefix}obj/mean_entropy"] = mean_entropy.detach().cpu()

        if self.args.rpo_alpha is not None:
            metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()

        if self.aux_loss_enabled:
            return (
                losses.mean()
                + getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss,
                metrics,
            )

        return losses.mean(), metrics

    def dpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
        p_oracle: Optional[torch.FloatTensor] = None,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute the DPO loss for a batch of policy and reference model log probabilities.

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
            reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
            reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)

        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the DPO loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
        """
        chosen_logratios = policy_chosen_logps.to(self.accelerator.device) - (
            not self.reference_free
        ) * reference_chosen_logps.to(self.accelerator.device)
        rejected_logratios = policy_rejected_logps.to(
            self.accelerator.device
        ) - (not self.reference_free) * reference_rejected_logps.to(
            self.accelerator.device
        )

        if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value:
            # The alpha-divergence formula: (1 - u^-alpha) / alpha
            # The divergence difference between the chosen and rejected sample is:
            #     (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha
            #        = (u[l]^-alpha - u[w]^-alpha) / alpha
            # where u[w] and u[l] are the policy/reference probability ratios
            # for the chosen and rejected samples, respectively.
            alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT
            if (
                self.f_divergence_params
                and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY
                in self.f_divergence_params
            ):
                alpha_coef = float(
                    self.f_divergence_params[
                        FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY
                    ]
                )
            logits = (
                cap_exp(rejected_logratios * -alpha_coef)
                - cap_exp(chosen_logratios * -alpha_coef)
            ) / alpha_coef
        else:
            pi_logratios = policy_chosen_logps - policy_rejected_logps
            if self.reference_free:
                ref_logratios = torch.tensor(
                    [0], dtype=pi_logratios.dtype, device=pi_logratios.device
                )
            else:
                ref_logratios = (
                    reference_chosen_logps - reference_rejected_logps
                )

            pi_logratios = pi_logratios.to(self.accelerator.device)
            ref_logratios = ref_logratios.to(self.accelerator.device)
            logits = pi_logratios - ref_logratios

            if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:
                # The js-divergence formula: log(2 * u / (1 + u))
                # The divergence difference between the chosen and rejected sample is:
                #     log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l]))
                #       = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l]))
                # where u[w] and u[l] are the policy/reference probability ratios
                # for the chosen and rejected samples, respectively.
                logits -= F.softplus(chosen_logratios) - F.softplus(
                    rejected_logratios
                )

        # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
        # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
        # calculates a conservative DPO loss.
        if self.loss_type == "sigmoid":
            losses = (
                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
                - F.logsigmoid(-self.beta * logits) * self.label_smoothing
            )
        elif self.loss_type == "soft_sigmoid":
            if p_oracle is None:
                raise ValueError(
                    "p_oracle must be provided for soft_sigmoid loss type"
                )
            losses = -F.logsigmoid(
                self.beta * logits
            ) * p_oracle - F.logsigmoid(-self.beta * logits) * (1 - p_oracle)
        elif self.loss_type == "robust":
            losses = (
                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
                + F.logsigmoid(-self.beta * logits) * self.label_smoothing
            ) / (1 - 2 * self.label_smoothing)
        elif self.loss_type == "exo_pair":
            # eqn (16) of the EXO paper: https://huggingface.co/papers/2402.00856
            import math

            if self.label_smoothing == 0:
                self.label_smoothing = 1e-3
            losses = (self.beta * logits).sigmoid() * (
                F.logsigmoid(self.beta * logits)
                - math.log(1 - self.label_smoothing)
            ) + (-self.beta * logits).sigmoid() * (
                F.logsigmoid(-self.beta * logits)
                - math.log(self.label_smoothing)
            )
        elif self.loss_type == "hinge":
            losses = torch.relu(1 - self.beta * logits)
        elif self.loss_type == "ipo":
            # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
            losses = (logits - 1 / (2 * self.beta)) ** 2
        elif self.loss_type == "bco_pair":
            chosen_logratios = policy_chosen_logps - reference_chosen_logps
            rejected_logratios = (
                policy_rejected_logps - reference_rejected_logps
            )

            chosen_rewards = self.beta * chosen_logratios
            rejected_rewards = self.beta * rejected_logratios
            rewards = (
                torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
            )
            self.running.update(rewards)
            delta = self.running.mean

            losses = -F.logsigmoid(
                (self.beta * chosen_logratios) - delta
            ) - F.logsigmoid(-(self.beta * rejected_logratios - delta))
        elif self.loss_type == "sppo_hard":
            # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach, estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is set to 1 for the winner and 0 for the loser.
            a = policy_chosen_logps - reference_chosen_logps
            b = policy_rejected_logps - reference_rejected_logps

            losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2
        elif self.loss_type == "nca_pair":
            chosen_rewards = (
                policy_chosen_logps - reference_chosen_logps
            ) * self.beta
            rejected_rewards = (
                policy_rejected_logps - reference_rejected_logps
            ) * self.beta
            losses = (
                -F.logsigmoid(chosen_rewards)
                - 0.5 * F.logsigmoid(-chosen_rewards)
                - 0.5 * F.logsigmoid(-rejected_rewards)
            )
        elif self.loss_type == "aot_pair":
            chosen_logratios = policy_chosen_logps - reference_chosen_logps
            rejected_logratios = (
                policy_rejected_logps - reference_rejected_logps
            )

            chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0)
            rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0)

            delta = chosen_logratios_sorted - rejected_logratios_sorted

            losses = (
                -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing)
                - F.logsigmoid(-self.beta * delta) * self.label_smoothing
            )

        elif self.loss_type == "aot":
            pi_logratios = policy_chosen_logps - policy_rejected_logps
            ref_logratios = reference_chosen_logps - reference_rejected_logps

            pi_logratios_sorted, _ = torch.sort(pi_logratios, dim=0)
            ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0)

            delta = pi_logratios_sorted - ref_logratios_sorted

            losses = (
                -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing)
                - F.logsigmoid(-self.beta * delta) * self.label_smoothing
            )

        elif self.loss_type == "apo_zero":
            # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
            # Use this loss when you believe the chosen outputs are better than your model's default output

            losses_chosen = 1 - F.sigmoid(
                self.beta * chosen_logratios
            )  # Increase chosen likelihood
            losses_rejected = F.sigmoid(
                self.beta * rejected_logratios
            )  # Decrease rejected likelihood

            losses = losses_chosen + losses_rejected

        elif self.loss_type == "apo_down":
            # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
            # Use this loss when you believe the chosen outputs are worse than your model's default output

            losses_chosen = F.sigmoid(
                self.beta * chosen_logratios
            )  # Decrease chosen likelihood
            losses_rejected = 1 - F.sigmoid(
                self.beta * (chosen_logratios - rejected_logratios)
            )  # Decrease rejected likelihood more

            losses = losses_chosen + losses_rejected

        else:
            raise ValueError(
                f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down']"
            )

        chosen_rewards = (
            self.beta
            * (
                policy_chosen_logps.to(self.accelerator.device)
                - reference_chosen_logps.to(self.accelerator.device)
            ).detach()
        )
        rejected_rewards = (
            self.beta
            * (
                policy_rejected_logps.to(self.accelerator.device)
                - reference_rejected_logps.to(self.accelerator.device)
            ).detach()
        )

        return losses, chosen_rewards, rejected_rewards

    def prediction_step(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ):
        # if not self.use_dpo_data_collator:
        #     warnings.warn(
        #         "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
        #         "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
        #     )
        if ignore_keys is None:
            if hasattr(model, "config"):
                ignore_keys = getattr(
                    model.config, "keys_to_ignore_at_inference", []
                )
            else:
                ignore_keys = []

        prediction_context_manager = (
            amp.autocast("cuda")
            if self._peft_has_been_casted_to_bf16
            else nullcontext()
        )

        with torch.no_grad(), prediction_context_manager:
            loss, metrics = self.get_batch_loss_metrics(
                model, inputs, train_eval="eval"
            )

        # force log the metrics
        self.store_metrics(metrics, train_eval="eval")

        if prediction_loss_only:
            return (loss.detach(), None, None)

        # logits for the chosen and rejected samples from model
        logits_dict = {
            "eval_logits/chosen": metrics["eval_logits/chosen"],
            "eval_logits/rejected": metrics["eval_logits/rejected"],
        }
        logits = tuple(
            v.unsqueeze(dim=0)
            for k, v in logits_dict.items()
            if k not in ignore_keys
        )
        logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
        labels = torch.zeros(logits.shape[0], device=self.accelerator.device)

        return (loss.detach(), logits, labels)


class OnlineSoftDPOTrainer(OnlineDPOTrainer):
    def __init__(
        self,
        *args,
        reward_processing_class,
        compute_soft_label=compute_p_oracle,
        using_profiler=False,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.compute_soft_label = compute_soft_label
        self.reward_processing_class = reward_processing_class
        self.using_profiler = using_profiler

    def training_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> torch.Tensor:
        if self.using_profiler:
            return self._training_step_with_profiler(model, inputs)
        else:
            return self._training_step(model, inputs)

    def _training_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> torch.Tensor:
        model.train()

        # Sample 2 completations per prompt of size `max_new_tokens` from the model
        inputs = self._prepare_inputs(inputs)
        prompts = inputs["prompt"]
        num_examples, context_length = inputs["prompt_input_ids"].shape
        prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
        prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
        with unwrap_model_for_generation(
            model, self.accelerator
        ) as unwrapped_model:
            output = unwrapped_model.generate(
                input_ids=prompt_ids,
                attention_mask=prompt_mask,
                generation_config=self.generation_config,
                pad_token_id=self.tokenizer.pad_token_id,
            )
        del inputs

        completion_ids = output[:, context_length:]
        completion_ids, completion_mask = truncate_right(
            completion_ids,
            self.tokenizer.eos_token_id,
            self.tokenizer.pad_token_id,
        )
        contain_eos_token = torch.any(
            completion_ids == self.tokenizer.eos_token_id, dim=-1
        )
        prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
        prompt_completion_mask = torch.cat(
            (prompt_mask, completion_mask), dim=1
        )

        # Get the logprobs of the completions from the model
        output = model(
            prompt_completion_ids, attention_mask=prompt_completion_mask
        )
        # There is 1 offset, because the model predict the next token
        logits = output.logits[:, context_length - 1 : -1]
        # Turn logits into logprobs
        all_logprobs = F.log_softmax(logits, dim=-1)
        # Take the completion tokens logprob
        logprobs = torch.take_along_dim(
            all_logprobs, completion_ids.unsqueeze(-1), dim=2
        ).squeeze(-1)
        del output, logits, all_logprobs  # free memory

        # Same for the reference model
        with torch.no_grad():
            if self.ref_model is not None:
                ref_output = self.ref_model(
                    prompt_completion_ids, attention_mask=prompt_completion_mask
                )
            else:  # peft case: we just need to disable the adapter
                with self.model.disable_adapter():
                    ref_output = self.model(
                        prompt_completion_ids,
                        attention_mask=prompt_completion_mask,
                    )
            ref_logits = ref_output.logits[:, context_length - 1 : -1]
            ref_all_logprobs = F.log_softmax(ref_logits, dim=-1)
            ref_logprobs = torch.take_along_dim(
                ref_all_logprobs, completion_ids.unsqueeze(-1), dim=2
            ).squeeze(-1)
            del ref_output, ref_logits, ref_all_logprobs  # free memory

        # ! compatible with latest trl version
        # ! apply rm_tokenizer
        device = prompt_completion_ids.device
        completion_ids = prompt_completion_ids[:, context_length:]
        completions = self.tokenizer.batch_decode(
            completion_ids, skip_special_tokens=True
        )
        if is_conversational({"prompt": prompts[0]}):
            completions = [
                [{"role": "assistant", "content": completion}]
                for completion in completions
            ]

        prompts = 2 * prompts
        if is_conversational({"prompt": prompts[0]}):
            examples = [
                {"prompt": p, "completion": c}
                for p, c in zip(prompts, completions)
            ]
            examples = [
                apply_chat_template(example, self.reward_processing_class)
                for example in examples
            ]
            prompts = [example["prompt"] for example in examples]
            completions = [example["completion"] for example in examples]

        # Tokenize the prompts
        prompts_ids = self.reward_processing_class(
            prompts, padding=True, return_tensors="pt", padding_side="left"
        )["input_ids"].to(device)
        context_length = prompts_ids.shape[1]

        # Tokenize the completions
        completions_ids = self.reward_processing_class(
            completions, padding=True, return_tensors="pt", padding_side="right"
        )["input_ids"].to(device)

        # Concatenate the prompts and completions and get the reward
        prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
        with torch.inference_mode():
            _, scores, _ = get_reward(
                self.reward_model,
                prompt_completion_ids,
                self.reward_processing_class.pad_token_id,
                context_length,
            )

            # Filter completion. Ensure that the sample contains stop_token_id
            # Completions not passing that filter will receive a lower score.
            if self.args.missing_eos_penalty is not None:
                scores[~contain_eos_token] -= self.args.missing_eos_penalty

        # Split the scores in 2 (the prompts of the first half are the same as the second half)
        first_half, second_half = scores.split(num_examples)

        # Get the indices of the chosen and rejected examples
        num_examples_range = torch.arange(num_examples, device=scores.device)
        mask = first_half >= second_half
        chosen_indices = num_examples_range + (~mask * num_examples)
        rejected_indices = num_examples_range + (mask * num_examples)

        # Build tensor so that the first half is the chosen examples and the second half the rejected examples
        cr_indices = torch.cat(
            (chosen_indices, rejected_indices), dim=0
        )  # cr = chosen and rejected
        cr_logprobs = logprobs[cr_indices]
        cr_ref_logprobs = ref_logprobs[cr_indices]

        # mask out the padding tokens
        padding_mask = ~completion_mask.bool()
        cr_padding_mask = padding_mask[cr_indices]

        cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
        cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)

        # Split the chosen and rejected examples
        chosen_logprobs_sum, rejected_logprobs_sum = torch.split(
            cr_logprobs_sum, num_examples
        )
        chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(
            cr_ref_logprobs_sum, num_examples
        )
        pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
        ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum

        logits = pi_logratios - ref_logratios

        if self.args.loss_type == "sigmoid":
            losses = -F.logsigmoid(self.beta * logits)
        elif self.args.loss_type == "ipo":
            losses = (logits - 1 / (2 * self.beta)) ** 2
        elif self.args.loss_type == "soft_sigmoid":
            p_oracle = self.compute_soft_label(
                scores[chosen_indices], scores[rejected_indices]
            )
            losses = -F.logsigmoid(
                self.beta * logits
            ) * p_oracle - F.logsigmoid(-self.beta * logits) * (1 - p_oracle)
        else:
            raise NotImplementedError(f"invalid loss type {self.loss_type}")

        loss = losses.mean()

        # Log everything
        self.stats["val/contain_eos_token"].append(
            contain_eos_token.float().mean().item()
        )
        self.stats["logps/chosen"].append(
            self.accelerator.gather(chosen_logprobs_sum).mean().item()
        )
        self.stats["logps/rejected"].append(
            self.accelerator.gather(rejected_logprobs_sum).mean().item()
        )
        self.stats["objective/scores"].append(
            self.accelerator.gather(scores.mean()).mean().item()
        )
        kl = logprobs - ref_logprobs
        mean_kl = kl.sum(1).mean()
        self.stats["objective/kl"].append(
            self.accelerator.gather(mean_kl).mean().item()
        )
        non_score_reward = (-self.beta * kl).sum(1)
        mean_non_score_reward = non_score_reward.mean()
        self.stats["objective/non_score_reward"].append(
            self.accelerator.gather(mean_non_score_reward).mean().item()
        )
        rlhf_reward = scores + non_score_reward
        self.stats["objective/rlhf_reward"].append(
            self.accelerator.gather(rlhf_reward).mean().item()
        )
        mean_entropy = -logprobs.sum(1).mean()
        self.stats["objective/entropy"].append(
            self.accelerator.gather(mean_entropy).mean().item()
        )
        scores_margin = scores[chosen_indices] - scores[rejected_indices]
        self.stats["objective/scores_margin"].append(
            self.accelerator.gather(scores_margin.mean()).mean().item()
        )
        chosen_rewards = self.beta * (
            chosen_logprobs_sum - chosen_ref_logprobs_sum
        )
        gathered_chosen_rewards = self.accelerator.gather(chosen_rewards)
        self.stats["rewards/chosen"].append(
            gathered_chosen_rewards.mean().item()
        )
        rejected_rewards = self.beta * (
            rejected_logprobs_sum - rejected_ref_logprobs_sum
        )
        gathered_rejected_rewards = self.accelerator.gather(rejected_rewards)
        self.stats["rewards/rejected"].append(
            gathered_rejected_rewards.mean().item()
        )
        margin = gathered_chosen_rewards - gathered_rejected_rewards
        self.stats["rewards/margins"].append(margin.mean().item())
        accuracy = margin > 0
        self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
        self.stats["beta"].append(self.beta)

        if (
            self.args.torch_empty_cache_steps is not None
            and self.state.global_step % self.args.torch_empty_cache_steps == 0
        ):
            empty_cache()

        kwargs = {}

        # For LOMO optimizers you need to explicitly use the learnign rate
        if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
            kwargs["learning_rate"] = self._get_learning_rate()

        if self.args.n_gpu > 1:
            loss = (
                loss.mean()
            )  # mean() to average on multi-gpu parallel training

        if self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            self.accelerator.backward(loss, **kwargs)

        return loss.detach() / self.args.gradient_accumulation_steps

    def _training_step_with_profiler(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> torch.Tensor:
        with torch.profiler.profile(
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ],
            record_shapes=True,
        ) as prof:
            model.train()

            inputs = self._prepare_inputs(inputs)
            prompts = inputs["prompt"]
            num_examples, context_length = inputs["prompt_input_ids"].shape
            prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
            prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)

            with torch.profiler.record_function("generate"):
                with unwrap_model_for_generation(
                    model, self.accelerator
                ) as unwrapped_model:
                    output = unwrapped_model.generate(
                        input_ids=prompt_ids,
                        attention_mask=prompt_mask,
                        generation_config=self.generation_config,
                        pad_token_id=self.tokenizer.pad_token_id,
                    )
                del inputs

            with torch.profiler.record_function("truncate_right"):
                completion_ids = output[:, context_length:]
                completion_ids, completion_mask = truncate_right(
                    completion_ids,
                    self.tokenizer.eos_token_id,
                    self.tokenizer.pad_token_id,
                )
                contain_eos_token = torch.any(
                    completion_ids == self.tokenizer.eos_token_id, dim=-1
                )
                prompt_completion_ids = torch.cat(
                    (prompt_ids, completion_ids), dim=1
                )
                prompt_completion_mask = torch.cat(
                    (prompt_mask, completion_mask), dim=1
                )

            with torch.profiler.record_function("get_logprobs"):
                output = model(
                    prompt_completion_ids, attention_mask=prompt_completion_mask
                )
                # There is 1 offset, because the model predict the next token
                logits = output.logits[:, context_length - 1 : -1]
                # Turn logits into logprobs
                all_logprobs = F.log_softmax(logits, dim=-1)
                # Take the completion tokens logprob
                logprobs = torch.take_along_dim(
                    all_logprobs, completion_ids.unsqueeze(-1), dim=2
                ).squeeze(-1)
                del output, logits, all_logprobs  # free memory

            with torch.profiler.record_function("get_ref_logprobs"):
                with torch.no_grad():
                    if self.ref_model is not None:
                        ref_output = self.ref_model(
                            prompt_completion_ids,
                            attention_mask=prompt_completion_mask,
                        )
                    else:  # peft case: we just need to disable the adapter
                        with self.model.disable_adapter():
                            ref_output = self.model(
                                prompt_completion_ids,
                                attention_mask=prompt_completion_mask,
                            )
                    ref_logits = ref_output.logits[:, context_length - 1 : -1]
                    ref_all_logprobs = F.log_softmax(ref_logits, dim=-1)
                    ref_logprobs = torch.take_along_dim(
                        ref_all_logprobs, completion_ids.unsqueeze(-1), dim=2
                    ).squeeze(-1)
                    del ref_output, ref_logits, ref_all_logprobs  # free memory

            with torch.profiler.record_function("get_reward_completions"):
                device = prompt_completion_ids.device
                completion_ids = prompt_completion_ids[:, context_length:]
                completions = self.tokenizer.batch_decode(
                    completion_ids, skip_special_tokens=True
                )
                if is_conversational({"prompt": prompts[0]}):
                    completions = [
                        [{"role": "assistant", "content": completion}]
                        for completion in completions
                    ]

                prompts = 2 * prompts
                if is_conversational({"prompt": prompts[0]}):
                    examples = [
                        {"prompt": p, "completion": c}
                        for p, c in zip(prompts, completions)
                    ]
                    examples = [
                        apply_chat_template(
                            example, self.reward_processing_class
                        )
                        for example in examples
                    ]
                    prompts = [example["prompt"] for example in examples]
                    completions = [
                        example["completion"] for example in examples
                    ]

                # Tokenize the prompts
                prompts_ids = self.reward_processing_class(
                    prompts,
                    padding=True,
                    return_tensors="pt",
                    padding_side="left",
                )["input_ids"].to(device)
                context_length = prompts_ids.shape[1]

                # Tokenize the completions
                completions_ids = self.reward_processing_class(
                    completions,
                    padding=True,
                    return_tensors="pt",
                    padding_side="right",
                )["input_ids"].to(device)

            with torch.profiler.record_function("get_reward_scores"):
                # Concatenate the prompts and completions and get the reward
                prompt_completion_ids = torch.cat(
                    (prompts_ids, completions_ids), dim=1
                )
                with torch.inference_mode():
                    _, scores, _ = get_reward(
                        self.reward_model,
                        prompt_completion_ids,
                        self.reward_processing_class.pad_token_id,
                        context_length,
                    )

                    # Filter completion. Ensure that the sample contains stop_token_id
                    # Completions not passing that filter will receive a lower score.
                    if self.args.missing_eos_penalty is not None:
                        scores[
                            ~contain_eos_token
                        ] -= self.args.missing_eos_penalty

                # Split the scores in 2 (the prompts of the first half are the same as the second half)
                first_half, second_half = scores.split(num_examples)

                # Get the indices of the chosen and rejected examples
                num_examples_range = torch.arange(
                    num_examples, device=scores.device
                )
                mask = first_half >= second_half
                chosen_indices = num_examples_range + (~mask * num_examples)
                rejected_indices = num_examples_range + (mask * num_examples)

            with torch.profiler.record_function("get_reward_logprobs_sum"):
                # Build tensor so that the first half is the chosen examples and the second half the rejected examples
                cr_indices = torch.cat(
                    (chosen_indices, rejected_indices), dim=0
                )  # cr = chosen and rejected
                cr_logprobs = logprobs[cr_indices]
                cr_ref_logprobs = ref_logprobs[cr_indices]

                # mask out the padding tokens
                padding_mask = ~completion_mask.bool()
                cr_padding_mask = padding_mask[cr_indices]

                cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
                cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(
                    1
                )

                # Split the chosen and rejected examples
                chosen_logprobs_sum, rejected_logprobs_sum = torch.split(
                    cr_logprobs_sum, num_examples
                )
                chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = (
                    torch.split(cr_ref_logprobs_sum, num_examples)
                )
                pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
                ref_logratios = (
                    chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
                )

                logits = pi_logratios - ref_logratios

            with torch.profiler.record_function("get_loss"):
                if self.args.loss_type == "sigmoid":
                    losses = -F.logsigmoid(self.beta * logits)
                elif self.args.loss_type == "ipo":
                    losses = (logits - 1 / (2 * self.beta)) ** 2
                elif self.args.loss_type == "soft_sigmoid":
                    p_oracle = self.compute_soft_label(
                        scores[chosen_indices], scores[rejected_indices]
                    )
                    losses = -F.logsigmoid(
                        self.beta * logits
                    ) * p_oracle - F.logsigmoid(-self.beta * logits) * (
                        1 - p_oracle
                    )
                else:
                    raise NotImplementedError(
                        f"invalid loss type {self.loss_type}"
                    )

                loss = losses.mean()

            with torch.profiler.record_function("get_stats"):
                # Log everything
                self.stats["val/contain_eos_token"].append(
                    contain_eos_token.float().mean().item()
                )
                self.stats["logps/chosen"].append(
                    self.accelerator.gather(chosen_logprobs_sum).mean().item()
                )
                self.stats["logps/rejected"].append(
                    self.accelerator.gather(rejected_logprobs_sum).mean().item()
                )
                self.stats["objective/scores"].append(
                    self.accelerator.gather(scores.mean()).mean().item()
                )
                kl = logprobs - ref_logprobs
                mean_kl = kl.sum(1).mean()
                self.stats["objective/kl"].append(
                    self.accelerator.gather(mean_kl).mean().item()
                )
                non_score_reward = (-self.beta * kl).sum(1)
                mean_non_score_reward = non_score_reward.mean()
                self.stats["objective/non_score_reward"].append(
                    self.accelerator.gather(mean_non_score_reward).mean().item()
                )
                rlhf_reward = scores + non_score_reward
                self.stats["objective/rlhf_reward"].append(
                    self.accelerator.gather(rlhf_reward).mean().item()
                )
                mean_entropy = -logprobs.sum(1).mean()
                self.stats["objective/entropy"].append(
                    self.accelerator.gather(mean_entropy).mean().item()
                )
                scores_margin = (
                    scores[chosen_indices] - scores[rejected_indices]
                )
                self.stats["objective/scores_margin"].append(
                    self.accelerator.gather(scores_margin.mean()).mean().item()
                )
                chosen_rewards = self.beta * (
                    chosen_logprobs_sum - chosen_ref_logprobs_sum
                )
                gathered_chosen_rewards = self.accelerator.gather(
                    chosen_rewards
                )
                self.stats["rewards/chosen"].append(
                    gathered_chosen_rewards.mean().item()
                )
                rejected_rewards = self.beta * (
                    rejected_logprobs_sum - rejected_ref_logprobs_sum
                )
                gathered_rejected_rewards = self.accelerator.gather(
                    rejected_rewards
                )
                self.stats["rewards/rejected"].append(
                    gathered_rejected_rewards.mean().item()
                )
                margin = gathered_chosen_rewards - gathered_rejected_rewards
                self.stats["rewards/margins"].append(margin.mean().item())
                accuracy = margin > 0
                self.stats["rewards/accuracies"].append(
                    accuracy.float().mean().item()
                )
                self.stats["beta"].append(self.beta)

                if (
                    self.args.torch_empty_cache_steps is not None
                    and self.state.global_step
                    % self.args.torch_empty_cache_steps
                    == 0
                ):
                    empty_cache()

                kwargs = {}

                # For LOMO optimizers you need to explicitly use the learnign rate
                if self.args.optim in [
                    OptimizerNames.LOMO,
                    OptimizerNames.ADALOMO,
                ]:
                    kwargs["learning_rate"] = self._get_learning_rate()

                if self.args.n_gpu > 1:
                    loss = (
                        loss.mean()
                    )  # mean() to average on multi-gpu parallel training

                if self.use_apex:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    self.accelerator.backward(loss, **kwargs)

            print(
                prof.key_averages().table(
                    sort_by="self_cuda_time_total", row_limit=-1
                )
            )
            return loss.detach() / self.args.gradient_accumulation_steps
