import os
import torch
import torch.nn.functional as F
import inspect
from sklearn.model_selection import train_test_split

import warnings
from collections import defaultdict
from dataclasses import FrozenInstanceError, replace, dataclass, field, asdict
from datasets import load_from_disk
from datasets import concatenate_datasets
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import pandas as pd
import torch.nn as nn
from accelerate.utils import gather_object
from datasets import Dataset,load_dataset
from torch.utils.data import DataLoader

from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_pt_utils import nested_detach
from transformers.trainer_utils import EvalPrediction
import sys 
from trl.import_utils import is_peft_available
from trl import RewardConfig, RewardTrainer
from trl.trainer.utils import RewardDataCollatorWithPadding, compute_accuracy, print_rich_table
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
if is_peft_available():
    from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoModelForSequenceClassification, AutoModelForCausalLM, AutoTokenizer,  HfArgumentParser, set_seed
import warnings
from accelerate import PartialState
from tqdm import tqdm
from evaluate import load
from trl import ModelConfig, RewardConfig, RewardTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
import evaluate
import numpy as np
tqdm.pandas()


@dataclass
class ScriptArguments:
    """
    The arguments for the DPO training script.
    """

    # data parameters
    beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})

    # training parameters
    model_name_or_path: Optional[str] = field(
        default="../sft/results/final_checkpoint",
        metadata={"help": "the location of the SFT model name or path"},
    )
    learning_rate: Optional[float] = field(default=5e-6, metadata={"help": "optimizer learning rate"})
    lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
    warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
    weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
    optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})

    per_device_train_batch_size: Optional[int] = field(default=6, metadata={"help": "train batch size per device"})
    per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
    gradient_accumulation_steps: Optional[int] = field(
        default=4, metadata={"help": "the number of gradient accumulation steps"}
    )
    gradient_checkpointing: Optional[bool] = field(
        default=True, metadata={"help": "whether to use gradient checkpointing"}
    )

    gradient_checkpointing_use_reentrant: Optional[bool] = field(
        default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
    )

    lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
    lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
    lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
    dataset: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset used for training and evaluation "})

    max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
    max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
    dataset_num_proc: Optional[int] = field(default=24, metadata={"help": "the number of processes to preprocess and tokenize dataset"})
    max_steps: Optional[int] = field(default=1250, metadata={"help": "max number of training steps"})
    logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
    save_steps: Optional[int] = field(default=400, metadata={"help": "the saving frequency"})
    eval_steps: Optional[int] = field(default=20, metadata={"help": "the evaluation frequency"})

    output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
    log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
    load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
    model_dtype: Optional[str] = field(
        default="bfloat16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
    )

    # instrumentation
    sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
    report_to: Optional[str] = field(
        default="wandb",
        metadata={
            "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
            '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
            'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
        },
    )
    # debug argument for distributed training
    ignore_bias_buffers: Optional[bool] = field(
        default=False,
        metadata={
            "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
            "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
        },
    )
    seed: Optional[int] = field(
        default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
    )

class OracleRewardTrainer(RewardTrainer):
    r"""
    DenseRewardTrainer is a subclass of the RewardTrainer that is specifically designed for training rewards with hte generic RM (log-sigmoid loss) + LM head (NLL loss)
    Adapted from https://arxiv.org/pdf/2406.10216 where they use 
    """
    
    def __init__(
        self,
        script_args: Dict[str, Union[str, float, int, bool, Dict]],
        model: Optional[Union[PreTrainedModel, nn.Module]] = None,
        teacher_model:Optional[Union[PreTrainedModel, nn.Module]] = None, # load the teacher model for reward model distillation
        args: Optional[RewardConfig] = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        teacher_tokenizer:Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
            None,
            None,
        ),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        max_length: Optional[int] = None,
        peft_config: Optional[Dict] = None,

        reward_scaling_factor_alpha : float = 0.01,  # New argument for scaling dense rewards,
        reward_kl_beta  = 0.1,  # KL regularization factor,
        is_encoder_decoder = None,
        beta: float = 0.1,
        label_smoothing: float = 0,
        loss_type: Optional[str] = None,
        truncation_mode: str = "keep_end",
        label_pad_token_id: int = -100,
    
        max_prompt_length: Optional[int] = None,
        max_target_length: Optional[int] = None,
        dataset_num_proc: Optional[int] = None,

        student_model_init_kwargs: Optional[Dict] = None,
        teacher_model_init_kwargs: Optional[Dict] = None,
        student_model_adapter_name: Optional[str] = None,
        teacher_model_adapter_name: Optional[str] = None,
        # reference_free: bool = False,
        force_use_ref_model: bool = False,
        is_vision_model: bool = False,
    ):
        """
        Initialize DenseRewardTrainer.

        Args:
            reward_scaling_factor (`float`, defaults to `1.0`):
                A scaling factor to apply to the dense rewards, allowing for easier handling of different reward distributions.
            All other arguments are inherited from the RewardTrainer.
        """
        

        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
            max_length=max_length,
            peft_config=peft_config,
            
        )
        # self.teacher_model = teacher_model
        self.student_tokenizer = tokenizer
        self.tokenizer = tokenizer
        self.teacher_model = teacher_model
        self.teacher_tokenizer = teacher_tokenizer
        self.reward_scaling_factor_alpha = reward_scaling_factor_alpha
        self.reward_kl_beta = reward_kl_beta
        self.beta = beta
        self.label_smoothing = label_smoothing
        self.loss_type = loss_type
        self.truncation_mode =truncation_mode
        self.label_pad_token_id = label_pad_token_id
        self.max_prompt_length = args.max_prompt_length
        self.max_target_length = args.max_target_length
        self.is_encoder_decoder = False #boolean is true if you want to use enc decoder model instead of decoder only 
        self.is_vision_model = False
        self._stored_metrics = defaultdict(lambda: defaultdict(list))
        self.script_args = script_args

        
    def build_tokenized_answer(self, prompt, answer):
        """
        Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
        It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
        Reference:
            https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
        """

        full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
        prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]

        answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
        answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]

        # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
        full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])

        # Prepare input tokens for token by token comparison
        full_input_ids = np.array(full_tokenized["input_ids"])

        if len(full_input_ids) != len(full_concat_input_ids):
            raise ValueError("Prompt input ids and answer input ids should have the same length.")

        # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
        # can be merged together when tokenizing prompt+answer. This could result
        # on the last token from the prompt being different when tokenized on its own
        # vs when done as prompt+answer.
        response_token_ids_start_idx = len(prompt_input_ids)

        # If tokenized prompt is different than both prompt+answer, then it means the
        # last token has changed due to merging.
        if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
            response_token_ids_start_idx -= 1

        prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
        prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]

        if len(prompt_input_ids) != len(prompt_attention_mask):
            raise ValueError("Prompt input ids and attention mask should have the same length.")

        answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
        answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]

        return_dict = dict(
            prompt_input_ids=prompt_input_ids,
            prompt_attention_mask=prompt_attention_mask,
            input_ids=answer_input_ids,
            attention_mask=answer_attention_mask,
        )
        if "pixel_values" in full_tokenized:
            return_dict["prompt_pixel_values"] = full_tokenized["pixel_values"]
        if "pixel_attention_mask" in full_tokenized:
            return_dict["prompt_pixel_attention_mask"] = full_tokenized["pixel_attention_mask"]

        return return_dict


    def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
        """Tokenize a single row from a DPO specific dataset.

        At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
        in case the prompt + chosen or prompt + rejected responses is/are too long. First
            we truncate the prompt; if we're still too long, we truncate the chosen/rejected.

        We also create the labels for the chosen/rejected responses, which are of length equal to
            the sum of the length of the prompt and the chosen/rejected response, with
            label_pad_token_id  for the prompt tokens.
        """
        batch = {}
        prompt = feature["prompt"]
        chosen = feature["chosen"]
        rejected = feature["rejected"]

        def add_bos_token_if_needed(
            bos_token_id: Optional[int],
            prompt_len_input_ids: int,
            prompt_tokens: Dict[str, List[int]],
            chosen_prompt_len_input_ids: int,
            chosen_tokens: Dict[str, List[int]],
            rejected_prompt_len_input_ids: int,
            rejected_tokens: Dict[str, List[int]],
        ):
            if bos_token_id is not None:
                if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]:
                    prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"]
                    prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
                if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]:
                    chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"]
                    chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
                if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]:
                    rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"]
                    rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
            return prompt_tokens, chosen_tokens, rejected_tokens


        def add_eos_token_if_needed(
            eos_token_id: int, chosen_tokens: Dict[str, List[int]], rejected_tokens: Dict[str, List[int]]
        ):
            if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]:
                chosen_tokens["input_ids"].append(eos_token_id)
                chosen_tokens["attention_mask"].append(1)
            if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]:
                rejected_tokens["input_ids"].append(eos_token_id)
                rejected_tokens["attention_mask"].append(1)
            return chosen_tokens, rejected_tokens


        if not self.is_encoder_decoder:
            # Check issues below for more details
            #  1. https://github.com/huggingface/trl/issues/907
            #  2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
            #  3. https://github.com/LianjiaTech/BELLE/issues/337

            if not isinstance(prompt, str):
                raise ValueError(f"prompt should be an str but got {type(prompt)}")
            if self.is_vision_model:
                if "add_special_tokens" in inspect.signature(self.processor).parameters:
                    processor_kwargs = {"add_special_tokens": False}
                else:
                    processor_kwargs = {}
                prompt_tokens = self.processor(prompt, images=images, **processor_kwargs)
                prompt_tokens = {k: v[0] for k, v in prompt_tokens.items()}  # Unbatch, not done when using idefics
                if not isinstance(prompt_tokens["input_ids"], list):  # llava processor returns tensors
                    prompt_tokens["input_ids"] = prompt_tokens["input_ids"].tolist()
                    prompt_tokens["attention_mask"] = prompt_tokens["attention_mask"].tolist()
            else:
                prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)

            prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}

            if not isinstance(chosen, str):
                raise ValueError(f"chosen should be an str but got {type(chosen)}")

            chosen_tokens = self.build_tokenized_answer(prompt, chosen, images)

            if not isinstance(rejected, str):
                raise ValueError(f"rejected should be an str but got {type(rejected)}")
            rejected_tokens = self.build_tokenized_answer(prompt, rejected, images)

            # Last prompt token might get merged by tokenizer and
            # it should not be included for generation if that happens
            prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])

            chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
            rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
            prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)

            for k, v in prompt_tokens.items():
                prompt_tokens[k] = v[:prompt_len_input_ids]

            # Make sure prompts only have one different token at most an
            # and length only differs by 1 at most
            num_diff_tokens = sum(
                [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
            )
            num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
            if num_diff_tokens > 1 or num_diff_len > 1:
                raise ValueError(
                    "Chosen and rejected prompt_input_ids might only differ on the "
                    "last token due to tokenizer merge ops."
                )

            # add BOS token to head of prompt. Avoid adding if it's already there
            prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
                self.tokenizer.bos_token_id,
                prompt_len_input_ids,
                prompt_tokens,
                chosen_prompt_len_input_ids,
                chosen_tokens,
                rejected_prompt_len_input_ids,
                rejected_tokens,
            )

            # add EOS token to end of answer. Avoid adding if it's already there
            chosen_tokens, rejected_tokens = add_eos_token_if_needed(
                self.tokenizer.eos_token_id, chosen_tokens, rejected_tokens
            )

            longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

            # if combined sequence is too long, truncate the prompt
            for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
                if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
                    if self.truncation_mode == "keep_start":
                        for k in ["prompt_input_ids", "prompt_attention_mask"]:
                            answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
                    elif self.truncation_mode == "keep_end":
                        for k in ["prompt_input_ids", "prompt_attention_mask"]:
                            answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
                    else:
                        raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")

            # if that's still too long, truncate the response
            for answer_tokens in [chosen_tokens, rejected_tokens]:
                if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
                    for k in ["input_ids", "attention_mask"]:
                        answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]

            # Create labels
            chosen_sequence_tokens = {
                k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
            }
            rejected_sequence_tokens = {
                k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
            }
            chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
            chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
                self.label_pad_token_id
            ] * len(chosen_tokens["prompt_input_ids"])
            rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
            rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
                self.label_pad_token_id
            ] * len(rejected_tokens["prompt_input_ids"])

            for k, toks in {
                "chosen_": chosen_sequence_tokens,
                "rejected_": rejected_sequence_tokens,
                "": prompt_tokens,
            }.items():
                for type_key, tokens in toks.items():
                    if type_key == "token_type_ids":
                        continue
                    batch[f"{k}{type_key}"] = tokens

        else:
            chosen_tokens = self.tokenizer(
                chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
            )
            rejected_tokens = self.tokenizer(
                rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
            )
            prompt_tokens = self.tokenizer(
                prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
            )

            batch["chosen_labels"] = chosen_tokens["input_ids"]
            batch["rejected_labels"] = rejected_tokens["input_ids"]
            batch["prompt_input_ids"] = prompt_tokens["input_ids"]
            batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]

            if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
                batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
                    labels=torch.tensor(batch["rejected_labels"])
                )
                batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
                    labels=torch.tensor(batch["chosen_labels"])
                )

        return batch
    
    def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
        for key, value in metrics.items():
            self._stored_metrics[train_eval][key].append(value)
            
    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training, including stored metrics.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        # logs either has 'loss' or 'eval_loss'
        train_eval = "train" if "loss" in logs else "eval"
        # Add averaged stored metrics to logs
        for key, metrics in self._stored_metrics[train_eval].items():
            logs[key] = torch.tensor(metrics).mean().item()
        del self._stored_metrics[train_eval]
        return super().log(logs)
    
    def cross_entropy_loss(self, logits, labels):
            if not self.is_encoder_decoder:
                # Shift so that tokens < n predict n
                logits = logits.float()
                logits = logits[..., :-1, :].contiguous()
                labels = labels[..., 1:].contiguous()

            loss_fct = nn.CrossEntropyLoss()
            logits = logits.view(-1, logits.shape[-1])
            labels = labels.view(-1)
            # Enable model parallelism
            labels = labels.to(logits.device)
            loss = loss_fct(logits, labels)
            return loss

 

    def forward_with_lm_head(self, model, concatenated_batch, model_kwargs):
        """
        Custom forward pass to get the rewards (logits) and the vocab logits for the chosen response
        """
        # Handle DataParallel and DistributedDataParallel models
        if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model = model.module  # Access the underlying model in both cases
        
        # Extract inputs from the concatenated batch
        chosen_input_ids = concatenated_batch["input_ids_chosen"]
        chosen_attention_mask = concatenated_batch["attention_mask_chosen"]
        rejected_input_ids = concatenated_batch["input_ids_rejected"]
        rejected_attention_mask = concatenated_batch["attention_mask_rejected"]
 
        # Forward pass through the base model
        chosen_outputs = model.base_model(
            input_ids=chosen_input_ids, 
            attention_mask=chosen_attention_mask, 
            return_dict=True, 
            use_cache=False,
            **model_kwargs
        )

        chosen_outputs_causal_lm = model(
            input_ids=chosen_input_ids, 
            attention_mask=chosen_attention_mask, 
            return_dict=True, 
            use_cache=False,
            **model_kwargs
        )

        chosen_vocab_logits = chosen_outputs_causal_lm.logits
        chosen_last_hidden_state = chosen_outputs.last_hidden_state

        chosen_attention_mask_expanded = chosen_attention_mask.unsqueeze(-1).expand(chosen_last_hidden_state.size()).float()
        chosen_masked_hidden_state = chosen_last_hidden_state * chosen_attention_mask_expanded

        # Perform mean pooling over the sequence length, only for valid (non-masked) tokens
        chosen_pooled_output = chosen_masked_hidden_state.sum(dim=1) / chosen_attention_mask_expanded.sum(dim=1)

        # Compute vocab logits using the model's classification_head

        chosen_pooled_output = chosen_pooled_output.to(dtype=model.classification_head.weight.dtype)


        chosen_reward_logits = model.classification_head(chosen_pooled_output)

        rejected_outputs = model.base_model(
            input_ids=rejected_input_ids, 
            attention_mask=rejected_attention_mask, 
            return_dict=True, 
            use_cache=False,
            **model_kwargs
        )
        rejected_outputs_causal_lm = model(
            input_ids=rejected_input_ids, 
            attention_mask=rejected_attention_mask, 
            return_dict=True, 
            use_cache=False,
            **model_kwargs
        )


        rejected_vocab_logits = rejected_outputs_causal_lm.logits
        rejected_last_hidden_state = rejected_outputs.last_hidden_state

        rejected_attention_mask_expanded = rejected_attention_mask.unsqueeze(-1).expand(rejected_last_hidden_state.size()).float()
        rejected_masked_hidden_state = rejected_last_hidden_state * rejected_attention_mask_expanded

        # Perform mean pooling over the sequence length, only for valid (non-masked) tokens
        rejected_pooled_output = rejected_masked_hidden_state.sum(dim=1) / rejected_attention_mask_expanded.sum(dim=1)
        # rejected_last_hidden_state = rejected_last_hidden_state[:, -1, :] 

        # Compute vocab logits using the model's classification_head
        rejected_pooled_output = rejected_pooled_output.to(dtype=model.classification_head.weight.dtype)
        rejected_reward_logits = model.classification_head(rejected_pooled_output)

        return chosen_vocab_logits, chosen_reward_logits, chosen_last_hidden_state, rejected_vocab_logits, rejected_reward_logits, rejected_last_hidden_state


        # return vocab_logits, last_hidden_state

    def get_batch_loss_metrics_new(self, batch_metrics):
        
        metrics = {}
        rewards_chosen, rewards_rejected, chosen_nll_loss, chosen_nll_loss, loss, loss_logits_no_nll = batch_metrics
        prefix = "train_"
        metrics[f"{prefix}rewards/chosen"] = rewards_chosen.mean().cpu()
        metrics[f"{prefix}rewards/rejected"] = rewards_rejected.mean().cpu()
        
        reward_accuracies = (rewards_chosen > rewards_rejected).float()
       
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
        metrics[f"{prefix}rewards/margins"] = (rewards_chosen - rewards_rejected).mean().cpu()

        metrics[f"{prefix}nll_loss"] = chosen_nll_loss.detach().mean().cpu()
        metrics[f"{prefix}_logit_total_loss"] = loss.cpu()
        metrics[f"{prefix}_logit_loss"] = loss_logits_no_nll.cpu()
        return metrics
        
    def compute_loss(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs=False,
        train_eval= None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
        metrics = {}
        model_kwargs = {}
        
        if not self.use_reward_data_collator:
            warnings.warn(
                "The current compute_loss is implemented for RewardDataCollatorWithPadding,"
                " if you are using a custom data collator make sure you know what you are doing or"
                " implement your own compute_loss method."
            )

        # vocab_logits, logits_rewards, last_hidden_state = self.forward_with_lm_head(model, inputs)
        chosen_vocab_logits, chosen_reward_logits, chosen_last_hidden_state, rejected_vocab_logits, rejected_reward_logits, rejected_last_hidden_state = self.forward_with_lm_head(model, inputs, model_kwargs)

        rewards_chosen = chosen_reward_logits
        rewards_rejected = rejected_reward_logits

        chosen_labels = inputs["input_ids_chosen"].clone()
        attention_mask = inputs["attention_mask_chosen"]
        self.label_pad_token_id = self.teacher_tokenizer.pad_token_id
        self.label_pad_token_id = -100 # for NLL loss computation to ignore labels, 
        chosen_labels = torch.where(inputs["attention_mask_chosen"] == 1, chosen_labels, self.label_pad_token_id)
        chosen_nll_loss = self.cross_entropy_loss(chosen_vocab_logits, chosen_labels)
        
        mean_pooled_chosen = chosen_last_hidden_state.mean(dim=1)  # Shape: (batch_size, seq_length, hidden_size)
        mean_pooled_rejected = rejected_last_hidden_state.mean(dim=1)  # Shape: (batch_size, seq_length, hidden_size)
        cosine_similarity = F.cosine_similarity(mean_pooled_chosen, mean_pooled_rejected, dim=-1)  # Shape: (batch_size, seq_length)
        batch_cosine_distance = 1 - cosine_similarity
        average_cosine_distance_mean = 1 - cosine_similarity.mean(dim=-1)  # Shape: (batch_size)



        if "margin" in inputs:
            loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected  - inputs["margin"]).mean() * (1 - self.reward_scaling_factor_alpha) + self.reward_scaling_factor_alpha * chosen_nll_loss.mean()
            loss_logits_no_nll  = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected  - inputs["margin"]).mean()* (1 - self.reward_scaling_factor_alpha)
        else:

            loss_logits_no_nll  = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected ).mean()
            
            # loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected ).mean()*(1-reward_scaling_factor_alpha) + reward_scaling_factor_alpha*chosen_nll_loss.mean()
            loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() * (1 - self.reward_scaling_factor_alpha) + self.reward_scaling_factor_alpha * chosen_nll_loss.mean()
           
         

        if train_eval:
            prefix = "eval_" 
            metrics[f"{prefix}rewards/chosen"] = rewards_chosen.mean().cpu()
            metrics[f"{prefix}rewards/rejected"] = rewards_rejected.mean().cpu()

            reward_accuracies = (rewards_chosen > rewards_rejected).float()

            metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
            metrics[f"{prefix}rewards/margins"] = (rewards_chosen - rewards_rejected).mean().cpu()

            metrics[f"{prefix}nll_loss"] = chosen_nll_loss.detach().mean().cpu()
            metrics[f"{prefix}_logit_total_loss"] = loss.cpu()
            metrics[f"{prefix}_logit_loss"] = loss_logits_no_nll.cpu()

                        # Printing the computed metrics
            print(f"{prefix}rewards/chosen: {metrics[f'{prefix}rewards/chosen']}")
            print(f"{prefix}rewards/rejected: {metrics[f'{prefix}rewards/rejected']}")

            reward_accuracies = (rewards_chosen > rewards_rejected).float()
            print(f"{prefix}rewards/accuracies: {reward_accuracies.mean().cpu()}")

            print(f"{prefix}rewards/margins: {metrics[f'{prefix}rewards/margins']}")

            # Printing the NLL and total/logits losses
            print(f"{prefix}nll_loss: {metrics[f'{prefix}nll_loss']}")
            print(f"{prefix}_logit_total_loss: {metrics[f'{prefix}_logit_total_loss']}")
            print(f"{prefix}_logit_loss: {metrics[f'{prefix}_logit_loss']}")


            
            self.store_metrics(metrics, train_eval="eval")
#             print("eval metrics", metrics)
        else:
            batch_metrics = (rewards_chosen, rewards_rejected, chosen_nll_loss, chosen_nll_loss, loss, loss_logits_no_nll)
            metrics = self.get_batch_loss_metrics_new(batch_metrics)
            self.store_metrics(metrics, train_eval="eval")
#             print("train metrics", metrics)
        
        
        if return_outputs:
            return loss, {
                "rewards_chosen": rewards_chosen,
                "rewards_rejected": rewards_rejected}, {
            "nll_loss_chosen": chosen_nll_loss.mean().item(),
            "batch_cosine_distance": batch_cosine_distance.mean().item()
        }
        return loss


    def prediction_step(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        with torch.no_grad():
            loss, logits_dict, nll_dict = self.compute_loss(model, inputs, return_outputs=True, train_eval = "eval")

        if prediction_loss_only:
            return (loss, None, None)

        loss = loss.detach()
        logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
        logits = nested_detach(logits)
        logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
        
        # Ensure all nll_dict values are tensors before detaching
        for key, value in nll_dict.items():
            if isinstance(value, float):
                nll_dict[key] = torch.tensor(value)
        labels = torch.zeros(logits.shape[0])
        labels = self._prepare_inputs(labels)

        return loss, logits, labels


    def evaluate(self, *args, **kwargs):
        num_print_samples = kwargs.pop("num_print_samples", 4)
        # self.visualize_samples(num_print_samples)
        return super().evaluate(*args, **kwargs)
    
    
    def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
        for key, value in metrics.items():
            self._stored_metrics[train_eval][key].append(value)

    def sanitize_logit_values(self, logits):
    # Replace NaN values with a valid placeholder (e.g., 0, None, or 'nan')
        return [[0 if np.isnan(inner_item) else inner_item for inner_item in item] for item in logits]

    def visualize_samples(self, num_print_samples: int):
        """
        Visualize the reward model logits prediction

        Args:
            num_print_samples (`int`, defaults to `4`):
                The number of samples to print. Set to `-1` to print all samples.
        """
        eval_dataloader = self.get_eval_dataloader()
        table = defaultdict(list)
        
        for _, inputs in enumerate(eval_dataloader):
            _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
            chosen_text = self.tokenizer.batch_decode(inputs["input_ids_chosen"], skip_special_tokens=True)
            rejected_text = self.tokenizer.batch_decode(inputs["input_ids_rejected"], skip_special_tokens=True)
            
            # Gather text and logits
            table["chosen_text"].extend(gather_object(chosen_text))
            table["rejected_text"].extend(gather_object(rejected_text))
            
            # Sanitize logits by replacing NaN values with 0 (or another placeholder)
            sanitized_logits = self.sanitize_logit_values(logits.tolist())
            table["logits"].extend(gather_object([[round(inner_item, 4) for inner_item in item] for item in sanitized_logits]))
            
            # Break if we've reached the desired number of print samples
            if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
                break

        # Convert table to a dataframe
        df = pd.DataFrame(table)
        
        # Print and log results
        if self.accelerator.process_index == 0:
            print_rich_table(df[:num_print_samples])
            if "wandb" in self.args.report_to:
                import wandb
                if wandb.run is not None:
                    wandb.log({"completions": wandb.Table(dataframe=df)})


def initialize_lm_head_from_causal_model(base_model_name, model):
    # model
    # model = AutoModelForSequenceClassification.from_pretrained(base_model_name)
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    base_model = model.base_model
    lm_head = nn.Linear(base_model.config.hidden_size, base_model.config.vocab_size, bias=False)
    causal_model = AutoModelForCausalLM.from_pretrained(base_model_name)
    lm_head.weight.data = causal_model.lm_head.weight.data.clone()
    model.lm_head = lm_head
    return model


if __name__ == "__main__":
 
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]

    set_seed(script_args.seed)
    script_args_dict = asdict(script_args)

    config = RewardConfig(
    output_dir=script_args.output_dir,
    per_device_train_batch_size=12,
    num_train_epochs=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    learning_rate=1e-5,
    report_to="wandb",
    remove_unused_columns=False,
    optim="adamw_torch",
    logging_steps=50,
    evaluation_strategy="steps",
    save_strategy="epoch",
    eval_steps=200,
    max_length=1024,
 
)

    config.max_prompt_length = 128
    config.max_target_length = 1024
    print("config max length", config.max_prompt_length,config.max_target_length ) 

    student_model_config = ModelConfig(
        model_name_or_path="facebook/opt-350m",



        #facebook/opt-1.3b, facebook/opt-350m
        # Add other model-specific arguments as needed
    )

    teacher_model_config = ModelConfig(
        model_name_or_path="microsoft/Phi-3-mini-4k-instruct")


    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    config.gradient_checkpointing_kwargs = dict(use_reentrant=False)

    ################
    # Model & Tokenizer
    ################
    torch_dtype = (
        student_model_config.torch_dtype
        if student_model_config.torch_dtype in ["auto", None]
        else getattr(torch, student_model_config.torch_dtype)
    )
    student_quantization_config = get_quantization_config(student_model_config)
    teacher_quantization_config = get_quantization_config(teacher_model_config)

    model_kwargs = dict(
        revision=student_model_config.model_revision,
        device_map=get_kbit_device_map() if student_quantization_config is not None else None,
        quantization_config=student_quantization_config,
    )

    teacher_model_kwargs = dict(
        revision=teacher_model_config.model_revision,
        device_map=get_kbit_device_map() if teacher_quantization_config is not None else None,
        quantization_config = teacher_quantization_config
       
    )

    teacher_tokenizer = AutoTokenizer.from_pretrained(
        teacher_model_config.model_name_or_path, trust_remote_code=teacher_model_config.trust_remote_code, use_fast=True
    )


    if script_args.model_dtype == "float16":
        torch_dtype = torch.float16
    elif script_args.model_dtype == "bfloat16":
        torch_dtype = torch.bfloat16

    def load_and_initialize_rm_llm(script_args):

        def initialize_model_with_classification_head(base_model_name, torch_dtype, model_type = "student"):
            """
            Start with a causal language model (CLM) and add a classification head (for scalar outputs).
            """
            # Load the causal model 
            # attn_implementation="flash_attention_2"

            if model_type == "student":
                model = AutoModelForCausalLM.from_pretrained(base_model_name, trust_remote_code=True,torch_dtype=torch_dtype)
            else:
                model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch_dtype, trust_remote_code=True, attn_implementation="flash_attention_2")
            tokenizer = AutoTokenizer.from_pretrained(base_model_name)

            # Unwrap DDP or DataParallel layers if applicable
            while isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
                model = model.module

            # Add a scalar classification head (for regression or classification tasks)
            if model_type == "student":
                classification_head = nn.Linear(model.config.hidden_size, 1)  # 1 output for scalar value (reward)
            elif model_type == "teacher":
                print("teacher hidden", model.config.hidden_size)
                classification_head = nn.Linear(model.config.hidden_size, 1)  # 1 output for scalar value (reward)
            # Attach the classification head to the model
            model.classification_head = classification_head

            # Ensure the model uses the correct precision (e.g., float16 or bfloat16)
            model.classification_head = model.classification_head.to(torch_dtype)

            return model, tokenizer

        student_model, student_tokenizer = initialize_model_with_classification_head(student_model_config.model_name_or_path, torch_dtype, model_type = "student")
        teacher_model, teacher_tokenizer = initialize_model_with_classification_head(teacher_model_config.model_name_or_path, torch_dtype, model_type = "teacher")
 

        if script_args.ignore_bias_buffers:
            # torch distributed hack
            model._ddp_params_and_buffers_to_ignore = [
                name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
            ]

        # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
        teacher_tokenizer.model_max_length = 6000
    # tokenizer.pad_token = tokenizer.eos_token
        teacher_tokenizer.pad_token = teacher_tokenizer.unk_token  # use unk rather than eos token to prevent endless generation
        teacher_tokenizer.pad_token_id = teacher_tokenizer.convert_tokens_to_ids(teacher_tokenizer.pad_token)
        teacher_tokenizer.padding_side = 'right'
        return student_model, student_tokenizer, teacher_model, teacher_tokenizer


    student_model, student_tokenizer, teacher_model, teacher_tokenizer = load_and_initialize_rm_llm(script_args)
 
    def initialize_lm_head_from_causal_model(base_model_name, model, device):

        """
        Fits an LM head from the original causal model weights into the RM to generalize/regularize the RM training process
        Returns the RM with both the reward head (scalar scores) + LM head (initialized with the casual model counterpart)
        """
        tokenizer = AutoTokenizer.from_pretrained(base_model_name)
        base_model = model.base_model
        # Initialize and move the LM head to the device
        lm_head = nn.Linear(base_model.config.hidden_size, base_model.config.vocab_size, bias=False).to(device)
        causal_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
        lm_head.weight.data = causal_model.lm_head.weight.data.clone()
        model.lm_head = lm_head
        return model

    if student_model_config.lora_task_type != "SEQ_CLS":
        warnings.warn(
            "You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
            " Make sure to pass --lora_task_type SEQ_CLS when using this script."
        )

    if teacher_model_config.lora_task_type != "SEQ_CLS":
        warnings.warn(
            "You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
            " Make sure to pass --lora_task_type SEQ_CLS when using this script."
        )

    ################
    # Dataset
    ################

    def get_tldr_data(tokenizer, 
        split, 
        sanity_check: bool = False,
        cache_dir: Optional[str] = None,
        num_proc=24,

    ):
        data_paths = {
            'high_conf_high_edit': {
                'train': "##",
                'validation': "##",
            },
       
            'low_conf_low_edit': {
                'train':"##",
                'validation': "##",
            },
        }

        # Load all datasets
        loaded_datasets = {}
        for name, paths in data_paths.items():
            print(f"Loading {name} train split...")
            train_data = load_from_disk(paths['train'])
            print(f"Loading {name} validation split...")
            validation_data = load_from_disk(paths['validation'])
            loaded_datasets[name] = {
                'train': train_data,
                'validation': validation_data
            }

        # Access loaded datasets
        updated_high_conf_high_edit_train = loaded_datasets['high_conf_high_edit']['train']
        updated_high_conf_high_edit_validation = loaded_datasets['high_conf_high_edit']['validation']
        updated_low_conf_low_edit_train = loaded_datasets['low_conf_low_edit']['train']
        updated_low_conf_low_edit_validation = loaded_datasets['low_conf_low_edit']['validation']

        # Concatenate all train and validation splits into a single dataset
        all_train_datasets = concatenate_datasets([
            updated_high_conf_high_edit_train,
            updated_high_conf_low_edit_train,
           
            updated_low_conf_high_edit_validation,
            updated_low_conf_low_edit_validation
        ])

    
        dataset_indices = np.arange(len(all_train_datasets))

        # Split the indices into train and validation sets (5% for validation)
        train_indices, validation_indices = train_test_split(
            dataset_indices,
            test_size=0.05,
            shuffle=True,
            random_state=42
        )

        # Create train and validation datasets using the indices
        train_dataset = all_train_datasets.select(train_indices)
        validation_dataset = all_train_datasets.select(validation_indices)
        print(f"Combined training dataset size: {len(train_dataset)}")
        print(f"Combined validation dataset size: {len(validation_dataset)}")

        dataset = load_dataset("argilla/ultrafeedback-binarized-preferences-cleaned", split=split)
        print(train_dataset)
        print(validation_dataset)
        if sanity_check:
            train_dataset = train_dataset.shuffle().select(range(1000)) #selecting
        original_columns = train_dataset.column_names
        selected_columns = ['prompt', 'chosen', 'rejected']
        removable_columns = [ x for x in original_columns if x not in selected_columns]

        def return_prompt_and_responses(row):
            row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
            row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
            return row
        train_dataset = train_dataset.map(
            return_prompt_and_responses,

            num_proc= num_proc,
            remove_columns=removable_columns,
        ) 
        
        validation_dataset = validation_dataset.map(
            return_prompt_and_responses,

            num_proc= num_proc,
            remove_columns=removable_columns,
        )
        


        with PartialState().local_main_process_first():
            train_dataset = train_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=24,
            )

            eval_dataset = validation_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=24,
            )

        print("size of train and eval dataset", train_dataset, eval_dataset)
        return train_dataset, eval_dataset
            
    
    def preprocess_function(examples):
        max_length = 1024  # Set the maximum length to 1024 tokens
        new_examples = {
            "input_ids_chosen": [],
            "attention_mask_chosen": [],
            "input_ids_rejected": [],
            "attention_mask_rejected": [],
        }

        for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
            # Tokenize the chosen and rejected examples
            tokenized_chosen = teacher_tokenizer(chosen)
            tokenized_rejected = teacher_tokenizer(rejected)

   
            input_ids_chosen = tokenized_chosen["input_ids"][:script_args.max_length]
            attention_mask_chosen = tokenized_chosen["attention_mask"][:script_args.max_length]

            input_ids_rejected = tokenized_rejected["input_ids"][:script_args.max_length]
            attention_mask_rejected = tokenized_rejected["attention_mask"][:script_args.max_length]

                # Append to the new_examples dictionary
            new_examples["input_ids_chosen"].append(input_ids_chosen)
            new_examples["attention_mask_chosen"].append(attention_mask_chosen)
            new_examples["input_ids_rejected"].append(input_ids_rejected)
            new_examples["attention_mask_rejected"].append(attention_mask_rejected)


            # Truncate the tokenized input ids and attention masks to max_length
            new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"][:script_args.max_length])
            new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"][:script_args.max_length])
            new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"][:script_args.max_length])
            new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"][:script_args.max_length])

             # Print the max and min input_ids for chosen and rejected
            print(f"Max ID in input_ids_chosen: {max(input_ids_chosen)}")
            print(f"Min ID in input_ids_chosen: {min(input_ids_chosen)}")

            print(f"Max ID in input_ids_rejected: {max(input_ids_rejected)}")
            print(f"Min ID in input_ids_rejected: {min(input_ids_rejected)}")


        return new_examples


        # 2. Load the Ultrafeedback binarized and cleaned dataset 
    script_args.dataset = "ultrafeedback_binarized"
    if script_args.dataset == "ultrafeedback_binarized":
        print(f"processing {script_args.dataset} dataset", teacher_tokenizer)
        train_dataset, eval_dataset = get_tldr_data(teacher_tokenizer, 
        split = "train", 
        sanity_check= False,
        cache_dir = None,
        num_proc=24)


        print("size of train and eval dataset", train_dataset, eval_dataset)


    ################
    # Training
    ################
    
    # Define the metric that we'll use for validation.
    accuracy = evaluate.load("accuracy")
    
    def compute_metrics(eval_pred):
        predictions, _ = eval_pred
        # Here, predictions is rewards_j and rewards_k.
        # We want to see how much of the time rewards_j > rewards_k.
        predictions = np.argmax(predictions, axis=0)
        labels = np.zeros(predictions.shape)
        return accuracy.compute(predictions=predictions, references=labels)
 
    trainer = OracleRewardTrainer(
        script_args = script_args_dict,
        model=teacher_model,
        teacher_model = teacher_model, 
        tokenizer=teacher_tokenizer,
        teacher_tokenizer = teacher_tokenizer,
        args=config,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        peft_config=get_peft_config(student_model_config),
      
#         compute_metrics=compute_metrics,
    )
    trainer.train()
    trainer.save_model(config.output_dir)
    # trainer.push_to_hub()
    metrics = trainer.evaluate()
    trainer.log_metrics("eval", metrics)
    print(metrics)