# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Classes and functions related to training of neural networks.

"""
import copy
import contextlib
from functools import partial
import glob
import math
import os
import random
import re
import shutil
import sys
import time
import warnings
from collections import OrderedDict
from packaging import version
import torch
from torch import nn
from typing import Any, Dict, List, Optional, Tuple, Union, Callable

from accelerate import __version__ as accelerate_version
from accelerate import PartialState
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from accelerate.utils.other import is_compiled_module
from collections import defaultdict
from datasets import Dataset, IterableDataset

import deepspeed
from deepspeed import comm as dist
from deepspeed.git_version_info import version as deepspeed_version
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.utils import log_dist

from peft import PeftModel
from trl import (
    PreTrainedModelWrapper, 
    # PPOConfig, 
    # PPOTrainer,
    GRPOConfig,
    GRPOTrainer,
    apply_chat_template,
    is_conversational,
    maybe_apply_chat_template,
)
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import is_vllm_available
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from trl.trainer.utils import pad
from unittest.mock import patch

import transformers
from transformers import (
    Trainer,
    PreTrainedTokenizerBase,
    AutoModelForCausalLM,
    TrainerCallback,
    AutoTokenizer,
    GenerationConfig,
    )
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
if version.parse(transformers.__version__) < version.parse('4.48'):
    from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint
else:
    from transformers.integrations import deepspeed_init, deepspeed_load_checkpoint
from transformers.generation.logits_process import LogitsProcessorList
from transformers.integrations import hp_params
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.modeling_utils import unwrap_model, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.trainer import TRAINER_STATE_NAME, OPTIMIZER_NAME, SCHEDULER_NAME, SCALER_NAME, TRAINING_ARGS_NAME
from transformers.trainer_callback import TrainerState
from transformers.trainer_pt_utils import (
    get_model_param_count,
    nested_concat,
    nested_detach,
)
from transformers.trainer_utils import HPSearchBackend, TrainOutput, has_length, speed_metrics
from transformers.training_args import ParallelMode
from transformers.utils import (
    can_return_loss,
    is_sagemaker_mp_enabled,
    is_torch_tpu_available,
    is_peft_available,
    is_safetensors_available,
    is_accelerate_available,
    is_torch_xla_available,
    is_apex_available,
    WEIGHTS_NAME,
    SAFE_WEIGHTS_NAME,
)

from utils.rl_utils import select_diverse_subset

if is_sagemaker_mp_enabled():
    from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
    import smdistributed.modelparallel.torch as smp

if is_safetensors_available():
    import safetensors.torch

if is_accelerate_available():
    from accelerate import skip_first_batches

if is_torch_xla_available():
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met

if is_apex_available():
    from apex import amp

if is_vllm_available():
    from vllm import LLM, SamplingParams


import logging

logger = logging.getLogger(__name__)


def count_parameters(model: nn.Module):
    """ count the num_trainable_param, num_untrainable_param, and storage in MB.
    """
    num_trainable_param = []
    num_untrainable_param = []
    storage = []
    message = ""
    for name, p in model.named_parameters():
        n = p.numel()
        if p.requires_grad:
            num_trainable_param.append(n)
            message+='-{}--{}--{}--{}--trainable\n'.format(name, p.shape, p.device, p.dtype)
        else:
            num_untrainable_param.append(n)
            message+='-{}--{}--{}--{}--fixed\n'.format(name, p.shape, p.device, p.dtype)
        if p.dtype in (torch.float32, torch.int32):
            storage.append(n*4)
        elif p.dtype in (torch.float16, torch.int16, torch.bfloat16):
            storage.append(n*2)
        elif p.dtype in (torch.float64, torch.int64):
            storage.append(n*8)
        elif p.dtype == torch.int8:
            storage.append(n)
        else:
            logger.warning('count_parameters: data type {} is encountered.'.format(p.dtype), ResourceWarning)
    logger.info('Model details:\n{}'.format(message))
    
    num_trainable_param = sum(num_trainable_param)
    num_untrainable_param = sum(num_untrainable_param)
    num_all_param = (num_trainable_param + num_untrainable_param)
    storage = sum(storage) / 1024. / 1024.
    logger.info(
        """Model summary \n -num. trainable param: {} ({}) \n -num. non-trainable param: {} ({}) \n -estimated memory: {:.3f} MB\n
        """.format(
            num_trainable_param, num_trainable_param/num_all_param, 
            num_untrainable_param, num_untrainable_param/num_all_param, storage)
        )


class RenameCKPTFiles(object):
    def __init__(self, model_name_or_path):
        if model_name_or_path is not None:
            self.trainer_state_name = os.path.join(model_name_or_path, TRAINER_STATE_NAME)
            self._trainer_state_name = os.path.join(model_name_or_path, 'copy_' + TRAINER_STATE_NAME)

            self.optimizer_name = os.path.join(model_name_or_path, OPTIMIZER_NAME)
            self._optimizer_name = os.path.join(model_name_or_path, 'copy_' + OPTIMIZER_NAME)

            self.scheduler_name = os.path.join(model_name_or_path, SCHEDULER_NAME)
            self._scheduler_name = os.path.join(model_name_or_path, 'copy_' + SCHEDULER_NAME)

            self.scaler_name = os.path.join(model_name_or_path, SCALER_NAME)
            self._scaler_name = os.path.join(model_name_or_path, 'copy_' + SCALER_NAME)
        else:
            self.trainer_state_name = None
            self._trainer_state_name = None
            self.optimizer_name = None
            self._optimizer_name = None
            self.scheduler_name = None
            self._scheduler_name = None
            self.scaler_name = None
            self._scaler_name = None

    def rename_one_file(self, source_file, target_file):
        if source_file is not None and os.path.isfile(source_file):
            try:
                os.rename(source_file, target_file)
            except FileNotFoundError:
                if os.path.isfile(target_file):
                    print('RenameCKPTFiles: {} cannot be renamed as it has been renamed by other process.'.format(source_file))
                else:
                    raise FileNotFoundError('RenameCKPTFiles: {} cannot be renamed due to unexpected reason. To be debugged.'.format(source_file))

    def rename_files(self):
        self.rename_one_file(self.trainer_state_name, self._trainer_state_name)
        self.rename_one_file(self.optimizer_name, self._optimizer_name)
        self.rename_one_file(self.scheduler_name, self._scheduler_name)
        self.rename_one_file(self.scaler_name, self._scaler_name)

    def restore_file_names(self):
        self.rename_one_file(self._trainer_state_name, self.trainer_state_name)
        self.rename_one_file(self._optimizer_name, self.optimizer_name)
        self.rename_one_file(self._scheduler_name, self.scheduler_name)
        self.rename_one_file(self._scaler_name, self.scaler_name) 


def _zero3_consolidated_16bit_trainable_state_dict(self):
    """This function monkey-patches 
    deepspeed.runtime.engine.DeepSpeedEngine._zero3_consolidated_16bit_state_dict 
    to save only the trainable parameters.
    """
    if not self.zero_optimization_partition_weights():
        raise ValueError("this function requires ZeRO-3 mode")

    state_dict = OrderedDict() if dist.get_rank() == 0 else None
    shared_params = {}

    def get_layer_state_dict(module, prefix=""):
        # gather one layer at a time to be memory-efficient
        # must use modifier_rank=0 to release GPU memory after each layer gathered
        #see_memory_usage("before GatheredParameters", force=True)
        with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
            if dist.get_rank() == 0:
                # handle params
                for name, param in module.named_parameters(recurse=False):
                    if (param is None) or (not param.requires_grad):
                        continue
                    if 'lora' in prefix and prefix.endswith("default."):
                        prefix = prefix.replace("default.", "")
                    key = prefix + name
                    # can't rely on param.data_ptr() as it will be reused as weights gets
                    # gathered and reduced, but param.ds_id is unique across all zero weights
                    # (and shared params will have the same param.ds_id)
                    if param.ds_id in shared_params:
                        # shared weights
                        #print(f"`{key}` is shared with `{shared_params[param.ds_id]}`")
                        state_dict[key] = state_dict[shared_params[param.ds_id]]
                    else:
                        state_dict[key] = param.detach().cpu()
                        shared_params[param.ds_id] = key
                    #print(f"param {param.ds_id} {param.shape} {key} ")

                # now buffers - not sure if need to take care of potentially shared weights here
                for name, buf in module.named_buffers(recurse=False):
                    if (buf is not None and name not in module._non_persistent_buffers_set):
                        state_dict[prefix + name] = buf.detach().cpu()
        #see_memory_usage("after GatheredParameters", force=True)

        for name, child in module.named_children():
            if child is not None:
                get_layer_state_dict(child, prefix + name + ".")

    # Prepare for checkpoint save by ensuring all parameters are partitioned
    self.optimizer.checkpoint_event_prologue()

    see_memory_usage("before get_layer_state_dict", force=False)
    get_layer_state_dict(self.module, prefix="")
    see_memory_usage("after get_layer_state_dict", force=False)
    self.optimizer.checkpoint_event_epilogue()

    return state_dict


def _save_trainable_checkpoint(self, save_dir, tag, client_state={}):
    """This function monkey-patches 
    deepspeed.runtime.engine.DeepSpeedEngine._save_checkpoint 
    to save only the trainable parameters.
    """
    save_path = self._get_ckpt_name(save_dir, tag)

    zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()

    # in the original code, deepspeed stage 3 will make save_frozen_param True
    # save_frozen_param = self.zero_optimization_partition_gradients()
    save_frozen_param = False

    # A hack to save the checkpointing directory. Pipeline parallelism overrides
    # module_state_dict() and uses this path to save the model. module_state_dict()
    # then instead just returns None.  The module_state_dict() implementation in
    # PipelineEngine expects the save path to be set in self._curr_ckpt_path.
    self._curr_ckpt_path = os.path.join(save_dir, tag)
    module = self.module_state_dict()
    self._curr_ckpt_path = None

    state = dict(
        module=module,
        buffer_names=self._get_buffer_names(),
        optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None,
        param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None,
        frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func)
        if save_frozen_param else None,
        shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None,
        frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func)
        if save_frozen_param else None,
        lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
        data_sampler=self.training_dataloader.data_sampler.state_dict() if
        (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None,
        random_ltd=self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None,
        sparse_tensor_module_names=self.sparse_tensor_module_names,
        skipped_steps=self.skipped_steps,
        global_steps=self.global_steps,
        global_samples=self.global_samples,
        dp_world_size=self.dp_world_size,
        mp_world_size=self.mp_world_size,
        ds_config=self.config,
        ds_version=deepspeed_version
        )
    state.update(client_state)

    if self.save_non_zero_checkpoint:
        log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])
        self.checkpoint_engine.save(state, save_path)


def get_state_dict(self, model, unwrap=True):
    """This function monkey-patches 
    accelerate.accelerator.Accelerator.get_state_dict
    to not convert the state_dict to float.

    Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full
    precision.

    Args:
        model (`torch.nn.Module`):
            A PyTorch model sent through [`Accelerator.prepare`]
        unwrap (`bool`, *optional*, defaults to `True`):
            Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict

    Returns:
        `dict`: The state dictionary of the model potentially without full precision.

    Example:

    ```python
    >>> import torch
    >>> from accelerate import Accelerator

    >>> accelerator = Accelerator()
    >>> net = torch.nn.Linear(2, 2)
    >>> net = accelerator.prepare(net)
    >>> state_dict = accelerator.get_state_dict(net)
    ```
    """

    if self.distributed_type == "DEEPSPEED":
        if self.deepspeed_config["zero_optimization"]["stage"] == 3:
            if model.zero_gather_16bit_weights_on_model_save():
                state_dict = model._zero3_consolidated_16bit_state_dict()
            else:
                raise ValueError(
                    "Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
                    "To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
                    "set `zero3_save_16bit_model` to True when using `accelerate config`. "
                    "To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
                )
        else:
            from deepspeed.checkpoint.utils import clone_tensors_for_torch_save

            state_dict = clone_tensors_for_torch_save(self.unwrap_model(model).state_dict())
    else:
        if unwrap:
            model = self.unwrap_model(model)
        state_dict = model.state_dict()

    # we want the model to be saved in float16 / as it is.
    # if state_dict is not None:
    #     for k in state_dict:
    #         if getattr(state_dict[k], "dtype", None) == torch.float16:
    #             state_dict[k] = state_dict[k].float()

    return state_dict


class GenerateEvalTrainer(Trainer):
    def __init__(self, *args, evaluation_method, generation_config=None, **kwargs):
        super(GenerateEvalTrainer, self).__init__(*args, **kwargs)

        self.evaluation_method = evaluation_method
        self.generation_config = generation_config
    
    def compute_generate_loss(self, model, inputs, generation_config, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if  "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        logits_processor = LogitsProcessorList()
        outputs = model.generate(**inputs, 
            generation_config=self.generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            logits_processor=logits_processor
            )
        all_ids = outputs["sequences"]
        question_ids = inputs["input_ids"]
        net_ids = [aid[len(qid):] for aid, qid in zip(all_ids, question_ids)]
        net_ids = [nid[nid!=0] for nid in net_ids ]
        max_len = max([len(net_id) for net_id in net_ids])
        for i,nid in enumerate(net_ids):
            diff = max_len - len(nid)
            net_ids[i] = [-100] * diff + nid.tolist()
        logits = torch.tensor(net_ids, device = all_ids.device)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            loss = torch.tensor(0.0,device= model.device)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
                
        outputs = CausalLMOutputWithPast(
            loss=loss,
            logits=logits
        )

        return (loss, outputs) if return_outputs else loss
    
    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on `model` using `inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to evaluate.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                Whether or not to return the loss only.
            ignore_keys (`List[str]`, *optional*):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.

        Return:
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
        """
        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
        # For CLIP-like models capable of returning loss values.
        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
        # is `True` in `model.forward`.
        return_loss = inputs.get("return_loss", None)
        if return_loss is None:
            return_loss = self.can_return_loss
        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False

        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels or loss_without_labels:
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        with torch.no_grad():
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
                if has_labels or loss_without_labels:
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
                else:
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
            else:
                if has_labels or loss_without_labels:
                    with self.compute_loss_context_manager():
                        if self.evaluation_method == 'autoregressive':
                            loss, outputs = self.compute_generate_loss(model, inputs, self.generation_config, return_outputs=True)
                        else:
                            loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
                    loss = loss.mean().detach()

                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]
                else:
                    loss = None
                    with self.compute_loss_context_manager():
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        logits = nested_detach(logits)
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)


class Zero3SaveTrainable16bitModelTrainer(GenerateEvalTrainer):
    """This trainer monkey-patches the saving process of deepspeed zero stage 3 to save 
    only the trainable parameters.
    """
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
        if self.deepspeed:  # save only trainable weights
            if version.parse(transformers.__version__) < version.parse('4.29'):
                self.deepspeed._zero3_consolidated_16bit_state_dict = lambda: _zero3_consolidated_16bit_trainable_state_dict(self.deepspeed)
            else:
                if version.parse(accelerate_version) < version.parse("0.22.0"):
                    self.accelerator.get_state_dict = lambda model, unwrap=True: get_state_dict(self.accelerator, model, unwrap=unwrap)
            if version.parse(deepspeed.__version__) < version.parse('0.11'):
                self.deepspeed._save_checkpoint = lambda save_dir, tag, client_state={}, exclude_frozen_parameters=True: _save_trainable_checkpoint(self.deepspeed, save_dir, tag, client_state)
            else:
                original_save_checkpoint = self.deepspeed.save_checkpoint

                def deepspeed_save_checkpoint(
                    save_dir, 
                    tag=None, 
                    client_state={}, 
                    save_latest=True, 
                    exclude_frozen_parameters=True
                ):
                    return original_save_checkpoint(
                        save_dir, 
                        tag, 
                        client_state, 
                        save_latest, 
                        exclude_frozen_parameters
                        )

                self.deepspeed.save_checkpoint = deepspeed_save_checkpoint

        super().save_model(output_dir, _internal_call)

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        # If we are executing this function, we are the process zero, so we don't check for that.
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")

        supported_classes = (PreTrainedModel, PreTrainedModelWrapper) \
            if not is_peft_available() else (PreTrainedModel, PreTrainedModelWrapper, PeftModel)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, supported_classes):
            if state_dict is None:
                state_dict = self.model.state_dict()

            if isinstance(unwrap_model(self.model), supported_classes):
                unwrap_model(self.model).save_pretrained(
                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
                )
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
                if self.args.save_safetensors:
                    safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
                else:
                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            self.model.save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )

        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))



# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]

   
class MyGRPOTrainer(GRPOTrainer, Zero3SaveTrainable16bitModelTrainer):
    """ Combine Huggingface Trainer and TRL GRPOTrainer for Reinforced Fine-tuning of LLM model. 
    
    Here we make the following assumptions: 
    - The model has gone through any necessary warm-up process and can response according to some required format. 
    - There exists some environment that can evaluate query and response efficiently. The environment may encapsulate 
    a reward model or ruleset that assigns scalar reward scores.
    Then, this class focuses only on a standard RL process, consisting of three steps:
    - Rollout (handled by Zero3SaveTrainable16bitModelTrainer): 
        batch sampling responses using the latest model, given querries in the training set.
    - Evaluation: 
        calling the external ENV to output reward scores for the responses from rollout. The ENV is considered a blackbox.
    - Optimization (handled by GRPOTrainer.step): 
        forward process: call the model to re-calculate logprobs and estimate values, do all kinds of reward scaling and clipping,
            estimate the advantages, calculate the loss and log the stats.
        backward process: accelerator.backward(loss)

    Difference to GRPOTrainer:
    - remove peft_config argument, and handle peft model outside the trainer
    """
    def __init__(
        self,
        model: Union[str, PreTrainedModel],
        reward_funcs: Union[RewardFunc, list[RewardFunc]],
        args: GRPOConfig = None,
        train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
        eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
        processing_class: Optional[PreTrainedTokenizerBase] = None,
        reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
        callbacks: Optional[list[TrainerCallback]] = None,
        compute_metrics = None,
        evaluation_method: str = None,
        generation_config: GenerationConfig = None,
        # peft_config: Optional["PeftConfig"] = None,
        **kwargs
    ):
        # Args
        if args is None:
            model_name = model if isinstance(model, str) else model.config._name_or_path
            model_name = model_name.split("/")[-1]
            args = GRPOConfig(f"{model_name}-GRPO")

        # Models
        # Trained model
        model_init_kwargs = args.model_init_kwargs or {}
        if isinstance(model, str):
            model_id = model
            torch_dtype = model_init_kwargs.get("torch_dtype")
            if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
                pass  # torch_dtype is already a torch.dtype or "auto" or None
            elif isinstance(torch_dtype, str):  # it's a str, but not "auto"
                torch_dtype = getattr(torch, torch_dtype)
                model_init_kwargs["torch_dtype"] = torch_dtype
            else:
                raise ValueError(
                    "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
                    f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
                )
            # Disable caching if gradient checkpointing is enabled (not supported)
            model_init_kwargs["use_cache"] = (
                False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
            )
            model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
        else:
            model_id = model.config._name_or_path

        # Enable gradient checkpointing if requested
        if args.gradient_checkpointing:
            model = self._enable_gradient_checkpointing(model, args)

        # Reference model
        self.beta = args.kl_coef
        if self.beta == 0.0:
            # If beta is 0.0, the reference model is not needed
            self.ref_model = None
        elif is_deepspeed_zero3_enabled():
            self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
        elif is_peft_model(model):
            # If PEFT is used, the reference model is not needed since the adapter can be disabled
            # to revert to the initial model.
            self.ref_model = None
        else:
            # If PEFT configuration is not provided, create a reference model based on the initial model.
            self.ref_model = create_reference_model(model)

        # Processing class
        if processing_class is None:
            processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")

        # Reward functions
        if not isinstance(reward_funcs, list):
            reward_funcs = [reward_funcs]
        # we handle the model loading outside the trainer; the reward model doesn't have to be a classification model
        assert all(callable(reward_func) for reward_func in reward_funcs), RuntimeError(
            "Expect reward func to be callable; got {}".format([reward_func for reward_func in reward_funcs if not callable(reward_func)]))
        self.reward_funcs = reward_funcs

        # Reward weights
        if args.reward_weights is not None:
            if len(args.reward_weights) != len(reward_funcs):
                raise ValueError(
                    f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
                    f"functions ({len(reward_funcs)})"
                )
            self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
        else:
            self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)

        # Reward processing class
        if reward_processing_classes is None:
            reward_processing_classes = [None] * len(reward_funcs)
        elif not isinstance(reward_processing_classes, list):
            reward_processing_classes = [reward_processing_classes]
        else:
            if len(reward_processing_classes) != len(reward_funcs):
                raise ValueError("The number of reward processing classes must match the number of reward functions.")

        for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
            if isinstance(reward_func, PreTrainedModel):
                if reward_processing_class is None:
                    reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
                if reward_processing_class.pad_token_id is None:
                    reward_processing_class.pad_token = reward_processing_class.eos_token
                # The reward model computes the reward for the latest non-padded token in the input sequence.
                # So it's important to set the pad token ID to the padding token ID of the processing class.
                reward_func.config.pad_token_id = reward_processing_class.pad_token_id
                reward_processing_classes[i] = reward_processing_class
        self.reward_processing_classes = reward_processing_classes

        # Data collator
        def data_collator(features):  # No data collation is needed in GRPO
            return features

        # Training arguments
        self.max_prompt_length = generation_config.max_seq_length
        self.max_completion_length = generation_config.max_new_tokens  # = |o_i| in the GRPO paper
        self.num_generations = args.num_generations  # = G in the GRPO paper
        self.target_generations = args.target_generations 
        self.min_correct_generations = args.min_correct_generations 
        self.min_incorrect_generations = args.min_incorrect_generations 
        self.p_low = args.p_low 
        self.p_high = args.p_high 
        self.generation_config = generation_config
        self.use_vllm = args.use_vllm

        self.temperature=self.generation_config.temperature

        # Multi-step
        self.num_iterations = args.num_iterations  # = 𝜇 in the GRPO paper
        self.epsilon = args.epsilon
        # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle.
        self._step = 0
        # Buffer the batch to reuse generated outputs across multiple updates. For more details, see
        # `_get_train_sampler` and `_prepare_inputs`.
        self._buffered_inputs = [None] * args.gradient_accumulation_steps

        # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
        # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
        # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
        # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
        # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
        # This acts as a flag to indicate that the warning has already been issued.
        model.warnings_issued["estimate_tokens"] = True

        # Initialize the metrics
        self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
        self.log_completions = args.log_completions

        Zero3SaveTrainable16bitModelTrainer.__init__(
            self, 
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            callbacks=callbacks,
            compute_metrics=compute_metrics,
            evaluation_method=evaluation_method,
            generation_config=generation_config,
            **kwargs
        )

        # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
        num_processes = self.accelerator.num_processes
        global_batch_size = args.per_device_train_batch_size * num_processes
        possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
        if self.target_generations not in possible_values:
            raise ValueError(
                f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
                f"divisible by the number of generations per prompt ({self.target_generations}). Given the current train "
                f"batch size, the valid values for the number of generations are: {possible_values}."
            )
        if self.args.eval_strategy != "no":
            global_batch_size = args.per_device_eval_batch_size * num_processes
            possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
            if self.target_generations not in possible_values:
                raise ValueError(
                    f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
                    f"divisible by the number of generations per prompt ({self.target_generations}). Given the current "
                    f"eval batch size, the valid values for the number of generations are: {possible_values}."
                )

        # Ensure each process receives a unique seed to prevent duplicate completions when generating with
        # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
        # it's safer to set it in all cases.
        set_seed(args.seed, device_specific=True)

        if self.use_vllm:
            if not is_vllm_available():
                raise ImportError(
                    "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
                    "`pip install vllm` to use it."
                )

            if self.accelerator.is_main_process:
                vllm_device = self.args.vllm_device
                device_type = PartialState().default_device.type
                device_module = getattr(torch, device_type)
                if vllm_device == "auto":
                    if device_module.device_count() == 1:
                        vllm_device = f"{device_type}:0"  # particular case when training with onyl 1 device: share it
                    else:
                        vllm_device = f"{device_type}:{self.accelerator.num_processes}"  # take the next GPU idx
                # Check that the requested device is available
                if (
                    vllm_device.split(":")[0] == f"{device_type}"
                    and int(vllm_device.split(":")[1]) >= device_module.device_count()
                ):
                    raise ValueError(
                        f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
                        "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
                        "value lower than the number of GPUs available on your machine—typically, reducing it by one "
                        f"is sufficient. In your case: `--num_processes {device_module.device_count() - 1}`."
                    )
                # Check that the requested device is not also used for training
                if vllm_device in {f"{device_type}:{idx}" for idx in range(self.accelerator.num_processes)}:
                    warnings.warn(
                        f"The requested device {vllm_device} is also being used for training. For higher throughput "
                        "and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. "
                        "If this is intentional, you may ignore this warning but should adjust "
                        "`vllm_gpu_memory_utilization` accordingly."
                    )
                # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
                # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
                # setting (profiling_patch).
                world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
                profiling_patch = patch(
                    "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
                )

                # For Ascend NPU (torch-npu), collective communication requires the establishment of a communication
                # group, and different processes must hold the same group number. However, multiple process groups will
                # be created internally within vLLM. This will cause the group id of the communication group on rank 0
                # to be different from that of other ranks, causing backward to hang on because the communication
                # domain cannot be established. So we need to patch it to make sure the group id of different ranks in
                # the training phase are the same.
                @contextlib.contextmanager
                def new_group_context():
                    new_group = torch.distributed.new_group
                    try:
                        torch.distributed.new_group = functools.partial(new_group, use_local_synchronization=True)
                        torch.npu.mem_get_info = functools.partial(torch.npu.mem_get_info, device=vllm_device)
                        yield
                    finally:
                        torch.distributed.new_group = new_group

                new_group_patch = new_group_context() if device_type == "npu" else contextlib.nullcontext()
                with world_size_patch, profiling_patch, new_group_patch:
                    self.llm = LLM(
                        model=model.name_or_path,
                        device=vllm_device,
                        gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
                        dtype=self.args.vllm_dtype,
                        # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
                        # directly reuse the KV cache if it shares the same prefix with one of the existing queries.
                        # This is particularly useful here because we generate completions from the same prompts.
                        enable_prefix_caching=getattr(self.args, "vllm_enable_prefix_caching", True),
                        max_model_len=self.args.vllm_max_model_len,
                    )
                    if is_peft_model(model):  # move peft weights to vllm model
                        self._move_model_to_vllm()

                # Guided decoding, if enabled
                if getattr(args, "vllm_guided_decoding_regex", None) is not None:
                    guided_decoding = GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex)
                else:
                    guided_decoding = None

                # Sampling parameters
                self.sampling_params = SamplingParams(
                    temperature=self.generation_config.temperature,
                    max_tokens=self.max_completion_length,
                    guided_decoding=guided_decoding,
                    top_p=self.generation_config.top_p,
                    top_k=self.generation_config.top_k,
                    n=self.num_generations,
                )

            self._last_loaded_step = 0  # tag to avoid useless loading during grad accumulation

            # When using vLLM, the main process is responsible for loading the model weights. This can cause process
            # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
            # synchronize all processes after vLLM has been fully initialized.
            self.accelerator.wait_for_everyone()

        # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
        # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
        # self.model_accepts_loss_kwargs to False to enable scaling.
        self.model_accepts_loss_kwargs = False

        # Add tags to the model
        self.model.add_model_tags(self._tag_names)

        if self.ref_model is not None:
            if self.is_deepspeed_enabled:
                self.ref_model = self._prepare_deepspeed(self.ref_model, self.accelerator)
            else:
                self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

        if getattr(args, "sync_ref_model", False):
            from trl.callbacks import SyncRefModelCallback
            
            self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))

        for i, reward_func in enumerate(self.reward_funcs):
            if isinstance(reward_func, PreTrainedModel):
                self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)

    def _prepare_deepspeed(self, model, accelerator):
        # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
        deepspeed_plugin = accelerator.state.deepspeed_plugin
        config_kwargs = copy.deepcopy(deepspeed_plugin.deepspeed_config)
        stage = config_kwargs["zero_optimization"]["stage"]
        
        # Remove scheduler params from ds_config so that deepspeed does not create a scheduler
        config_kwargs.pop("scheduler", None)

        if model is not None:
            hidden_size = (
                max(model.config.hidden_sizes)
                if getattr(model.config, "hidden_sizes", None)
                else getattr(model.config, "hidden_size", None)
            )
            if hidden_size is not None and stage == 3:
                # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache
                # @ step 0: expected module 1, but got module 0`
                # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
                config_kwargs.update(
                    {
                        "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
                        "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
                        "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
                    }
                )

        # If ZeRO-3 is used, we shard both the active and reference model.
        # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO
        # disabled (stage 0)
        if stage != 3:
            config_kwargs["zero_optimization"]["stage"] = 0
        model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
        model.eval()
        return model

    def _sep_prompt_completion(self, chat):
        if isinstance(chat, list) and len(chat) > 0 and isinstance(chat[0], dict) and 'role' in chat[0]:  # lazy check List[dict]
            num_completion = 0
            for utterrance in chat[::-1]:
                if utterrance['role'].lower() in {"assistant"}:
                    num_completion += 1
                else:
                    break
            if num_completion > 0:  # set assistant message as both completion and label
                return {"prompt": chat[:-num_completion], "completion": chat[-num_completion:], "label": chat[-num_completion:]}
            else:
                return {"prompt": chat, "completion": "", "label": ""}
        elif isinstance(chat, dict) and 'messages' in chat:
            return self._sep_prompt_completion(chat['messages'])
        else:
            raise TypeError(f"Expect input to be of chat format; got {chat}")
        
    @profiling_decorator
    def _move_model_to_vllm(self):
        if self.accelerator.is_main_process:
            with unwrap_model_for_generation(
                self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
            ) as unwrapped_model:
                if is_compiled_module(unwrapped_model):
                    unwrapped_model = unwrapped_model._orig_mod
                if is_peft_model(unwrapped_model):
                    unwrapped_model.merge_adapter()
                    state_dict = unwrapped_model.state_dict()
                    # Remove base_model and base_layer prefixes
                    state_dict = {
                        k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
                    }
                    # Remove values with adapter prefix (example: "_lora")
                    state_dict = {k: v for k, v in state_dict.items() if getattr(unwrapped_model, "prefix", "lora") not in k}
                     # When module to save, remove its prefix and discard the original module
                    state_dict = {
                        k.replace("modules_to_save.default.", ""): v
                        for k, v in state_dict.items()
                        if "original_module" not in k
                    }
                else:
                    state_dict = unwrapped_model.state_dict()
                llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
                llm_model.load_weights(state_dict.items())
                # Unmerge the adapter to restore the model to its original state.
                # This must be done after loading weights to ensure they correspond to the merged state.
                if is_peft_model(unwrapped_model):
                    unwrapped_model.unmerge_adapter()

    @profiling_decorator
    @torch.no_grad
    def _prepare_inputs(self, inputs: List[dict]) -> List[dict]:
        """ Process inputs, calculate completions and rewards.
        
        Args:
            inputs (List[dict]): a list of conversational input, each of the following format
                {
                    "messages": [{'role': 'user', 'content': '...'}, {'role': 'assistant', 'content': '...'}]
                }
            
        Returns:
            prompt_ids, prompt_mask, completion_ids, completion_mask, ref_per_token_logps, advantages
        """
        # Convert conversational inputs (List[dict]) to 
        # {
        #     "prompt": [{"role": "user", "content": "What color is the sky?"}],
        #     "completion": [{"role": "assistant", "content": "It is blue."}]
        # }        
        assert isinstance(inputs, list) and len(inputs) > 0 and isinstance(inputs[0], dict), RuntimeError(
            f"Expect inputs to be List[dict]; got {inputs}")
        if 'messages' in inputs[0]:
            inputs = [self._sep_prompt_completion(chat) for chat in inputs]
        assert "prompt" in inputs[0], RuntimeError(  # lazy check
            f"Expect each of inputs to be dict with keys prompt and completion; got {inputs}")
        
        # TODO check necessarity
        self.model.eval()
        if self.ref_model is not None:
            self.ref_model.eval()
            
        mode = "eval" if self.control.should_evaluate else "train"
        if mode == "train":
            if self.state.global_step % self.num_iterations == 0:
                inputs = self._generate_and_score_completions(inputs)
                self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
            else:
                inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
            self._step += 1
        else:
            # In evaluation, we don't reuse completions across multiple updates, so we don't need to buffer inputs.
            inputs = self._generate_and_score_completions(inputs)
        return inputs

    def _generate_and_score_completions(
        self, inputs: dict[str, Union[torch.Tensor, Any]]
    ) -> dict[str, Union[torch.Tensor, Any]]:
        device = self.accelerator.device
        prompts = [example["prompt"] for example in inputs]  # what date is it today
        prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]  # [bos]user\nwhat date is it today
        prompt_inputs = self.processing_class(
            prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
        )
        prompt_inputs = super(GRPOTrainer, self)._prepare_inputs(prompt_inputs)
        prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

        if self.max_prompt_length is not None:
            prompt_ids = prompt_ids[:, -self.max_prompt_length :]
            prompt_mask = prompt_mask[:, -self.max_prompt_length :]

        # Generate completions using either vLLM or regular generation
        if self.args.use_vllm:
            # First, have main process load weights if needed
            if self.state.global_step != self._last_loaded_step:
                self._move_model_to_vllm()
                self._last_loaded_step = self.state.global_step

            # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
            all_prompts_text = gather_object(prompts_text)
            if self.accelerator.is_main_process:

                # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
                # num_generations outputs for each one. This is faster than generating outputs for each duplicate
                # prompt individually.
                ordered_set_of_prompts = list(dict.fromkeys(all_prompts_text))
                with profiling_context(self, "vLLM.generate"):
                    all_outputs = self.llm.generate(
                        ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False
                    )
                completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
                
            else:
                completion_ids = [None] * len(all_prompts_text)
            # Broadcast the completions from the main process to all processes, ensuring each process receives its
            # corresponding slice.
            completion_ids = broadcast_object_list(completion_ids, from_process=0)
            process_slice = slice(
                self.accelerator.process_index * len(prompts),
                (self.accelerator.process_index + 1) * len(prompts),
            )
            completion_ids = completion_ids[process_slice]

            # Pad the completions, and concatenate them with the prompts
            completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
            completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
            prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        else:
            # Regular generation path
            with torch.inference_mode():
                with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
                    prompt_completion_ids = unwrapped_model.generate(
                        input_ids=prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
                    )

            # Compute prompt length and extract completion ids
            prompt_length = prompt_ids.size(1)
            prompt_ids = prompt_completion_ids[:, :prompt_length]
            completion_ids = prompt_completion_ids[:, prompt_length:]

        labels = [example.get("label", "") for example in inputs]
        if self.num_generations > self.target_generations: 
        
            # Collect data for all processes
            all_labels = gather_object(labels)
            all_prompts = gather_object(prompts)
            all_prompts_text = gather_object(prompts_text)
            all_prompt_ids = gather_object(prompt_ids)
            all_prompt_mask = gather_object(prompt_mask)
            all_completion_ids = gather_object(completion_ids)

            n = int(len(prompts) * self.target_generations / self.num_generations)
            selected_all_n = int(len(all_labels) * self.target_generations / self.num_generations)
        
            if self.accelerator.is_main_process:
            
                # select subset with high diversity
                all_completions = self.processing_class.batch_decode(all_completion_ids, skip_special_tokens=True)
                _, selected_indices, initial_metrics, subset_metrics = select_diverse_subset(
                    completions=all_completions,
                    labels=all_labels,
                    n_generations=self.num_generations,
                    target_n_generations=self.target_generations,
                    min_correct=self.min_correct_generations,
                    min_incorrect=self.min_incorrect_generations,
                    tokenizer=self.processing_class
                )
                mode = "eval" if self.control.should_evaluate else "train"
                self._metrics[mode]["init_avg_switch_div"].append(initial_metrics['avg_switch_div'])
                self._metrics[mode]["init_avg_process_div"].append(initial_metrics['avg_process_div'])
                self._metrics[mode]["init_avg_length_div"].append(initial_metrics['avg_length_div'])
                self._metrics[mode]["init_avg_accuracy"].append(initial_metrics['avg_accuracy'])
                
                self._metrics[mode]["subset_avg_switch_div"].append(subset_metrics['avg_switch_div'])
                self._metrics[mode]["subset_avg_process_div"].append(subset_metrics['avg_process_div'])
                self._metrics[mode]["subset_avg_length_div"].append(subset_metrics['avg_length_div'])
                self._metrics[mode]["subset_avg_accuracy"].append(subset_metrics['avg_accuracy'])
                
                selected_labels = [all_labels[i] for i in selected_indices]
                selected_prompts = [all_prompts[i] for i in selected_indices]
                selected_prompts_text = [all_prompts_text[i] for i in selected_indices]
                selected_prompt_ids = [all_prompt_ids[i] for i in selected_indices]
                selected_completion_ids = [all_completion_ids[i] for i in selected_indices] 
                selected_prompt_mask = [all_prompt_mask[i] for i in selected_indices]
                
            else:
                
                selected_labels = [None] * selected_all_n
                selected_prompts = [None] * selected_all_n
                selected_prompts_text = [None] * selected_all_n
                selected_prompt_ids = [None] * selected_all_n
                selected_completion_ids = [None] * selected_all_n
                selected_prompt_mask = [None] * selected_all_n

            # Broadcasts the selected data to all processes
            labels = broadcast_object_list(selected_labels, from_process=0)
            prompts = broadcast_object_list(selected_prompts, from_process=0)
            prompts_text = broadcast_object_list(selected_prompts_text, from_process=0)
            prompt_ids = broadcast_object_list(selected_prompt_ids, from_process=0)
            prompt_mask = broadcast_object_list(selected_prompt_mask, from_process=0)
            completion_ids = broadcast_object_list(selected_completion_ids, from_process=0)

        
            # Each process takes its own part
            process_slice = slice(
                self.accelerator.process_index * n,
                (self.accelerator.process_index + 1) * n,
            )
            labels = labels[process_slice]
            prompts = prompts[process_slice]
            prompts_text = prompts_text[process_slice]
            prompt_ids = prompt_ids[process_slice]
            prompt_mask = prompt_mask[process_slice]
            completion_ids = completion_ids[process_slice]
            
            # Pad the completions, and concatenate them with the prompts
            prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids]
            prompt_ids = pad(prompt_ids, padding_value=self.processing_class.pad_token_id, padding_side='left')
            
            prompt_mask = [torch.tensor(mask, device=device) for mask in prompt_mask]
            prompt_mask = pad(prompt_mask, padding_value=0, padding_side='left') 
            
            completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
            completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
            
            prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        
        
        # Mask everything after the first EOS token
        is_eos = completion_ids == self.processing_class.eos_token_id
        eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
        eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
        sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
        completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

        # Concatenate prompt_mask with completion_mask for logit computation
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)  # (B, P+C)

        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

        with torch.no_grad():
            # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
            # computation here, and use per_token_logps.detach() instead.
            if self.num_iterations > 1:
                old_per_token_logps = self._get_per_token_logps(
                    self.model, prompt_completion_ids, attention_mask, logits_to_keep
                )
            else:
                old_per_token_logps = None

            if self.beta == 0.0:
                ref_per_token_logps = None
            elif self.ref_model is not None:
                ref_per_token_logps = self._get_per_token_logps(
                    self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
                )
            else:
                with self.accelerator.unwrap_model(self.model).disable_adapter():
                    ref_per_token_logps = self._get_per_token_logps(
                        self.model, prompt_completion_ids, attention_mask, logits_to_keep
                    )

        # Decode the generated completions
        completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
        if is_conversational(inputs[0]):
            completions = []
            for prompt, completion in zip(prompts, completions_text):
                bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
                completions.append([{"role": "assistant", "content": bootstrap + completion}])
        else:
            completions = completions_text
        
        rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
        for i, (reward_func, reward_processing_class) in enumerate(
            zip(self.reward_funcs, self.reward_processing_classes)
        ):
            if isinstance(reward_func, nn.Module):  # Module instead of PretrainedModel for compat with compiled models
                reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}"
            else:
                reward_func_name = reward_func.func.__name__ if isinstance(reward_func, partial) else reward_func.__name__
            with profiling_context(self, reward_func_name):
                if isinstance(
                    reward_func, nn.Module
                ):  # Module instead of PretrainedModel for compat with compiled models
                    if is_conversational(inputs[0]):
                        messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
                        texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
                    else:
                        texts = [p + c for p, c in zip(prompts, completions)]
                    reward_inputs = reward_processing_class(
                        texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
                    )
                    reward_inputs = super(GRPOTrainer, self)._prepare_inputs(reward_inputs)
                    with torch.inference_mode():
                        rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]  # Shape (B*G,)
                else:
                    # Repeat all input columns (but "prompt" and "completion") to match the number of generations
                    keys = [key for key in inputs[0] if key not in ["prompt", "completion", 'label']]
                    reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
                    output_reward_func = reward_func(prompt=prompts, completion=completions, label=labels, **reward_kwargs)
                    rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

        # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
        # completions may be distributed across processes
        rewards_per_func = gather(rewards_per_func)

        # Apply weights to each reward function's output and sum
        rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)

        # Compute grouped-wise rewards
        mean_grouped_rewards = rewards.view(-1, min(self.num_generations, self.target_generations)).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, min(self.num_generations, self.target_generations)).std(dim=1)

        # Normalize the rewards to compute the advantages
        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(min(self.num_generations, self.target_generations), dim=0)
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(min(self.num_generations, self.target_generations), dim=0)
        advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

        # Slice to keep only the local part of the data
        process_slice = slice(
            self.accelerator.process_index * len(prompts),
            (self.accelerator.process_index + 1) * len(prompts),
        )
        advantages = advantages[process_slice]

        # Log the metrics
        mode = "eval" if self.control.should_evaluate else "train"

        completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
        self._metrics[mode]["completion_length"].append(completion_length)

        reward_per_func = rewards_per_func.mean(0)
        for i, reward_func in enumerate(self.reward_funcs):
            if isinstance(reward_func, nn.Module):  # Module instead of PretrainedModel for compat with compiled models
                reward_func_name = reward_func.config._name_or_path.split("/")[-1]
            else:
                reward_func_name = reward_func.func.__name__ if isinstance(reward_func, partial) else reward_func.__name__
            self._metrics[mode][f"rewards/{reward_func_name}"].append(reward_per_func[i].item())

        self._metrics[mode]["reward"].append(rewards.mean().item())
        self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())

        if self.log_completions and self.state.global_step % self.args.logging_steps == 0:
            prompts_to_log = gather_object(prompts_text)
            completions_to_log = gather_object(completions_text)
            rewards_to_log = rewards.tolist()

            if self.accelerator.is_main_process:
                if is_rich_available():
                    print_prompt_completions_sample(
                        prompts_to_log,
                        completions_to_log,
                        rewards_to_log,
                        self.state.global_step,
                    )
                if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
                    import pandas as pd

                    # For logging
                    table = {
                        "step": [str(self.state.global_step)] * len(rewards),
                        "prompt": prompts_to_log,
                        "completion": completions_to_log,
                        "reward": rewards.tolist(),
                    }
                    df = pd.DataFrame(table)
                    wandb.log({"completions": wandb.Table(dataframe=df)})

        return {
            "prompt_ids": prompt_ids,
            "prompt_mask": prompt_mask,
            "completion_ids": completion_ids,
            "completion_mask": completion_mask,
            "old_per_token_logps": old_per_token_logps,
            "ref_per_token_logps": ref_per_token_logps,
            "advantages": advantages,
        }

    @profiling_decorator
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if return_outputs:
            raise ValueError("The GRPOTrainer does not support returning outputs")
        # Compute the per-token log probabilities for the model
        model.train()

        prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
        completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

        per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

        # Compute the KL divergence between the model and the reference model
        if self.beta != 0.0:
            ref_per_token_logps = inputs["ref_per_token_logps"]
            per_token_kl = (
                torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
            )

        # Compute the loss
        advantages = inputs["advantages"]
        # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see
        # _generate_and_score_completions) and use per_token_logps.detach() instead.
        old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
        coef_1 = torch.exp(per_token_logps - old_per_token_logps)
        coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)
        per_token_loss1 = coef_1 * advantages.unsqueeze(1)
        per_token_loss2 = coef_2 * advantages.unsqueeze(1)
        per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
        if self.beta != 0.0:
            per_token_loss = per_token_loss + self.beta * per_token_kl
        loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()

        # Log the metrics
        mode = "eval" if self.control.should_evaluate else "train"

        if self.beta != 0.0:
            mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
            self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

        is_clipped = (per_token_loss1 < per_token_loss2).float()
        clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
        self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
        return loss