from typing import Any, Callable, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import Dataset, IterableDataset
from transformers import (
    BaseImageProcessor,
    FeatureExtractionMixin,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    ProcessorMixin,
    TrainerCallback,
    is_wandb_available,
)
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import OptimizerNames
from transformers.generation import GenerationConfig
from trl.data_utils import is_conversational, maybe_apply_chat_template
from trl.models.utils import unwrap_model_for_generation
from trl.trainer.judges import BasePairwiseJudge
from trl.trainer.online_dpo_trainer import OnlineDPOTrainer
from trl.trainer.utils import (
    empty_cache,
    truncate_right,
)

from .DualModel import DualModel, switch_model, is_ddp_wrapped
from .stackelberg_gda_config import StackelbergPGConfig

# TODO: Add apex support
# if is_apex_available():
#     from apex import amp

if is_wandb_available():
    pass


class StackelbergPGTrainer(OnlineDPOTrainer):
    r"""
    Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].

    Args:
        model (`transformers.PreTrainedModel`):
            The model to train, preferably an `AutoModelForCausalLM`.
        ref_model (`PreTrainedModelWrapper`):
            Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
            reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
        reward_model (`transformers.PreTrainedModel`):
            The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
        judge (`BasePairwiseJudge`):
            The judge to use for pairwise comparison of model completions.
        args (`NashMDConfig`):
            The NashMD config arguments to use for training.
        data_collator (`transformers.DataCollator`):
            The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
            which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
        train_dataset (`datasets.Dataset`):
            The dataset to use for training.
        eval_dataset (`datasets.Dataset`):
            The dataset to use for evaluation.
        processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
            Processing class used to process the data. If provided, will be used to automatically process the inputs
            for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
            reuse the fine-tuned model.
        peft_config (`dict`):
            The peft config to use for training.
        compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
            The function to use to compute the metrics. Must take a `EvalPrediction` and return
            a dictionary string to metric values.
        callbacks (`list[transformers.TrainerCallback]`):
            The callbacks to use for training.
        optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
            The optimizer and scheduler to use for training.
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
            The function to use to preprocess the logits before computing the metrics.
    """

    _tag_names = ["trl", "stackelberg-pg"]

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        ref_model: Union[PreTrainedModel, nn.Module] = None,
        reward_model: Union[PreTrainedModel, nn.Module, None] = None,
        judge: Optional[BasePairwiseJudge] = None,
        args: Optional[StackelbergPGConfig] = None,
        data_collator: Optional[Callable] = None,
        train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
        eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
        processing_class: Optional[
            Union[
                PreTrainedTokenizerBase,
                BaseImageProcessor,
                FeatureExtractionMixin,
                ProcessorMixin,
            ]
        ] = None,
        peft_config: Optional[dict] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
        callbacks: Optional[list[TrainerCallback]] = None,
        optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
            None,
            None,
        ),
        preprocess_logits_for_metrics: Optional[
            Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
        ] = None,
    ) -> None:
        super().__init__(
            model=model,
            ref_model=ref_model,
            reward_model=reward_model,
            judge=judge,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            reward_processing_class=processing_class,  # for now, NashMDTrainer can't use any reward model
            peft_config=peft_config,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        )

        self._follower_weight = self.args.follower_weight
        self._follower_prompt = self.args.follower_prompt
        self.generation_config.top_k = self.args.top_k
        self.generation_config.top_p = self.args.top_p

        # Overwrite the stats dictionary to include NashMD specific statistics
        self.stats = {
            "rewards/probabilities": [],
            "rewards/margins": [],
            "loss/loss_leader": [],
            "loss/loss_follower": [],
            "loss/score_leader": [],
            "loss/score_follower": [],
            "loss/kl_leader": [],
            "loss/kl_follower": [],
            "logps/leader": [],
            "logps/follower": [],
            "objective/entropy_leader": [],
            "objective/entropy_follower": [],
            "val/leader_contain_eos_token": [],
            "val/follower_contain_eos_token": [],
            "beta": [],
            "follower_weight": [],
            # "rewards/accuracies": [],
        }
        if self.reward_model is not None:
            self.stats["rewards/chosen"] = []
            self.stats["rewards/rejected"] = []

    def _generate_completions(self, model, prompts):
        """
        Generate completions for the given prompts using the model.
        :param model:
        :param prompts: Dict of the structure:
            prompts = {
                "input_ids": Tensor, Shape: (batch_size, input_length),
                "attention_mask": Tensor, Shape: (batch_size, input_length),
                "raw": str or List[Dict[str, str]],  # Raw prompt text (maybe conversational structure)
            }
        :return:
        """
        generation_config = GenerationConfig(
            max_new_tokens=self.args.max_new_tokens,
            temperature=(self.args.generation_temperature + 1e-7),
            top_k=self.args.top_k,
            top_p=self.args.top_p,
            do_sample=True,
        )
        with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
            if self.args.separate_follower_model:
                assert ~model.module.use_follower
            # Generate Leader Data
            leader_output = unwrapped_model.generate(
                input_ids=prompts["input_ids"],
                attention_mask=prompts["attention_mask"],
                generation_config=generation_config,
            )  # Tensor, Shape: (batch_size, input_length+output_length)
            leader_data = self._process_completions(leader_output, prompts)
            del leader_output

            # Generate Prompt for the Follower
            follower_prompts = []
            for prompt, leader_completion_text in zip(
                prompts["raw"], leader_data["completion_text"]
            ):
                if is_conversational({"prompt": prompts["raw"][0]}):
                    follower_prompt = (
                        prompt
                        + [{"role": "assistant", "content": leader_completion_text}]
                        + [{"role": "user", "content": self._follower_prompt}]
                    )
                    follower_prompts.append({"prompt": follower_prompt})
                else:
                    follower_prompts.append(
                        prompt + leader_completion_text + self._follower_prompt
                    )

            follower_inputs = [
                maybe_apply_chat_template(x, self.processing_class)
                for x in follower_prompts
            ]
            follower_inputs = [
                self.tokenize_row(
                    x, self.model.config.is_encoder_decoder, self.processing_class
                )
                for x in follower_inputs
            ]
            follower_inputs = self.data_collator(follower_inputs)
            follower_inputs = {
                key: value.to(self.model.device)
                for key, value in follower_inputs.items()
            }
            follower_prompts = {
                "input_ids": follower_inputs["prompt_input_ids"],
                "attention_mask": follower_inputs["prompt_attention_mask"],
                "raw": follower_prompts,
            }
            if self.args.separate_follower_model:
                switch_model(unwrapped_model, "follower")
                assert unwrapped_model.use_follower

            follower_output = unwrapped_model.generate(
                input_ids=follower_inputs["prompt_input_ids"],
                attention_mask=follower_inputs["prompt_attention_mask"],
                generation_config=generation_config,
            )  # Tensor, Shape: (batch_size, input_length+output_length)
            if self.args.separate_follower_model:
                switch_model(unwrapped_model, "leader")
                assert ~unwrapped_model.use_follower

            follower_data = self._process_completions(follower_output, follower_prompts)
            del follower_output, follower_inputs, follower_prompts
            torch.cuda.empty_cache()
        return leader_data, follower_data

    def _process_completions(self, output, prompts):
        """
        Process the completions generated by the model.
        Adds completion mask and pads out all non-padding token after the first eos_token
        """
        context_length = prompts["input_ids"].shape[1]
        completion_ids = output[:, context_length:]
        completion_ids, completion_mask = truncate_right(
            completion_ids,
            self.processing_class.eos_token_id,
            self.processing_class.pad_token_id,
        )
        completion_text = self.processing_class.batch_decode(
            completion_ids, skip_special_tokens=True
        )
        completion_text = [completion.strip() for completion in completion_text]
        completion_ids = torch.cat((prompts["input_ids"], completion_ids), dim=1)
        completion_mask = torch.cat((prompts["attention_mask"], completion_mask), dim=1)
        return {
            "input_ids": completion_ids,
            "attention_mask": completion_mask,
            "raw": prompts["raw"],
            "completion_text": completion_text,
            "context_length": context_length,
        }

    def _compute_rewards(self, leader_data, follower_data):
        raise NotImplementedError

    def _compute_judge(self, leader_data, follower_data):
        prompts = leader_data["raw"]
        leader_data_completions = leader_data["completion_text"]
        follower_data_completions = follower_data["completion_text"]
        if is_conversational({"prompt": prompts[0]}):
            leader_data_completions = [
                [{"role": "assistant", "content": completion}]
                for completion in leader_data_completions
            ]
            follower_data_completions = [
                [{"role": "assistant", "content": completion}]
                for completion in follower_data_completions
            ]
        if self.judge.missing_eos_penalty is not None:
            leader_contain_eos = torch.any(
                leader_data["input_ids"][:, leader_data["context_length"] :]
                == self.processing_class.eos_token_id,
                dim=-1,
            ).tolist()
            follower_contain_eos = torch.any(
                follower_data["input_ids"][:, follower_data["context_length"] :]
                == self.processing_class.eos_token_id,
                dim=-1,
            ).tolist()
            contain_eos_tokens = list(zip(leader_contain_eos, follower_contain_eos))
        else:
            contain_eos_tokens = None
        probability = self.judge.judge(
            prompts,
            list(zip(leader_data_completions, follower_data_completions)),
            contain_eos_tokens=contain_eos_tokens,
            return_scores=True,
        )
        return torch.tensor(probability, device=leader_data["input_ids"].device)

    def _compute_logprobs(
        self, model, model_data, with_grad=True, disable_adapters=False
    ):
        context_length = model_data["context_length"]
        # Compute logprobs for model completions under the model
        if with_grad and not disable_adapters:
            output = model(
                model_data["input_ids"], attention_mask=model_data["attention_mask"]
            )
        elif with_grad and disable_adapters:
            with model.disable_adapter():
                output = model(
                    model_data["input_ids"], attention_mask=model_data["attention_mask"]
                )
        elif ~with_grad and disable_adapters:
            with model.disable_adapter():
                with torch.no_grad():
                    output = model(
                        model_data["input_ids"],
                        attention_mask=model_data["attention_mask"],
                    )
        else:
            with torch.no_grad():
                output = model(
                    model_data["input_ids"], attention_mask=model_data["attention_mask"]
                )
        logits = output.logits[:, context_length - 1 : -1]
        logits /= self.args.temperature + 1e-7
        logprobs = F.log_softmax(logits, dim=-1)
        model_logprobs_model_data = torch.gather(
            logprobs, 2, model_data["input_ids"][:, context_length:].unsqueeze(-1)
        ).squeeze(-1)
        # Mask padding tokens
        model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
        model_logprobs_model_data = model_logprobs_model_data.masked_fill(
            model_padding_mask, 0.0
        )
        return model_logprobs_model_data

    def _compute_losses(self, model_logprobs, ref_logprobs, probability, beta=None):
        # reinforce score where score_baseline (default: 0.5) is a control variate
        if self.args.score_baseline is not None:
            probability = probability - self.args.score_baseline
        if self.args.rloo_baseline:
            prob_rloo = (probability.sum() - probability) / (probability.shape[0] - 1)
            probability = probability - prob_rloo

        score = probability * model_logprobs.sum(1)

        # kl divergence via reinforce
        with torch.no_grad():
            logr = ref_logprobs - model_logprobs
            if self.args.kl_estimator == "k1":
                log_ratio = -logr
            elif self.args.kl_estimator == "k2":
                raise NotImplementedError
            elif self.args.kl_estimator == "k3":
                log_ratio = (torch.exp(logr) - 1) - logr
            else:
                raise ValueError("kl_estimator must be one of k1, k2, or k3")
            kl_div_log = log_ratio.sum(1)
        kl_div_loss = (log_ratio * model_logprobs).sum(1)

        # final loss
        if beta is None:
            beta = self.beta
        loss = beta * kl_div_loss - score

        return loss.mean(), score, kl_div_log

    def _log_statistics(
        self,
        leader_data,
        follower_data,
        model_logprobs_leader_data,
        model_logprobs_follower_data,
        probability,
        loss_leader,
        loss_follower,
        score_leader,
        score_follower,
        kl_div_leader,
        kl_div_follower,
    ):
        # Helper function to gather and compute mean
        def gather_mean(tensor):
            return self.accelerator.gather_for_metrics(tensor).mean().item()

        # Log score
        self.stats["loss/score_leader"].append(gather_mean(score_leader))
        self.stats["loss/score_follower"].append(gather_mean(score_follower))
        # Log KL divergence
        self.stats["loss/kl_leader"].append(gather_mean(kl_div_leader))
        self.stats["loss/kl_follower"].append(gather_mean(kl_div_follower))
        # Log loss
        self.stats["loss/loss_leader"].append(gather_mean(loss_leader))
        self.stats["loss/loss_follower"].append(gather_mean(loss_follower))

        # Log logprobs
        model_logprobs_leader_data_sum = model_logprobs_leader_data.sum(1)
        model_logprobs_follower_data_sum = model_logprobs_follower_data.sum(1)

        self.stats["logps/leader"].append(gather_mean(model_logprobs_leader_data_sum))
        self.stats["logps/follower"].append(
            gather_mean(model_logprobs_follower_data_sum)
        )

        # Log rewards
        if self.reward_model is not None:
            # self.stats["rewards/chosen"].append(gather_mean(model_scores))
            # self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
            raise NotImplementedError(
                "Logging with a reward model is not yet implemented!"
            )

        # Log probabilities
        self.stats["rewards/probabilities"].append(gather_mean(probability))

        # Calculate entropy for model data
        entropy_model_leader = -model_logprobs_leader_data.sum(1)
        entropy_model_follower = -model_logprobs_follower_data.sum(1)
        self.stats["objective/entropy_leader"].append(gather_mean(entropy_model_leader))
        self.stats["objective/entropy_follower"].append(
            gather_mean(entropy_model_follower)
        )

        # Calculate margins
        margin = model_logprobs_leader_data_sum - model_logprobs_follower_data_sum
        self.stats["rewards/margins"].append(gather_mean(margin))

        # # Calculate accuracy
        # accuracy = (margin > 0).float()
        # self.stats["rewards/accuracies"].append(gather_mean(accuracy))

        # Log EOS token statistics
        leader_eos = (
            leader_data["input_ids"][:, leader_data["context_length"] :]
            == self.processing_class.eos_token_id
        ).any(dim=1)
        follower_eos = (
            follower_data["input_ids"][:, follower_data["context_length"] :]
            == self.processing_class.eos_token_id
        ).any(dim=1)
        self.stats["val/leader_contain_eos_token"].append(
            gather_mean(leader_eos.float())
        )
        self.stats["val/follower_contain_eos_token"].append(
            gather_mean(follower_eos.float())
        )

        # Log beta and follower_weight
        self.stats["beta"].append(self.beta)
        self.stats["follower_weight"].append(self._follower_weight)

    def training_step(
        self,
        model: nn.Module,
        inputs: dict[str, Union[torch.Tensor, Any]],
        num_items_in_batch: Optional[int] = None,
    ) -> torch.Tensor:
        # Apply chat template and tokenize the input
        batch_size = len(next(iter(inputs.values())))
        prompts = inputs["prompt"]
        inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
        inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
        inputs = [
            self.tokenize_row(
                x, self.model.config.is_encoder_decoder, self.processing_class
            )
            for x in inputs
        ]
        inputs = self.data_collator(inputs)

        # need the prompt_ only
        inputs = self._prepare_inputs(inputs)
        prompts = {
            "input_ids": inputs["prompt_input_ids"],
            "attention_mask": inputs["prompt_attention_mask"],
            "raw": prompts,
        }
        del inputs

        # Sample completions from both the leader and the follower
        leader_data, follower_data = self._generate_completions(model, prompts)

        # Compute rewards
        if self.reward_model is not None:
            model_scores, mixture_scores = self._compute_rewards(
                leader_data, follower_data
            )
            # probability of the model data vs the mixture data
            probability = F.sigmoid(model_scores - mixture_scores)
        else:
            model_scores, mixture_scores = None, None
            probability = self._compute_judge(leader_data, follower_data)

        # Learning prep
        kwargs = {}
        if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
            kwargs["learning_rate"] = self._get_learning_rate()

        # Compute reference logprobs
        ref_logprobs_follower_data, ref_logprobs_leader_data = (
            self._compute_ref_model_logprobs(follower_data, leader_data, model)
        )

        # Compute the loss for the Leader
        if self.args.separate_follower_model:
            switch_model(model, "leader")
            assert ~model.module.use_follower
        model_logprobs_leader_data = self._compute_logprobs(
            model, leader_data, with_grad=True
        )
        leader_contain_eos = torch.detach(
            torch.any(
                leader_data["input_ids"][:, leader_data["context_length"] :]
                == self.processing_class.eos_token_id,
                dim=-1,
            )
        )  # Tensor, Shape: (batch_size,)
        loss_leader, score_leader, kl_div_leader = self._compute_losses(
            model_logprobs_leader_data,
            ref_logprobs_leader_data,
            probability
            - self.args.missing_eos_probability_penalty * (~leader_contain_eos).float(),
        )

        # Compute loss for the Follower
        if self.args.separate_follower_model:
            # Do a backward pass for the Leader first
            if self.args.n_gpu > 1:
                leader_loss = loss_leader.mean()
            else:
                leader_loss = loss_leader

            if self.use_apex:
                raise NotImplementedError("Apex support is not yet implemented.")
            else:
                self.accelerator.backward(leader_loss, **kwargs)

            # Create a separate computation graph for the follower to prevent DDP issues
            if is_ddp_wrapped(model):
                # Save original DDP state
                original_ddp_sync = model.require_backward_grad_sync
                # Temporarily disable gradient synchronization
                model.require_backward_grad_sync = False

            # Switch to follower adapter
            switch_model(model, "follower")
            assert model.module.use_follower
            model_logprobs_follower_data = self._compute_logprobs(
                model, follower_data, with_grad=True
            )
            # Restore DDP settings
            if is_ddp_wrapped(model):
                model.require_backward_grad_sync = original_ddp_sync

            follower_contain_eos = torch.detach(
                torch.any(
                    follower_data["input_ids"][:, follower_data["context_length"] :]
                    == self.processing_class.eos_token_id,
                    dim=-1,
                )
            )

            loss_follower, score_follower, kl_div_follower = self._compute_losses(
                model_logprobs_follower_data,
                ref_logprobs_follower_data,
                1.0
                - probability
                - self.args.missing_eos_probability_penalty
                * (~follower_contain_eos).float(),
                beta=self.args.follower_beta,
            )

            if self.args.n_gpu > 1:
                follower_loss = loss_follower.mean()
            else:
                follower_loss = loss_follower

            # Apply backward pass for follower model
            self.accelerator.backward(follower_loss, **kwargs)

            # Calculate total loss for reporting
            total_loss = (leader_loss + follower_loss) / 2.0

            # Switch back to leader adapter for next iteration
            switch_model(model, "leader")
        else:
            model_logprobs_follower_data = self._compute_logprobs(
                model, follower_data, with_grad=True
            )
            # switch_peft_adapter(model, "leader")
            follower_contain_eos = torch.detach(
                torch.any(
                    follower_data["input_ids"][:, follower_data["context_length"] :]
                    == self.processing_class.eos_token_id,
                    dim=-1,
                )
            )  # Tensor, Shape: (batch_size,)
            loss_follower, score_follower, kl_div_follower = self._compute_losses(
                model_logprobs_follower_data,
                ref_logprobs_follower_data,
                1.0
                - probability  # Negate the probability for the follower
                - self.args.missing_eos_probability_penalty
                * (~follower_contain_eos).float(),
            )

            assert (
                self.args.leader_update_frequency >= 1
            ), "leader_update_frequency must be >= 1"
            if self.args.leader_update_frequency == 1:
                total_loss = loss_leader + self._follower_weight * loss_follower
            elif (self.state.global_step + 1) % self.args.leader_update_frequency == 0:
                total_loss = loss_leader
            else:
                total_loss = self._follower_weight * loss_follower

            if self.args.n_gpu > 1:
                total_loss = (
                    total_loss.mean()
                )  # mean() to average on multi-gpu parallel training
            if self.use_apex:
                raise NotImplementedError("Apex support is not yet implemented.")
                # with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                #     scaled_loss.backward()
            else:
                self.accelerator.backward(total_loss, **kwargs)

        # Training update preparation
        if (
            self.args.torch_empty_cache_steps is not None
            and self.state.global_step % self.args.torch_empty_cache_steps == 0
        ):
            empty_cache()

        # Log everything
        self._log_statistics(
            leader_data,
            follower_data,
            model_logprobs_leader_data,
            model_logprobs_follower_data,
            probability,
            loss_leader,
            loss_follower,
            score_leader,
            score_follower,
            kl_div_leader,
            kl_div_follower,
        )

        if self.accelerator.sync_gradients and self.args.max_clip_grad_norm > 0:
            self.accelerator.clip_grad_norm_(
                self.model.parameters(),
                self.args.max_clip_grad_norm,
            )

        return total_loss.detach() / self.args.gradient_accumulation_steps

    def _compute_ref_model_logprobs(self, follower_data, leader_data, model):
        """
        Compute the reference model's logprobabilities for the Leader and Follower completions
        Used to calculate the KL-divergence regularization
        """

        if self.args.standard_follower_kl_regularization:
            ref_model_data = follower_data
        else:
            ref_model_data = {
                key_name: torch.cat(
                    [
                        leader_data[key_name][:, : leader_data["context_length"]],
                        follower_data[key_name][:, follower_data["context_length"] :],
                    ],
                    dim=1,
                )
                for key_name in ["input_ids", "attention_mask"]
            }
            ref_model_data["context_length"] = leader_data["context_length"]
        if self.ref_model is None:
            with unwrap_model_for_generation(model, self.accelerator) as ref_model:
                if isinstance(ref_model, DualModel):
                    ref_logprobs_leader_data = self._compute_logprobs(
                        ref_model.leader,
                        leader_data,
                        with_grad=False,
                        disable_adapters=True,
                    )
                    ref_logprobs_follower_data = self._compute_logprobs(
                        ref_model.follower,
                        ref_model_data,
                        with_grad=False,
                        disable_adapters=True,
                    )
                else:
                    ref_logprobs_leader_data = self._compute_logprobs(
                        ref_model, leader_data, with_grad=False, disable_adapters=True
                    )
                    ref_logprobs_follower_data = self._compute_logprobs(
                        ref_model,
                        ref_model_data,
                        with_grad=False,
                        disable_adapters=True,
                    )
        else:
            ref_logprobs_leader_data = self._compute_logprobs(
                self.ref_model, leader_data, with_grad=False
            )
            ref_logprobs_follower_data = self._compute_logprobs(
                self.ref_model, ref_model_data, with_grad=False
            )
        return ref_logprobs_follower_data, ref_logprobs_leader_data

    def create_model_card(
        self,
        model_name: Optional[str] = None,
        dataset_name: Optional[str] = None,
        tags: Union[str, list[str], None] = None,
    ):
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            model_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the model.
            dataset_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the dataset used for training.
            tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
                Tags to be associated with the model card.
        """
        raise NotImplementedError
        # if not self.is_world_process_zero():
        #     return
        #
        # if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(
        #     self.model.config._name_or_path
        # ):
        #     base_model = self.model.config._name_or_path
        # else:
        #     base_model = None
        #
        # tags = tags or []
        # if isinstance(tags, str):
        #     tags = [tags]
        #
        # if hasattr(self.model.config, "unsloth_version"):
        #     tags.append("unsloth")
        #
        # # TODO: UPDATE CITATION
        # citation = textwrap.dedent("""""")
        #
        # model_card = generate_model_card(
        #     base_model=base_model,
        #     model_name=model_name,
        #     hub_model_id=self.hub_model_id,
        #     dataset_name=dataset_name,
        #     tags=tags,
        #     wandb_url=(
        #         wandb.run.get_url()
        #         if is_wandb_available() and wandb.run is not None
        #         else None
        #     ),
        #     comet_url=get_comet_experiment_url(),
        #     trainer_name="StackelbergPG",
        #     trainer_citation=citation,
        #     paper_title="NA",  # TODO: Update with the paper title
        #     paper_id="NA",
        # )
        #
        # model_card.save(os.path.join(self.args.output_dir, "README.md"))
