import os
import sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, Dataset, concatenate_datasets
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    DataCollator,
    Trainer,
)
from trl import (
    ORPOConfig,
    ORPOTrainer,
    DPOConfig,
    DPOTrainer,
    SFTConfig,
    SFTTrainer,
    RewardConfig,
    RewardTrainer,
    ModelConfig,
)
from tqdm import tqdm
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from accelerate import PartialState
import seaborn as sns
import matplotlib.pyplot as plt
import logging
import random
import re
import pickle
import gc
import warnings
from functools import partial, wraps
from collections import defaultdict
from contextlib import nullcontext
from itertools import cycle
from sklearn.model_selection import train_test_split
from typing_extensions import Literal
from evaluate import load

 
from peft import AutoPeftModelForCausalLM, LoraConfig, prepare_model_for_kbit_training
import bitsandbytes as bnb
optim_8bit = bnb.optim.Adam8bit
 

## All code adapted from TRL Library 

@dataclass
class ScriptArguments:
    """
    The arguments for the DRDO training script.
    """

    # data parameters
    beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
    student_model_name_or_path: Optional[str] = field(
        default="##student_policy_path##",
        metadata={"help": "the location of the SFT model name or the student model or path"},
    )

    teacher_model_name_or_path: Optional[str] = field(
        default="##teacher or oracle model##",
        metadata={"help": "the location of the SFT model name or path or the teacher model "},
    )
        # facebook/opt-1.3b previously 
    trainer_teacher_rm: Optional[bool] = field(
        default=True, metadata={"help": "whether to use the trainer RM as teacher for DRDO training"}
    )
    learning_rate: Optional[float] = field(default=5e-6, metadata={"help": "optimizer learning rate"})
    lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
    warmup_steps: Optional[int] = field(default=10, metadata={"help": "the number of warmup steps"})
    weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
    optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
    loss_type: Optional[str] = field(default="sigmoid", metadata={"help": "the loss type you want to test your policy on"})

    per_device_train_batch_size: Optional[int] = field(default=6, metadata={"help": "train batch size per device"})
    per_device_eval_batch_size: Optional[int] = field(default=4, metadata={"help": "eval batch size per device"})
    gradient_accumulation_steps: Optional[int] = field(
        default=8, metadata={"help": "the number of gradient accumulation steps"}
    )
    gradient_checkpointing: Optional[bool] = field(
        default=True, metadata={"help": "whether to use gradient checkpointing"}
    )
    gradient_checkpointing_use_reentrant: Optional[bool] = field(
        default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
    )

    lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
    lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
    lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
    dataset: Optional[str] = field(default="ultrafeedback_binarized", metadata={"help": "the dataset used for training and evaluation "})

    max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
    max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
    max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
    logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
    save_steps: Optional[int] = field(default=200, metadata={"help": "the saving frequency"})
    save_strategy: Optional[str] = field(default="no", metadata={"help": "whether to save intermediate steps during training"})
    eval_steps: Optional[int] = field(default=200, metadata={"help": "the evaluation frequency"})
    
    output_dir: Optional[str] = field(default="./results_falcon", metadata={"help": "the output directory"})
    log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
    load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
    model_dtype: Optional[str] = field(
        default="bfloat16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
    ) 
    sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
    report_to: Optional[str] = field(
        default="wandb",
        metadata={
            "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
            '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
            'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
        },
    )
   
    ignore_bias_buffers: Optional[bool] = field(
        default=False,
        metadata={
            "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
            "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
        },
    )
    seed: Optional[int] = field(
        default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
    )


@dataclass
class DROTrainer(DPOTrainer):
    def __init__(
        self,
        model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
        ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
        beta: float = 0.1,
        label_smoothing: float = 0,
        loss_type: Optional[Literal["sigmoid", "DRDO", "robust_fisch"]] = None,
        args: Optional[DROConfig] = None,
        data_collator: Optional[DataCollator] = None,
        label_pad_token_id: int = -100,
        padding_value: Optional[int] = None,
        truncation_mode: str = "keep_end",
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        max_length: Optional[int] = None,
        max_prompt_length: Optional[int] = None,
        max_target_length: Optional[int] = None,
        peft_config: Optional[Dict] = None,
        is_encoder_decoder: Optional[bool] = None,
        disable_dropout: bool = True,
        generate_during_eval: bool = False,
        compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
        precompute_ref_log_probs: bool = False,
        dataset_num_proc: Optional[int] = None,
        model_init_kwargs: Optional[Dict] = None,
        ref_model_init_kwargs: Optional[Dict] = None,
        model_adapter_name: Optional[str] = None,
        ref_adapter_name: Optional[str] = None,
        reference_free: bool = False,
        force_use_ref_model: bool = False,
    ):
        super().__init__(
            model=model,
            ref_model=ref_model,
            beta=beta,
            label_smoothing=label_smoothing,
            loss_type=loss_type,
            args=args,
            data_collator=data_collator,
            label_pad_token_id=label_pad_token_id,
            padding_value=padding_value,
            truncation_mode=truncation_mode,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            model_init=model_init,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
            max_length=max_length,
            max_prompt_length=max_prompt_length,
            max_target_length=max_target_length,
            peft_config=peft_config,
            is_encoder_decoder=is_encoder_decoder,
            disable_dropout=disable_dropout,
            generate_during_eval=generate_during_eval,
            compute_metrics=compute_metrics,
            precompute_ref_log_probs=precompute_ref_log_probs,
            dataset_num_proc=dataset_num_proc,
            model_init_kwargs=model_init_kwargs,
            ref_model_init_kwargs=ref_model_init_kwargs,
            model_adapter_name=model_adapter_name,
            ref_adapter_name=ref_adapter_name,
            reference_free=reference_free,
            force_use_ref_model=force_use_ref_model,
        )

        self.f_divergence_type = None
        self.reward_scaling_factor_alpha =args.reward_scaling_factor_alpha
        self.ull_scaling_factor = args.ull_scaling_factor
        self.loss_type = loss_type
        self.generate_during_eval = args.generate_during_eval
        self._stored_metrics = defaultdict(lambda: defaultdict(list))
        self.rouge_scores_list = []  # Persistent state to accumulate ROUGE scores
        self.scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)  # Initialize RougeScorer
        self.bleu_metric = load_metric("bleu")
        self.meteor_metric = load_metric("meteor")
        self.sentence_transformer_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

        self.max_length = args.max_length
        self.max_prompt_length = args.max_prompt_length
        # self.max_completion_length = args.max_completion_length
        self.output_dir = args.output_dir
        self.truncation_mode == "keep_start" # for open ai summarize since you want your policy to still be able to summarize even if the post is a tad truncated 
        self.beta = args.beta
        self.is_vision_model = False
        self.gamma = gamma

    def odds_ratio_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)

        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the ORPO loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
            The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
            The `log(sigmoid(log_odds_chosen))` for logging purposes.
        """

        # Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)

        # for nan logps and log odds loss in training, add a small epsilon (eps) to the torch.log1p function: source- https://github.com/huggingface/trl/issues/1473

        eps = 1e-7 

        # Compute safe probabilities to avoid -inf in log calculations
        safe_chosen_p = torch.clamp(-torch.exp(policy_chosen_logps) + eps, min=-1 + 1e-7)
        safe_rejected_p = torch.clamp(-torch.exp(policy_rejected_logps) + eps, min=-1 + 1e-7)  # Clamping lower bound of log probs

        # Compute the stable log odds using safe log probabilities
        log_odds = (policy_chosen_logps - torch.log1p(safe_rejected_p))

        return log_odds


    def dro_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
        policy_chosen_rewards: torch.FloatTensor,
        policy_rejected_rewards: torch.FloatTensor,
        reference_chosen_rewards: torch.FloatTensor,
        reference_rejected_rewards: torch.FloatTensor


    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute the DRO loss for a batch of policy and reference model log probabilities and direct scalar rewards

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
            reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
            reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
            policy_chosen_rewards: Scalar reward for the chosen respones by policy (student)
            policy_rejected_rewards:  Scalar reward for the rejected respones by policy (student) 
            reference_chosen_rewards:  Scalar reward for the chosen respones by policy (teacher) --> NOT USED IN TRAINING DRDO BUT FOR WAS USED FOR INITIAL EXPERIMENTATION
            reference_rejected_rewards:  Scalar reward for the rejected respones by policy (teacher) --> NOT USED IN TRAINING DRDO BUT FOR WAS USED FOR INITIAL EXPERIMENTATION


        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the DRO and other types of loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
        """
        print("loss type in DRO loss function", self.loss_type)
        
        chosen_logratios = policy_chosen_logps.to(self.accelerator.device) - (
            not self.reference_free
        ) * reference_chosen_logps.to(self.accelerator.device)
        rejected_logratios = policy_rejected_logps.to(self.accelerator.device) - (
            not self.reference_free
        ) * reference_rejected_logps.to(self.accelerator.device)

        if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value:

            alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT
            if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params:
                alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY])
            logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef
        else:
            pi_logratios = policy_chosen_logps - policy_rejected_logps
            if self.reference_free:
                ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
            else:
                ref_logratios = reference_chosen_logps - reference_rejected_logps

            pi_logratios = pi_logratios.to(self.accelerator.device)
            ref_logratios = ref_logratios.to(self.accelerator.device)
            logits = pi_logratios - ref_logratios
            log_odds = pi_logratios
            log_odds_unlikelihood = self.odds_ratio_loss(
            policy_chosen_logps, policy_rejected_logps
        )

            pi_pi_ref_ratio_chosen = policy_chosen_logps - reference_chosen_logps # DPO implicit reward for chosen
            pi_pi_ref_ratio_rejected = policy_rejected_logps - reference_rejected_logps
            # print("printing diveergence type", self.f_divergence_type, self.loss_type)

            if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:

                logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)
 
        if self.loss_type == "dpo_original": #dpo_original is "sigmoid" in the original DPOTrainer code
            losses = (
                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
                - F.logsigmoid(-self.beta * logits) * self.label_smoothing
            )
        
        elif self.loss_type == "DPO_P": #DPO-positive baseline, not reported in paper 
            self.lambda_ = 50
            log_odds_unlikelihood = self.odds_ratio_loss(
            policy_chosen_logps, policy_rejected_logps
        )
            dpop_penalty = F.relu(reference_chosen_logps - policy_chosen_logps) 
            ull_logits  = self.beta * log_odds_unlikelihood
            # Compute the DPOP loss using the full equation
 
            losses = (
                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
                - F.logsigmoid(-self.beta * logits) * self.label_smoothing  - self.lambda_ * dpop_penalty 
            )
        
        elif self.loss_type == "DRDO":
            # DRDO loss as repontred in the paper!!
            # Compute the log odds unlikelihood using your existing method
            log_odds_unlikelihood = self.odds_ratio_loss(
                policy_chosen_logps, policy_rejected_logps
            )

            # Compute reward difference b/w R* and r hat in the paper
            direct_reward_diff_squared = (
                reference_chosen_rewards - reference_rejected_rewards
                - (policy_chosen_rewards - policy_rejected_rewards)
            ) ** 2
            direct_reward_loss_scaled = self.reward_scaling_factor_alpha * direct_reward_diff_squared
            logit_difference = policy_chosen_logps - policy_rejected_logps
            prob_w = torch.sigmoid(logit_difference)  # Probability of the chosen response
            focal_modulating_factor = (1 - prob_w) ** gamma
            focal_softened_loss = -focal_modulating_factor * log_odds_unlikelihood
            ull_logits = self.ull_scaling_factor * focal_softened_loss # alpha in the second focal term in DRDO loss: using ull_scaling_factor since alpha paramater inialized already for RM trianing
            losses = ull_logits + direct_reward_loss_scaled

     
        elif self.loss_type == "e-DPO":
          
        
            losses  = (reference_chosen_rewards - reference_rejected_rewards - (self.beta * logits))**2 
            # This is Ldistill or e-DPO after averaging rewards from precomputed rewards (we used 3 RMs and averaged them )

            ull_logits  = self.ull_scaling_factor  * log_odds_unlikelihood

        else:
            raise ValueError(
                f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'bco_pair', 'sppo_hard', 'nca_pair', 'robust', 'exo_pair']"
            )

        chosen_rewards = (
            self.beta
            * (
                policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
            ).detach()
        )
        rejected_rewards = (
            self.beta
            * (
                policy_rejected_logps.to(self.accelerator.device)
                - reference_rejected_logps.to(self.accelerator.device)
            ).detach()
        )
        # print("log probs CHOSEN inside DRO loss", policy_chosen_logps,reference_chosen_logps)
        # print("log probs rejected inside DRO loss", policy_rejected_logps,reference_rejected_logps)
        # print("chosen and rejected reward inside DRO loss", chosen_rewards,rejected_rewards, self.beta )

        return losses, chosen_rewards, rejected_rewards, policy_chosen_rewards, policy_rejected_rewards, reference_chosen_rewards, reference_rejected_rewards, ull_logits

    @staticmethod
    def get_batch_logps(
        logits: torch.FloatTensor,
        labels: torch.LongTensor,
        label_pad_token_id: int = -100,
        is_encoder_decoder: bool = False,
    ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
        """Compute the log probabilities of the given labels under the given logits.

        Args:
            logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
            labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
            label_pad_token_id: The label pad token id.
            is_encoder_decoder: Whether the model is an encoder-decoder model.

        Returns:
            A Tuple of two tensor of shape ((batch_size,), (batch_size,)) containing the sum of log probabilities of the given labels under the given logits in the first tensor and the number of non-masked tokens in the second tensor.
        """
        if logits.shape[:-1] != labels.shape:
            raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")

        if not is_encoder_decoder:
            labels = labels[:, 1:].clone()
            logits = logits[:, :-1, :]
        loss_mask = labels != label_pad_token_id

        # dummy token; we'll ignore the losses on these tokens later
        labels[labels == label_pad_token_id] = 0

        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

        return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)

    def build_tokenized_answer(self, prompt, answer, images=None):
        """
        Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
        It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
        Reference:
            https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
        """
        if self.is_vision_model:
            if answer.count("<image>") > 0:
                raise NotImplementedError("Answer contains <image> token, which is not supported yet.")
            if "add_special_tokens" in inspect.signature(self.processor).parameters:
                processor_kwargs = {"add_special_tokens": False}
            else:
                processor_kwargs = {}
            full_tokenized = self.processor(prompt + answer, images=images, **processor_kwargs)
            full_tokenized = {k: v[0] for k, v in full_tokenized.items()}  # Unbatch, not done when using idefics
            if not isinstance(full_tokenized["input_ids"], list):  # llava processor returns tensors
                full_tokenized["input_ids"] = full_tokenized["input_ids"].tolist()
                full_tokenized["attention_mask"] = full_tokenized["attention_mask"].tolist()
            prompt_input_ids = self.processor(prompt, images=images, **processor_kwargs)["input_ids"][0]
            if not isinstance(prompt_input_ids, list):  # llava processor returns tensors
                prompt_input_ids = prompt_input_ids.tolist()
        else:
            full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
            prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]

        answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
        answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]

        # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
        full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])

        # Prepare input tokens for token by token comparison
        full_input_ids = np.array(full_tokenized["input_ids"])

        if len(full_input_ids) != len(full_concat_input_ids):
            raise ValueError("Prompt input ids and answer input ids should have the same length.")


        response_token_ids_start_idx = len(prompt_input_ids)

        # If tokenized prompt is different than both prompt+answer, then it means the
        # last token has changed due to merging.
        if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
            response_token_ids_start_idx -= 1

        prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
        prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]

        if len(prompt_input_ids) != len(prompt_attention_mask):
            raise ValueError("Prompt input ids and attention mask should have the same length.")

        answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
        answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]

        return_dict = dict(
            prompt_input_ids=prompt_input_ids,
            prompt_attention_mask=prompt_attention_mask,
            input_ids=answer_input_ids,
            attention_mask=answer_attention_mask,
        )
        if "pixel_values" in full_tokenized:
            return_dict["prompt_pixel_values"] = full_tokenized["pixel_values"]
        if "pixel_attention_mask" in full_tokenized:
            return_dict["prompt_pixel_attention_mask"] = full_tokenized["pixel_attention_mask"]

        return return_dict

    def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
        """Tokenize a single row from a DPO specific dataset.

        At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
        in case the prompt + chosen or prompt + rejected responses is/are too long. First
            we truncate the prompt; if we're still too long, we truncate the chosen/rejected.

        We also create the labels for the chosen/rejected responses, which are of length equal to
            the sum of the length of the prompt and the chosen/rejected response, with
            label_pad_token_id  for the prompt tokens.
        """
        batch = {}
        prompt = feature["prompt"]
        chosen = feature["chosen"]
        rejected = feature["rejected"]
        images = feature.get("images")

        if not self.is_encoder_decoder:
            # Check issues below for more details
            #  1. https://github.com/huggingface/trl/issues/907
            #  2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
            #  3. https://github.com/LianjiaTech/BELLE/issues/337

            if not isinstance(prompt, str):
                raise ValueError(f"prompt should be an str but got {type(prompt)}")
            if self.is_vision_model:
                if "add_special_tokens" in inspect.signature(self.processor).parameters:
                    processor_kwargs = {"add_special_tokens": False}
                else:
                    processor_kwargs = {}
                prompt_tokens = self.processor(prompt, images=images, **processor_kwargs)
                prompt_tokens = {k: v[0] for k, v in prompt_tokens.items()}  # Unbatch, not done when using idefics
                if not isinstance(prompt_tokens["input_ids"], list):  # llava processor returns tensors
                    prompt_tokens["input_ids"] = prompt_tokens["input_ids"].tolist()
                    prompt_tokens["attention_mask"] = prompt_tokens["attention_mask"].tolist()
            else:
                prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)

            prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}

            if not isinstance(chosen, str):
                raise ValueError(f"chosen should be an str but got {type(chosen)}")

            chosen_tokens = self.build_tokenized_answer(prompt, chosen, images)

            if not isinstance(rejected, str):
                raise ValueError(f"rejected should be an str but got {type(rejected)}")
            rejected_tokens = self.build_tokenized_answer(prompt, rejected, images)

            # Last prompt token might get merged by tokenizer and
            # it should not be included for generation if that happens
            prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])

            chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
            rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
            prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)

            for k, v in prompt_tokens.items():
                prompt_tokens[k] = v[:prompt_len_input_ids]

            # Make sure prompts only have one different token at most an
            # and length only differs by 1 at most
            num_diff_tokens = sum(
                [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
            )
            num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
            if num_diff_tokens > 1 or num_diff_len > 1:
                raise ValueError(
                    "Chosen and rejected prompt_input_ids might only differ on the "
                    "last token due to tokenizer merge ops."
                )

            # add BOS token to head of prompt. Avoid adding if it's already there
            bos_token_id = self.tokenizer.bos_token_id
            if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]:
                prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"]
                prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
            if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]:
                chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"]
                chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
            if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]:
                rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"]
                rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]

            # add EOS token to end of answer. Avoid adding if it's already there
            eos_token_id = self.tokenizer.eos_token_id
            if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]:
                chosen_tokens["input_ids"].append(eos_token_id)
                chosen_tokens["attention_mask"].append(1)
            if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]:
                rejected_tokens["input_ids"].append(eos_token_id)
                rejected_tokens["attention_mask"].append(1)

            longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

            # if combined sequence is too long, truncate the prompt
            for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
                if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
                    if self.truncation_mode == "keep_start":
                        for k in ["prompt_input_ids", "prompt_attention_mask"]:
                            answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
                    elif self.truncation_mode == "keep_end":
                        for k in ["prompt_input_ids", "prompt_attention_mask"]:
                            answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
                    else:
                        raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")

            # if that's still too long, truncate the response
            for answer_tokens in [chosen_tokens, rejected_tokens]:
                if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
                    for k in ["input_ids", "attention_mask"]:
                        answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]

            # Create labels
            chosen_sequence_tokens = {
                k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
            }
            rejected_sequence_tokens = {
                k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
            }
            chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
            chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
                self.label_pad_token_id
            ] * len(chosen_tokens["prompt_input_ids"])
            rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
            rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
                self.label_pad_token_id
            ] * len(rejected_tokens["prompt_input_ids"])

            for k, toks in {
                "chosen_": chosen_sequence_tokens,
                "rejected_": rejected_sequence_tokens,
                "": prompt_tokens,
            }.items():
                for type_key, tokens in toks.items():
                    if type_key == "token_type_ids":
                        continue
                    batch[f"{k}{type_key}"] = tokens

        else:
            chosen_tokens = self.tokenizer(
                chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
            )
            rejected_tokens = self.tokenizer(
                rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
            )
            prompt_tokens = self.tokenizer(
                prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
            )

            batch["chosen_labels"] = chosen_tokens["input_ids"]
            batch["rejected_labels"] = rejected_tokens["input_ids"]
            batch["prompt_input_ids"] = prompt_tokens["input_ids"]
            batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]

            if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
                batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
                    labels=torch.tensor(batch["rejected_labels"])
                )
                batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
                    labels=torch.tensor(batch["chosen_labels"])
                )

        return batch


    def forward_with_lm_head(self, model, concatenated_batch, model_kwargs):
        """
        Custom forward pass to get the rewards (logits) and the vocab logits for the chosen response
        """
        # Handle DataParallel and DistributedDataParallel models
        if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model = model.module  # Access the underlying model in both cases
        
        # Extract inputs from the concatenated batch
        input_ids = concatenated_batch["concatenated_input_ids"]
        attention_mask = concatenated_batch["concatenated_attention_mask"]

        # Forward pass through the base model
        outputs = model.base_model(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            return_dict=True, 
            use_cache=False,
            **model_kwargs
        )

        last_hidden_state = outputs.last_hidden_state
        last_hidden_state = last_hidden_state[:, -1, :]
        # Compute vocab logits using the model's classification_head
        vocab_logits = model.classification_head(last_hidden_state)
        return vocab_logits, last_hidden_state


    def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]], which_model: Optional[bool] = None  
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

        We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = self.concatenated_inputs(
            batch,
            is_encoder_decoder=self.is_encoder_decoder,
            is_vision_model=self.is_vision_model,
            label_pad_token_id=self.label_pad_token_id,
            padding_value=self.padding_value,
            device=self.accelerator.device,
        )
        len_chosen = batch["chosen_labels"].shape[0]

        model_kwargs = {}

        if self.is_encoder_decoder:
            model_kwargs["labels"] = concatenated_batch["concatenated_labels"]
            model_kwargs["decoder_input_ids"] = concatenated_batch.pop("concatenated_decoder_input_ids", None)

        if self.is_vision_model:
            model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
            if "pixel_attention_mask" in concatenated_batch:
                model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]

        if self.aux_loss_enabled:
            model_kwargs["output_router_logits"] = True

        # print("conc input ids",concatenated_batch["concatenated_input_ids"][0] )
        outputs = model(concatenated_batch["concatenated_input_ids"],

            attention_mask=concatenated_batch["concatenated_attention_mask"],
            use_cache=False,
            **model_kwargs,
        )
 

        all_rewards, last_hidden_state = self.forward_with_lm_head(model, concatenated_batch, model_kwargs)
        all_logits = outputs.logits
        if all_logits.shape[:2] != concatenated_batch["concatenated_labels"].shape[:2]:
            # for llava, the model returns logits for the entire sequence, including the image tokens (placed before the text tokens)
            seq_len = concatenated_batch["concatenated_labels"].shape[1]
            all_logits = all_logits[:, -seq_len:]

        all_logps, size_completion = self.get_batch_logps(
            all_logits,
            concatenated_batch["concatenated_labels"],
            # average_log_prob=self.loss_type == "ipo",
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
        )

        def cross_entropy_loss(logits, labels):
            if not self.is_encoder_decoder:
                # Shift so that tokens < n predict n
                logits = logits[..., :-1, :].contiguous()
                labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            logits = logits.view(-1, logits.shape[-1])
            labels = labels.view(-1)
            # Enable model parallelism
            labels = labels.to(logits.device)
            loss = loss_fct(logits, labels)
            return loss

        labels = concatenated_batch["concatenated_labels"].clone()
        nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

        if self.loss_type == "ipo":
            all_logps = all_logps / size_completion

        chosen_logps = all_logps[:len_chosen]
        rejected_logps = all_logps[len_chosen:]

        chosen_logits = all_logits[:len_chosen]
        rejected_logits = all_logits[len_chosen:]

        chosen_direct_rewards = all_rewards[:len_chosen]
        rejected_direct_rewards = all_rewards[len_chosen:]


        if self.aux_loss_enabled:
            return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)

        return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss,chosen_direct_rewards, rejected_direct_rewards)

    def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}
      
          
        forward_output = self.concatenated_forward(model, batch, which_model = "policy")


        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
            policy_nll_loss, policy_chosen_rewards, policy_rejected_rewards
        ) = forward_output[:7] # to cactch the chosen and rejected rewards along with the logps and logits
 
        if self.aux_loss_enabled:
            aux_loss = forward_output[5]

        # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
        if (
            "reference_chosen_logps" in batch
            and "reference_rejected_logps" in batch
            and self.args.rpo_alpha is not None
        ):
            reference_chosen_logps = batch["reference_chosen_logps"]
            reference_rejected_logps = batch["reference_rejected_logps"]
        else:
            # print("reference log probs are NOT in batch")
            with torch.no_grad():
                if self.ref_model is None:
 
                    with self.null_ref_context():
                        (
                            reference_chosen_logps,
                            reference_rejected_logps,
                            _,
                            _,
                            _,
                        ) = self.concatenated_forward(self.model, batch)
                else:
                    
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        _,
                        _,
                        _,
                        reference_chosen_rewards, reference_rejected_rewards
                    ) = self.concatenated_forward(self.ref_model, batch, which_model = "teacher")

                   
        losses, chosen_rewards, rejected_rewards, \
            policy_chosen_direct_rewards, policy_rejected_direct_rewards, \
            reference_chosen_direct_rewards, reference_rejected_direct_rewards, ull_logits = self.dro_loss(
                policy_chosen_logps, 
                policy_rejected_logps, 
                reference_chosen_logps, 
                reference_rejected_logps, 
                policy_chosen_rewards, 
                policy_rejected_rewards, 
                reference_chosen_rewards, 
                reference_rejected_rewards
            )

        reward_accuracies = (chosen_rewards > rejected_rewards).float()

      

        if self.args.rpo_alpha is not None:
            losses = losses * self.args.rpo_alpha + policy_nll_loss

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
        metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
        metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
        metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
        metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
        metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
         #log the direct rewards (distilled) here 
        direct_reward_accuracies = (policy_chosen_direct_rewards > policy_rejected_direct_rewards).float()
        direct_teacher_reward_accuracies = (reference_chosen_direct_rewards > reference_rejected_direct_rewards).float()

        student_teacher_reward_diff = (reference_chosen_direct_rewards - reference_rejected_direct_rewards - (policy_chosen_direct_rewards - policy_rejected_direct_rewards))**2

        student_teacher_reward_diff_scaled = self.reward_scaling_factor_alpha * student_teacher_reward_diff
        metrics[f"{prefix}policy_nll_loss"] = policy_nll_loss.detach().mean().cpu()
        metrics[f"{prefix}direct_reward_diff_scaled"] = student_teacher_reward_diff_scaled.mean().cpu()
        metrics[f"{prefix}ull_logits"] = ull_logits.mean().cpu()
  
        metrics[f"{prefix}directrewards/chosen"] = policy_chosen_direct_rewards.mean().cpu()
        metrics[f"{prefix}directrewards/rejected"] = policy_rejected_direct_rewards.mean().cpu()
        metrics[f"{prefix}directrewards_student/accuracies"] = direct_reward_accuracies.mean().cpu()
        metrics[f"{prefix}directrewards_teacher/accuracies"] = direct_teacher_reward_accuracies.mean().cpu()
        metrics[f"{prefix}directrewards_student/margins"] = (policy_chosen_direct_rewards - policy_rejected_direct_rewards).mean().cpu()
        metrics[f"{prefix}directrewards_teacher/margins"] = (reference_chosen_direct_rewards - reference_rejected_direct_rewards).mean().cpu()


        print(f"{prefix}policy_nll_loss: {policy_nll_loss.detach().mean().cpu().item()}")
        print(f"{prefix}directrewards/chosen: {policy_chosen_direct_rewards.mean().cpu().item()}")
        print(f"{prefix}directrewards/rejected: {policy_rejected_direct_rewards.mean().cpu().item()}")
        print(f"{prefix}directrewards_student/accuracies: {direct_reward_accuracies.mean().cpu().item()}")
        print(f"{prefix}directrewards_teacher/accuracies: {direct_teacher_reward_accuracies.mean().cpu().item()}")
        print(f"{prefix}directrewards_student/margins: {(policy_chosen_direct_rewards - policy_rejected_direct_rewards).mean().cpu().item()}")
        print(f"{prefix}directrewards_teacher/margins: {(reference_chosen_direct_rewards - reference_rejected_direct_rewards).mean().cpu().item()}")


        if self.args.rpo_alpha is not None:
            metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()

        if self.aux_loss_enabled:
            return losses.mean() + getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss, metrics

        return losses.mean(), metrics

    def compute_loss(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs=False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
        if not self.use_dpo_data_collator:
            warnings.warn(
                "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
                "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
            )

        compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

        with compute_loss_context_manager():
            loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

        # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
        loss = loss.to(self.args.device)
        # force log the metrics
        self.store_metrics(metrics, train_eval="train")

        if return_outputs:
            return (loss, metrics)
        return loss

    def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
        """Generate samples from the model and reference model for the given batch of inputs."""

        # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
        # the torch cuda amp context manager as some hidden states are silently casted to full precision.

        def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor:
            if tensor.size(dim) >= length:
                return tensor
            else:
                pad_size = list(tensor.shape)
                pad_size[dim] = length - tensor.size(dim)
                return torch.cat(
                    [
                        tensor,
                        pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device),
                    ],
                    dim=dim,
                )
        generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast


        with generate_context_manager():
            policy_output = model.generate(
                input_ids=batch["prompt_input_ids"],
                attention_mask=batch["prompt_attention_mask"],
                max_length=self.max_length,
                do_sample=True,temperature=0.7, 
                top_k=50, top_p=0.92, repetition_penalty=1.5, 
                no_repeat_ngram_size=3,
                pad_token_id=self.tokenizer.pad_token_id,
            )

            # if reference_output in batch use that otherwise use the reference model
            if "reference_output" in batch:
                reference_output = batch["reference_output"]
            else:
                if self.ref_model is None:
                    with self.null_ref_context():
                        reference_output = self.model.generate(
                            input_ids=batch["prompt_input_ids"],
                            attention_mask=batch["prompt_attention_mask"],
                            max_length=self.max_length,
                            do_sample=True,
                            pad_token_id=self.tokenizer.pad_token_id,
                        )
                else:
                    reference_output = self.ref_model.generate(
                        input_ids=batch["prompt_input_ids"],
                        attention_mask=batch["prompt_attention_mask"],
                        max_length=self.max_length,
                        do_sample=True,
                        pad_token_id=self.tokenizer.pad_token_id,
                        temperature=0.7, 
                top_k=50, top_p=0.92, repetition_penalty=1.5, 
                no_repeat_ngram_size=3,
                    )

        policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
        policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)

        reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id)
        reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True)

        return policy_output_decoded, reference_output_decoded

    def prediction_step(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ):
        if not self.use_dpo_data_collator:
            warnings.warn(
                "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
                "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
            )
        if ignore_keys is None:
            if hasattr(model, "config"):
                ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext

        with torch.no_grad(), prediction_context_manager():
            loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")

        # force log the metrics
        self.store_metrics(metrics, train_eval="eval")

        if prediction_loss_only:
            return (loss.detach(), None, None)

        # logits for the chosen and rejected samples from model
        logits_dict = {
            "eval_logits/chosen": metrics["eval_logits/chosen"],
            "eval_logits/rejected": metrics["eval_logits/rejected"],
        }
        logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
        logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
        labels = torch.zeros(logits.shape[0], device=self.accelerator.device)

        return (loss.detach(), logits, labels)

    def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
        for key, value in metrics.items():
            self._stored_metrics[train_eval][key].append(value)

    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> EvalLoopOutput:
        """
        Overriding built-in evaluation loop to store metrics for each batch.
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.

        Works both with or without labels.
        """
        quarter_steps = self.args.max_steps // 4
        batch_policy_token_count_list = []
        batch_chosen_token_count_list = []
        batch_rejected_token_count_list = []
        batch_reference_token_count_list = []
    

        # Sample and save to game log if requested (for one batch to save time)
        if self.generate_during_eval and (self.state.global_step % quarter_steps == 0 or self.state.global_step == self.args.max_steps):
            num_samples = len(dataloader.dataset)
            start_index = 0
            min_samples_for_generation_during_eval = len(num_samples)
            end_index = min(min_samples_for_generation_during_eval, num_samples)
            selected_indices = list(range(start_index, end_index))
            selected_dataset = dataloader.dataset.select(selected_indices)
            print("size of selected eval dataset for automatic metrics:", len(selected_dataset))
            batch_size = self.args.eval_batch_size
            batch_rouge_scores_list = []
            batch_bleu_scores_list = []
            batch_meteor_scores_list = []
            batch_semantic_similarity_list = []

            batch_rouge_scores_list_ref = []
            batch_bleu_scores_list_ref = []
            batch_meteor_scores_list_ref = []
            batch_semantic_similarity_list_ref = []
            for i in range(0, len(selected_dataset), batch_size):
                # Select a batch of samples
                
                batch_indices = list(range(i, min(i + batch_size, len(selected_dataset))))
                batch_dataset = selected_dataset.select(batch_indices)
                batch = self.data_collator(batch_dataset)
                batch = self._prepare_inputs(batch)

                # Generate and process the batch
                print("getting batch generation inference", i)
 
                policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, batch)
                policy_output_decoded = [pol[len(prompt):] for prompt, pol in zip(batch['prompt'], policy_output_decoded)]
                ref_output_decoded = [ref[len(prompt):] for prompt, ref in zip(batch['prompt'], ref_output_decoded)]

                chosen_labels = batch['chosen_labels'].cpu().numpy()
                rejected_labels = batch['rejected_labels'].cpu().numpy()

                chosen_labels = np.where(chosen_labels != -100, chosen_labels, self.tokenizer.pad_token_id)
                rejected_labels = np.where(rejected_labels != -100, rejected_labels, self.tokenizer.pad_token_id)
                chosen_output_decoded = self.tokenizer.batch_decode(chosen_labels, skip_special_tokens=True)
                rejected_output_decoded = self.tokenizer.batch_decode(rejected_labels, skip_special_tokens=True)

                policy_token_counts = [len(self.tokenizer.tokenize(pred)) for pred in policy_output_decoded]
                reference_token_counts = [len(self.tokenizer.tokenize(pred)) for pred in ref_output_decoded]
                chosen_token_counts = [len(self.tokenizer.tokenize(chosen)) for chosen in chosen_output_decoded]
                rejected_token_counts = [len(self.tokenizer.tokenize(rejected)) for rejected in rejected_output_decoded]

                # Append token counts to batch lists
                batch_policy_token_count_list.append(np.mean(policy_token_counts))
                batch_reference_token_count_list.append(np.mean(reference_token_counts))
                batch_chosen_token_count_list.append(np.mean(chosen_token_counts))
                batch_rejected_token_count_list.append(np.mean(rejected_token_counts))

                                # ROUGE expects a newline after each sentence
                decoded_preds = ["\n".join(pred.strip()) for pred in policy_output_decoded]
                decoded_ref_preds = ["\n".join(pred.strip()) for pred in ref_output_decoded]
                decoded_chosen_labels = ["\n".join(label.strip()) for label in chosen_output_decoded]

                # Compute ROUGE scores using RougeScorer for policy predictions
                rouge_scores = self.scorer.score("\n".join(decoded_chosen_labels), "\n".join(decoded_preds))
                rouge_scores = {k: v.fmeasure * 100 for k, v in rouge_scores.items()}  # Extract F1 scores
                rouge_scores = {k: round(v, 4) for k, v in rouge_scores.items()}

                # Compute ROUGE scores for reference model predictions
                rouge_scores_ref = self.scorer.score("\n".join(decoded_chosen_labels), "\n".join(decoded_ref_preds))
                rouge_scores_ref = {k: v.fmeasure * 100 for k, v in rouge_scores_ref.items()}  # Extract F1 scores
                rouge_scores_ref = {k: round(v, 4) for k, v in rouge_scores_ref.items()}

                # Append ROUGE scores for the batch
                batch_rouge_scores_list.append(rouge_scores)
                batch_rouge_scores_list_ref.append(rouge_scores_ref)

                # Compute BLEU score
                                # Compute BLEU score for policy predictions
                                # Compute BLEU score for policy model predictions
                self.bleu_metric.add_batch(
                    predictions=[pred.split() for pred in decoded_preds],
                    references=[[label.split()] for label in decoded_chosen_labels]
                )
                batch_bleu_scores_list.append(self.bleu_metric.compute())

                # Directly compute BLEU score for reference model predictions
                self.bleu_metric.add_batch(
                    predictions=[pred.split() for pred in decoded_ref_preds],
                    references=[[label.split()] for label in decoded_chosen_labels]
                )
                batch_bleu_scores_list_ref.append(self.bleu_metric.compute())

                for pred, ref in zip(decoded_preds, decoded_chosen_labels):
                    self.meteor_metric.add(prediction=pred, reference=ref)
                batch_meteor_scores_list.append(self.meteor_metric.compute())

                # Directly compute METEOR score for reference model predictions
                for pred, ref in zip(decoded_ref_preds, decoded_chosen_labels):
                    self.meteor_metric.add(prediction=pred, reference=ref)
                batch_meteor_scores_list_ref.append(self.meteor_metric.compute())


                 # Compute semantic similarity for policy predictions
                embeddings1 = self.sentence_transformer_model.encode(decoded_preds, convert_to_tensor=True)
                embeddings2 = self.sentence_transformer_model.encode(decoded_chosen_labels, convert_to_tensor=True)
                semantic_similarities = util.pytorch_cos_sim(embeddings1, embeddings2)
                avg_similarity = semantic_similarities.mean().item()

                # Append semantic similarity for policy predictions
                batch_semantic_similarity_list.append(avg_similarity)

                # Compute semantic similarity for reference model predictions
                embeddings1_ref = self.sentence_transformer_model.encode(decoded_ref_preds, convert_to_tensor=True)
                semantic_similarities_ref = util.pytorch_cos_sim(embeddings1_ref, embeddings2)
                avg_similarity_ref = semantic_similarities_ref.mean().item()

                # Append semantic similarity for reference model predictions
                batch_semantic_similarity_list_ref.append(avg_similarity_ref)


                # Step 1: Create the DataFrame with the decoded outputs
                df_log = pd.DataFrame({
                    "Prompt": batch["prompt"],
                    "Policy": [pol for prompt, pol in zip(batch["prompt"], policy_output_decoded)],
                    "Reference": [pol for prompt, pol in zip(batch["prompt"], ref_output_decoded)],
                    "Chosen Output": chosen_output_decoded,
                    "Rejected Output": rejected_output_decoded
                })

         
                log_dict = df_log.to_dict('list')

                # Log each column of the DataFrame using `self.log`
                for column, values in log_dict.items():
                    self.log({f"{column}_batch_{i//batch_size + 1}": values})

            avg_policy_tokens = round(np.mean(batch_policy_token_count_list), 2)
            avg_reference_tokens = round(np.mean(batch_reference_token_count_list), 2)
            avg_chosen_tokens = round(np.mean(batch_chosen_token_count_list), 2)
            avg_rejected_tokens = round(np.mean(batch_rejected_token_count_list), 2)
            self.log({
                "avg_policy_tokens": avg_policy_tokens,
                "avg_reference_tokens":avg_reference_tokens,
                "avg_chosen_tokens": avg_chosen_tokens,
                "avg_rejected_tokens": avg_rejected_tokens
            })
                
            if batch_rouge_scores_list:
                overall_rouge_scores = {
                    key: round(sum([batch_scores[key] for batch_scores in batch_rouge_scores_list]) / len(batch_rouge_scores_list), 2)
                    for key in batch_rouge_scores_list[0].keys()
                }
                # print("Overall ROUGE scores (policy):", overall_rouge_scores)
                self.log({"overall_rouge_scores": overall_rouge_scores})

            if batch_rouge_scores_list_ref:
                overall_rouge_scores_ref = {
                    key: round(sum([batch_scores[key] for batch_scores in batch_rouge_scores_list_ref]) / len(batch_rouge_scores_list_ref), 2)
                    for key in batch_rouge_scores_list_ref[0].keys()
                }
                # print("Overall ROUGE scores (reference):", overall_rouge_scores_ref)
                self.log({"overall_rouge_scores_ref": overall_rouge_scores_ref})

            # Compute the average BLEU scores across all batches (for policy and reference)
            if batch_bleu_scores_list:
                overall_bleu_scores = {
                    key: round(sum([batch_scores[key] if isinstance(batch_scores[key], (int, float)) else np.mean(batch_scores[key])
                                    for batch_scores in batch_bleu_scores_list]) / len(batch_bleu_scores_list), 2)
                    for key in batch_bleu_scores_list[0].keys()
                }
                # print("Overall BLEU scores (policy):", overall_bleu_scores)
                self.log({"overall_bleu_scores": overall_bleu_scores})

            if batch_bleu_scores_list_ref:
                overall_bleu_scores_ref = {
                    key: round(sum([batch_scores[key] if isinstance(batch_scores[key], (int, float)) else np.mean(batch_scores[key])
                                    for batch_scores in batch_bleu_scores_list_ref]) / len(batch_bleu_scores_list_ref), 2)
                    for key in batch_bleu_scores_list_ref[0].keys()
                }
                # print("Overall BLEU scores (reference):", overall_bleu_scores_ref)
                self.log({"overall_bleu_scores_ref": overall_bleu_scores_ref})

            # Compute the average METEOR scores across all batches (for policy and reference)
            if batch_meteor_scores_list:
                overall_meteor_scores = {
                    key: round(sum([batch_scores[key] for batch_scores in batch_meteor_scores_list]) / len(batch_meteor_scores_list), 2)
                    for key in batch_meteor_scores_list[0].keys()
                }
                # print("Overall POLICY METEOR scores:", overall_meteor_scores)
                # self.log({"overall_meteor_scores": overall_meteor_scores})

            if batch_meteor_scores_list_ref:
                overall_meteor_scores_ref = {
                    key: round(sum([batch_scores[key] for batch_scores in batch_meteor_scores_list_ref]) / len(batch_meteor_scores_list_ref), 2)
                    for key in batch_meteor_scores_list_ref[0].keys()
                }
                # print("Overall REF METEOR scores:", overall_meteor_scores_ref)
                self.log({"overall_meteor_scores_ref": overall_meteor_scores_ref})

            # Compute the average semantic similarity across all batches (for policy and reference)
            if batch_semantic_similarity_list:
                overall_semantic_similarity = round(sum(batch_semantic_similarity_list) / len(batch_semantic_similarity_list), 2)
                # print("Overall Semantic Similarity (policy):", overall_semantic_similarity)
                self.log({"overall_semantic_similarity": overall_semantic_similarity})

            if batch_semantic_similarity_list_ref:
                overall_semantic_similarity_ref = round(sum(batch_semantic_similarity_list_ref) / len(batch_semantic_similarity_list_ref), 2)
                # print("Overall Semantic Similarity (reference):", overall_semantic_similarity_ref)
                self.log({"overall_semantic_similarity_ref": overall_semantic_similarity_ref})



        # Base evaluation
        initial_output = super().evaluation_loop(
            dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
        )

        return initial_output

def plot_all_metrics(data, output_folder):
    # Ensure the output folder exists
    os.makedirs(output_folder, exist_ok=True)

    # Initialize a dictionary to store all metrics for plotting
    metrics_dict = defaultdict(lambda: defaultdict(list))

    # Iterate through each experiment's data
    for experiment_key, experiment_data in data.items():
        for entry in experiment_data[0]:
            step = entry.get('step')
            
            if step is not None:
                for key, value in entry.items():
                    if isinstance(value, dict):  # If value is a nested dictionary, handle its keys
                        for sub_key, sub_value in value.items():
                            metric_name = f"{key}/{sub_key}"
                            if experiment_key not in metrics_dict[metric_name]:
                                metrics_dict[metric_name][experiment_key] = {'steps': [], 'values': []}
                            metrics_dict[metric_name][experiment_key]['steps'].append(step)
                            metrics_dict[metric_name][experiment_key]['values'].append(sub_value)
                    elif key != 'step' and isinstance(value, (int, float)):
                        if experiment_key not in metrics_dict[key]:
                            metrics_dict[key][experiment_key] = {'steps': [], 'values': []}
                        metrics_dict[key][experiment_key]['steps'].append(step)
                        metrics_dict[key][experiment_key]['values'].append(value)

    # Create a list of distinct colors and markers
    colors = plt.cm.get_cmap('tab20', 20).colors  # Ensure you have 20 distinct colors
    markers = ['o', 's', 'D', '^', 'v', '>', '<', 'p', 'h', 'H', '*', 'x', 'X', 'd', '|', '_', '+', 'P', '1', '2']  # 20 distinct markers
    
    color_cycle = cycle(colors)
    marker_cycle = cycle(markers)

    # Plot each metric separately and save each plot
    for metric_name, experiments in metrics_dict.items():
        plt.figure(figsize=(10, 6))
        for experiment_key, metric_data in experiments.items():
            steps = metric_data['steps']
            values = metric_data['values']
            color = next(color_cycle)
            marker = next(marker_cycle)
            plt.plot(steps, values, marker=marker, label=str(experiment_key), color=color)

        # plt.title(f'{metric_name} vs. Steps')
        plt.xlabel('Step')
        plt.ylabel(metric_name)
        plt.grid(True)
        plt.legend(loc='best', fontsize='small')
        plt.tight_layout()
        plt.savefig(os.path.join(output_folder, f'{metric_name.replace("/", "_")}_plot.png'), dpi=300)
        plt.close()

    # Create combined plots with multiple subplots
    num_metrics = len(metrics_dict)
    num_cols = 2
    num_rows = (num_metrics + 1) // num_cols

    fig, axs = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows))
    fig.suptitle('Metrics vs. Steps')

    for i, (metric_name, experiments) in enumerate(metrics_dict.items()):
        row = i // num_cols
        col = i % num_cols
        ax = axs[row, col] if num_metrics > 1 else axs

        for experiment_key, metric_data in experiments.items():
            steps = metric_data['steps']
            values = metric_data['values']
            color = next(color_cycle)
            marker = next(marker_cycle)
            ax.plot(steps, values, marker=marker, label=str(experiment_key), color=color)

        # ax.set_title(f'{metric_name} vs. Steps')
        ax.set_xlabel('Step')
        ax.set_ylabel(metric_name)
        ax.grid(True)
        ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), fontsize='small', ncol=2)

    # Hide any unused subplots if necessary
    if num_metrics % num_cols != 0:
        axs[-1, -1].axis('off')

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(os.path.join(output_folder, 'final_overall_metrics_plot.png'), dpi=300)
    plt.show()


class EvaluateFirstStepCallback(TrainerCallback):
    def on_train_begin(self, args, state, control, **kwargs):
 
        control.should_evaluate = True

def get_ultrafeedback_binarized_paired(
    split, 
    sanity_check: bool = False,
    cache_dir: Optional[str] = None,
    num_proc=24,

):
 # Load dataset from the hub
    dataset = load_dataset("##", split=split)
    print(dataset)
    if sanity_check:
        dataset = dataset.shuffle().select(range(1000))
    original_columns = dataset.column_names
    selected_columns = ['prompt', 'chosen', 'rejected']
    removable_columns = [ x for x in original_columns if x not in selected_columns]
    def return_prompt_and_responses(row):
        row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
        row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
        return row

    return dataset.map(
        return_prompt_and_responses,
  
        num_proc= num_proc,
        remove_columns=removable_columns,
    )

def get_ultrafeedback_binarized_paired_new(dataset, tokenizer, split = 'train', num_proc=24):
    """
    Processes the dataset by filtering for chosen and rejected responses

    Args:
        dataset: Hugging Face Dataset to be processed.
        tokenizer: Tokenizer instance used for applying chat templates.
        split: Indicates the data split (e.g., 'train', 'test').
        num_proc: Number of processes to use for mapping (default: 24).

    Returns:
        Processed dataset with tokenized 'chosen' and 'rejected' responses.
    """

    original_columns = dataset.column_names
    selected_columns = ['prompt', 'chosen', 'rejected']
    removable_columns = [ x for x in original_columns if x not in selected_columns]

    def return_prompt_and_responses(row):
        row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False, add_generation_prompt=True)
        row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False, add_generation_prompt=True)
        return row

    # Map the processing function to the dataset
    processed_dataset = dataset.map(
        return_prompt_and_responses,
        num_proc=num_proc,
        remove_columns=removable_columns
    )

    return processed_dataset

if __name__ == "__main__":
     
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]
    # Set seed for reproducibility
    set_seed(script_args.seed)
    if script_args.model_dtype == "float16":
        torch_dtype = torch.float16
    elif script_args.model_dtype == "bfloat16":
        torch_dtype = torch.bfloat16

    # Initialize the accelerator
    accelerator = Accelerator()
    tokenizer = AutoTokenizer.from_pretrained(script_args.student_model_name_or_path)
    print("tokenizer when first loadoing", tokenizer)
    tokenizer.model_max_length = 6000
    # tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token = tokenizer.unk_token  # use unk rather than eos token to prevent endless generation
    tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
    tokenizer.padding_side = 'right'

    if tokenizer.pad_token is None:  #use this for OPT models
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.chat_template is None:
        tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"

    tokenizer.model_max_length = 6000 #use this for phi 3 isntruct models
    # tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token = tokenizer.unk_token  # use unk rather than eos token to prevent endless generation
    tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
    tokenizer.padding_side = 'right'

    def load_and_initialize_rm_llm(script_args, s_model):

        def initialize_model_with_classification_head(base_model_name, torch_dtype, model_type = "student"):
            """
            Start with a causal language model (CLM) and add a classification head (for scalar outputs).
            """
          
            model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch_dtype, trust_remote_code=True, attn_implementation="flash_attention_2")
            tokenizer = AutoTokenizer.from_pretrained(base_model_name)
            while isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
                model = model.module
            if model_type == "student":
 
                lm_head_input_size = model.lm_head.in_features
                print("Input size of lm_head:", lm_head_input_size)
                classification_head = nn.Linear(lm_head_input_size, 1)   
            elif model_type == "teacher":

                print("teacher hidden", model.config.hidden_size)
                classification_head = nn.Linear(model.config.hidden_size, 1)   
            model.classification_head = classification_head
            model.classification_head = model.classification_head.to(torch_dtype)
            return model, tokenizer


        def load_teacher_model_and_tokenizer(script_args, torch_dtype):
            # Load the main model
            model = AutoModelForCausalLM.from_pretrained(
                script_args.teacher_model_name_or_path, 
                torch_dtype=torch_dtype, 
                trust_remote_code=True
            )
            tokenizer = AutoTokenizer.from_pretrained(script_args.teacher_model_name_or_path)
            print("Running training with script_args.trainer_teacher_rm", script_args.trainer_teacher_rm)
            model_dir = script_args.teacher_model_name_or_path
            index_file_path = os.path.join(model_dir, "model.safetensors.index.json")

            with open(index_file_path, "r") as f:
                index_data = json.load(f)
            safetensors_path = os.path.join(model_dir, index_data['weight_map']["classification_head.weight"])
            safetensors_bias_path = os.path.join(model_dir, index_data['weight_map']["classification_head.bias"])
            with safe_open(safetensors_path, framework="pt", device="cpu") as f:
                classification_weight = f.get_tensor("classification_head.weight")
                
            with safe_open(safetensors_bias_path, framework="pt", device="cpu") as f:
                classification_bias = f.get_tensor("classification_head.bias")

          
            if hasattr(model, "classification_head"):
                model.classification_head.weight = torch.nn.Parameter(classification_weight)
                model.classification_head.bias = torch.nn.Parameter(classification_bias)
            else:
                model.classification_head = torch.nn.Linear(
                    classification_weight.shape[1], 
                    classification_weight.shape[0]
                )
                model.classification_head.weight = torch.nn.Parameter(classification_weight)
                model.classification_head.bias = torch.nn.Parameter(classification_bias)

 
            model.classification_head = model.classification_head.to(torch_dtype)

            return model, tokenizer
        print("Student model path:", script_args.student_model_name_or_path)
        print("Teacher model path:", script_args.teacher_model_name_or_path)
        student_model, student_tokenizer = initialize_model_with_classification_head(s_model, torch_dtype, model_type = "student")

        if script_args.trainer_teacher_rm:
            teacher_model, teacher_tokenizer = load_teacher_model_and_tokenizer(script_args, torch_dtype)
        else:
            teacher_model, teacher_tokenizer = initialize_model_with_classification_head(script_args.teacher_model_name_or_path, torch_dtype, model_type = "teacher")

        if script_args.ignore_bias_buffers:
            # torch distributed hack
            model._ddp_params_and_buffers_to_ignore = [
                name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
            ]
        return student_model, student_tokenizer, teacher_model, teacher_tokenizer

    experiment_logs = defaultdict(list)
    
    ######## MAIN DRDO TRAINING STARTS HERE AFTER LOADING THE STUDENT AND TEACHE (ORACLE) MODELS
    student_model, student_tokenizer, teacher_model, teacher_tokenizer = load_and_initialize_rm_llm(script_args,s_model)
    script_args.learning_rate = lr
    train_dataset, eval_dataset = get_dataset(script_args, train_data, validation_data, type = loss_type)
 
    
    # 4. initialize training arguments:
    training_args = DROConfig(
        per_device_train_batch_size=script_args.per_device_train_batch_size,
        per_device_eval_batch_size=script_args.per_device_eval_batch_size,
        num_train_epochs = 1, 
        max_steps=script_args.max_steps,
        logging_steps=script_args.logging_steps,
        save_steps=script_args.save_steps,
        save_strategy=script_args.save_strategy,
        gradient_accumulation_steps=script_args.gradient_accumulation_steps,
        gradient_checkpointing=script_args.gradient_checkpointing,
        learning_rate=script_args.learning_rate,
        evaluation_strategy="steps",
        eval_steps=script_args.eval_steps,
        logging_strategy='steps',
        logging_first_step = True, 
        output_dir=script_args.output_dir,
        report_to=script_args.report_to,
        lr_scheduler_type=script_args.lr_scheduler_type,
        warmup_steps=script_args.warmup_steps,
        optim=script_args.optimizer_type,
        bf16=True,
        remove_unused_columns=False,
        run_name="drdo_tldr",
        gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
        seed=script_args.seed,
        generate_during_eval = True,
        loss_type = loss_type
    
    )
    training_args.beta = beta
    training_args.generate_during_eval = False
    training_args.loss_type = loss_type  #checking all loss functions now, sigmoid, bcm robust, sigmoid is still working with direct rewards being scaled,
    training_args.reward_scaling_factor_alpha = delta
    training_args.ull_scaling_factor = 0.01 #is the alpha parameter in DRDO loss in the log-unlikelihood focal component 

    print("printing  all hyperparamters: loss, reward alpha, KL beta and LR",training_args.loss_type, training_args.reward_scaling_factor_alpha,
    training_args.beta,training_args.learning_rate, script_args.learning_rate)
    print("Generate during training ",training_args.generate_during_eval )
    peft_config = LoraConfig(
        r=script_args.lora_r,
        lora_alpha=script_args.lora_alpha,
        lora_dropout=script_args.lora_dropout,
        target_modules=[
            "q_proj",
            "v_proj",
            "k_proj",
            "out_proj",
            "fc_in",
            "fc_out",
            "wte",
        ],
        bias="none",
        task_type="CAUSAL_LM",
    )
    drdo_trainer = DROTrainer(
        model = student_model,
        ref_model=teacher_model,
        # loss_type = 
        args=training_args,
        beta=script_args.beta,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        # peft_config=peft_config,
        max_prompt_length=script_args.max_prompt_length,
        max_length=script_args.max_length,
    )
    drdo_trainer.loss_type = loss_type

    drdo_trainer.add_callback(EvaluateFirstStepCallback())
    print(f"Getting experiment: {s_model}, {loss_type}, {kl_beta}, {delta}, {lr}")
    drdo_trainer.evaluate() 
    drdo_trainer.train()
    output_dir = os.path.join(script_args.output_dir, script_args.experiment_run_name)
    os.makedirs(output_dir, exist_ok=True)
    # Save the model and tokenizer
    dro_trainer.model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    metrics_to_plot = [
        'eval_rewards/chosen',
        'eval_rewards/rejected',
        'rewards/accuracies',
        'eval_rewards/accuracies',
        'rewards/margins',
        'eval_rewards/margins',
        'logps/chosen',
        'eval_logps/chosen',
        'logps/rejected',
        'eval_logps/rejected',
        'logits/chosen',
        'eval_logits/chosen',
        'logits/rejected',
        'eval_logits/rejected',
        'train_loss',
        'nll_loss',
        'log_odds_loss',
        'simpo_loss',
        'ull_loss',
        'eval_total_loss',
        'eval_nll_loss',
        'eval_log_odds_loss',
        'eval_simpo_loss',
        'eval_ull_loss',
        'eval_rouge1',  
        'eval_rouge2',
        'eval_rougeL','eval_bleu'
        
    ]
 
        


