# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
import sys
import time
import copy
import random
import inspect
import warnings
import numpy as np
from types import MethodType
from typing import TYPE_CHECKING, Any, Optional

import torch
import accelerate
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, gather_object, is_deepspeed_available
from accelerate.utils import DistributedDataParallelKwargs
from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.optimization import get_scheduler
from transformers.trainer import DEFAULT_CALLBACKS
from transformers.trainer_callback import CallbackHandler
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer
from trl.trainer import BaseTrainer
from trl.core import PPODecorators, logprobs_from_logits, set_seed, stats_to_np, stack_dicts
from trl.models.utils import unwrap_model_for_generation
from trl.import_utils import is_torch_greater_2_0
from trl.trainer import RunningMoments
from typing_extensions import override

from ...extras import logging
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
from .generate_get_midden_status import generate_with_intermediates

import torch.nn.functional as F
from datasets import Dataset
import lmppl

if TYPE_CHECKING:
    from transformers import (
        DataCollatorWithPadding,
        PreTrainedTokenizer,
        ProcessorMixin,
        Seq2SeqTrainingArguments,
        TrainerCallback,
        AutoModel
    )
    from trl_util.modeling_value_head import AutoModelForCausalLMWithValueHead

    from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments

if is_deepspeed_available():
    import deepspeed

from transformers.integrations.deepspeed import (
    set_hf_deepspeed_config,
    unset_hf_deepspeed_config,
)


logger = logging.get_logger(__name__)

class CustomPPOTrainer(PPOTrainer, Trainer):
    r"""Inherit PPOTrainer."""

    def __init__(
        self,
        model_args: "ModelArguments",
        training_args: "Seq2SeqTrainingArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
        callbacks: Optional[list["TrainerCallback"]],
        model: "AutoModelForCausalLMWithValueHead",
        # reward_model: Optional["AutoModelForCausalLMWithValueHead"],
        # ref_model: Optional["AutoModelForCausalLMWithValueHead"],
        tokenizer: "PreTrainedTokenizer",
        processor: Optional["ProcessorMixin"],
        data_collator: "DataCollatorWithPadding",
        train_dataset: Optional["Dataset"] = None,
        eval_dataset: Optional["Dataset"] = None,
    ) -> None:

        if eval_dataset is not None:
            raise NotImplementedError("PPOTrainer does not support eval dataset yet.")

        backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
        ppo_config = PPOConfig(
            model_name=model_args.model_name_or_path,
            learning_rate=training_args.learning_rate,
            mini_batch_size=training_args.per_device_train_batch_size,
            batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
            gradient_accumulation_steps=training_args.gradient_accumulation_steps,
            ppo_epochs=finetuning_args.ppo_epochs,
            max_grad_norm=training_args.max_grad_norm,
            seed=training_args.seed,
            optimize_device_cache=True,
            target=finetuning_args.ppo_target,
            use_score_scaling=finetuning_args.ppo_score_norm,
            use_score_norm=finetuning_args.ppo_score_norm,
            whiten_rewards=finetuning_args.ppo_whiten_rewards,
            accelerator_kwargs={"step_scheduler_with_optimizer": False},
            log_with=training_args.report_to[0] if training_args.report_to else None,
            project_kwargs={"logging_dir": training_args.logging_dir},
        )

        # Add deepspeed config
        if training_args.deepspeed_plugin is not None:
            ppo_config.accelerator_kwargs["kwargs_handlers"] = [
                DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
            ]
            ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
            if ppo_config.log_with is not None:
                logger.warning_rank0("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
                ppo_config.log_with = None

        # Create optimizer and scheduler
        if training_args.max_steps > 0:
            num_training_steps = training_args.max_steps
        else:
            total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
            num_training_steps = training_args.num_train_epochs * math.ceil(
                len(train_dataset) / total_train_batch_size
            )

        optimizer = self.create_optimizer(model, training_args, finetuning_args)
        scheduler = self.create_scheduler(training_args, num_training_steps, optimizer)

        ######################
        # PPOTrainer.__init__(
        #     self,
        #     config=ppo_config,
        #     model=model,
        #     # ref_model=ref_model,
        #     tokenizer=tokenizer,
        #     dataset=train_dataset,
        #     optimizer=optimizer,
        #     data_collator=data_collator,
        #     lr_scheduler=scheduler,
        # )
        ######################
        config = ppo_config
        self.config = config
        set_seed(config.seed)
        self.accelerator = Accelerator(
            log_with=config.log_with,
            gradient_accumulation_steps=config.gradient_accumulation_steps,
            project_config=ProjectConfiguration(**config.project_kwargs),
            **config.accelerator_kwargs,
        )
        logger.info_rank0(
            f"Total dlm steps: {finetuning_args.sample_step}\n" + \
            f"Block length: {finetuning_args.sample_block_length}\n" + \
            f"Number of steps in each group: {finetuning_args.steps_per_group}\n" + \
            f"Number of groups: {finetuning_args.sample_step//finetuning_args.steps_per_group}\n" + \
            f"Number of groups to update model: {int(finetuning_args.sample_ratio_of_groups_to_update_model*finetuning_args.sample_step//finetuning_args.steps_per_group)}\n" + \
            f"Number of groups to accumulate grad: {finetuning_args.num_of_groups_to_accumulate_grad}\n"
        )

        config.world_size = self.accelerator.num_processes
        config.global_backward_batch_size = config.backward_batch_size * config.world_size
        config.global_batch_size = config.batch_size * config.world_size
        self.model = model
        self.model_params = filter(lambda p: p.requires_grad, self.model.parameters())
        self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
        self.is_peft_model = getattr(self.model, "is_peft_model", False)
        config.is_encoder_decoder = self.is_encoder_decoder
        config.is_peft_model = self.is_peft_model

        is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
        self.accelerator.init_trackers(
            config.tracker_project_name,
            config=dict(trl_ppo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
            init_kwargs=config.tracker_kwargs,
        )
        self.is_using_text_environment = getattr(config, "use_text_environment", False)

        if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)):
            raise ValueError(
                "tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast"
            )
        self.tokenizer = tokenizer

        dataset = train_dataset
        if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)):
            raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset")
        elif dataset is None:
            warnings.warn(
                "No dataset is provided. Make sure to set config.batch_size to the correct value before training.",
                UserWarning,
            )
        self.dataset = dataset
        self._signature_columns = None
        if self.dataset is not None:
            self.dataloader = self.prepare_dataloader(self.dataset, data_collator)
        elif self.dataset is None and self.accelerator.num_processes > 1:
            warnings.warn(
                "No dataset is provided. In a multi-GPU setting, this will lead to an error. You should"
                " prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`"
                " and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please "
                " refer to the documentation for more details.",
                UserWarning,
            )
            self.dataloader = None
        else:
            self.dataloader = None

        # Step 3: Initialize optimizer and data collator
        self.data_collator = None
        if optimizer is None:
            self.optimizer = Adam(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate,
            )
        else:
            self.optimizer = optimizer

        self.lr_scheduler = scheduler
        if self.lr_scheduler is not None:
            lr_scheduler_class = (
                torch.optim.lr_scheduler._LRScheduler
                if not is_torch_greater_2_0()
                else torch.optim.lr_scheduler.LRScheduler
            )

            if not isinstance(self.lr_scheduler, lr_scheduler_class):
                raise ValueError(
                    "lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)"
                )

        # Safety checkers for DS integration
        is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
            self.accelerator.state, "deepspeed_plugin"
        )

        if config.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

            if hasattr(self.model, "enable_input_require_grads"):
                self.model.enable_input_require_grads()
            else:
                # For backward compatibility with older versions of transformers
                def make_inputs_require_grad(module, input, output):
                    output.requires_grad_(True)

                self.model.pretrained_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

        (
            self.model,
            self.optimizer,
            self.data_collator,
            self.dataloader,
            self.lr_scheduler,
        ) = self.accelerator.prepare(
            self.model,
            self.optimizer,
            self.data_collator,
            self.dataloader,
            self.lr_scheduler,
        )

        self.is_distributed = self.accelerator.num_processes > 1

        # init the current step
        self.current_step = 0

        # init variables for pushing model to hub
        if config.push_to_hub_if_best_kwargs:
            if "repo_id" not in config.push_to_hub_if_best_kwargs:
                raise ValueError("You have to specify repo_id in order to push the model to the hub!")
            self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs
            self.compare_step = 0
            self.highest_reward = torch.tensor(-float("inf"))

        # post process for PP
        if not getattr(self.model, "is_sequential_parallel", False):
            self.current_device = self.accelerator.device
        else:
            if is_xpu_available():
                self.current_device = torch.device("xpu:0")
            elif is_npu_available():
                self.current_device = torch.device("npu:0")
            else:
                self.current_device = torch.device("cuda:0")

        PPODecorators.optimize_device_cache = self.config.optimize_device_cache

        self.running = RunningMoments(self.accelerator)
        ######################

        self.args = training_args
        self.model_args = model_args
        self.finetuning_args = finetuning_args
        # self.reward_model = reward_model
        self.current_device = get_current_device()  # patch for deepspeed training

        if self.finetuning_args.model_path_or_name_calculating_ppl:
            hf_deepspeed_config = self.accelerator.state.deepspeed_plugin.hf_ds_config
            if is_deepspeed_used:
                unset_hf_deepspeed_config()
            self.ppl_scorer = lmppl.LM(
                self.finetuning_args.model_path_or_name_calculating_ppl,
                device_map=self.model.device,
                low_cpu_mem_usage=True
            )
            if is_deepspeed_used:
                set_hf_deepspeed_config(hf_deepspeed_config)
        else:
            self.ppl_scorer = None

        self.generation_config = GenerationConfig(
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
            **generating_args.to_dict(),
        )

        self.state = TrainerState()
        self.control = TrainerControl()
        self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
        self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
        callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
        self.callback_handler = CallbackHandler(
            callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
        )
        if self.args.max_steps > 0:
            logger.info_rank0("max_steps is given, it will override any value given in num_train_epochs")

        self.amp_context = torch.autocast(self.current_device.type)
        warnings.simplefilter("ignore")  # remove gc warnings on ref model

        # if finetuning_args.reward_model_type == "full":
        #     if self.is_deepspeed_enabled:
        #         if not (
        #             getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
        #             or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
        #         ):  # quantized models are already set on the correct device
        #             self.reward_model = self._prepare_deepspeed(self.reward_model)
        #     else:
        #         self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)

        # self.add_callback(FixValueHeadModelCallback)

        if processor is not None:
            self.add_callback(SaveProcessorCallback(processor))

        if finetuning_args.use_badam:
            from badam import BAdamCallback, clip_grad_norm_old_version  # type: ignore

            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
            self.add_callback(BAdamCallback)

        self.accelerator.wait_for_everyone()

    def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
        r"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer."""
        if resume_from_checkpoint is not None:
            raise ValueError("`resume_from_checkpoint` will be supported in the future version.")

        total_train_batch_size = (
            self.args.per_device_train_batch_size
            * self.args.gradient_accumulation_steps
            * self.finetuning_args.ppo_buffer_size
            * self.args.world_size
        )
        if self.args.max_steps > 0:
            num_examples = total_train_batch_size * self.args.max_steps
            num_train_epochs = sys.maxsize
            max_steps = self.args.max_steps
            steps_in_epoch = self.args.max_steps
        else:
            len_dataloader = len(self.dataloader)
            num_examples = len(self.dataset)
            num_train_epochs = self.args.num_train_epochs
            max_steps = math.ceil(num_train_epochs * len_dataloader)
            steps_in_epoch = len_dataloader

        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()

        logger.info_rank0("***** Running training *****")
        logger.info_rank0(f"  Num examples = {num_examples:,}")
        logger.info_rank0(f"  Num Epochs = {num_train_epochs:,}")
        logger.info_rank0(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
        logger.info_rank0(
            f"  Total train batch size (w. parallel, buffer, distributed & accumulation) = {total_train_batch_size:,}"
        )
        logger.info_rank0(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
        logger.info_rank0(f"  Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
        logger.info_rank0(f"  Total training steps = {max_steps:,}")
        logger.info_rank0(f"  Number of trainable parameters = {count_parameters(self.model)[0]:,}")

        dataiter = iter(self.dataloader)
        meters = []
        self.callback_handler.on_train_begin(self.args, self.state, self.control)

        for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
            try:
                batch = next(dataiter)
            except StopIteration:
                dataiter = iter(self.dataloader)
                batch = next(dataiter)

            # Get inputs
            self.tokenizer.padding_side = "right"  # change padding side
            queries, responses, rewards = [], [], []

            self.model.train()

            stats = self.step(batch)

            self.tokenizer.padding_side = "left"  # restore padding side
            meters.append(stats)

            if self.config.log_with is not None:
                try:
                    batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
                    batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
                    self.log_stats(stats, batch, rewards)
                except Exception:
                    logger.warning_rank0("Failed to save stats due to unknown errors.")

            self.state.global_step += 1
            self.callback_handler.on_step_end(self.args, self.state, self.control)

            if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
                average_meters = {
                    "loss": [],
                }

                for k in meters[0].keys():
                    average_meters[k] = []

                for item in meters:
                    for k,v in item.items():
                        average_meters[k].append(v)

                logs = {k: round(np.nanmean(v), 4) for k, v in average_meters.items()}
                logs["learning_rate"] = stats["reinforce/learning_rate"]
                logs["epoch"] = round(step / steps_in_epoch, 2)

                tqdm.write(str(logs))
                logs["step"] = step
                self.state.log_history.append(logs)
                self.callback_handler.on_log(self.args, self.state, self.control, logs)
                meters = []

            if (step + 1) % self.args.save_steps == 0:  # save checkpoint
                self.save_model(
                    os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
                )
                self.callback_handler.on_save(self.args, self.state, self.control)

            if self.control.should_epoch_stop or self.control.should_training_stop:
                break

        self.callback_handler.on_train_end(self.args, self.state, self.control)

    @override
    def create_optimizer(
        self,
        model: "AutoModelForCausalLMWithValueHead",
        training_args: "Seq2SeqTrainingArguments",
        finetuning_args: "FinetuningArguments",
    ) -> "torch.optim.Optimizer":
        optimizer = create_custom_optimizer(model, training_args, finetuning_args)
        if optimizer is None:
            decay_params, nodecay_params = [], []
            decay_param_names = self.get_decay_parameter_names(model)
            for name, param in model.named_parameters():
                if param.requires_grad:
                    if name in decay_param_names:
                        decay_params.append(param)
                    else:
                        nodecay_params.append(param)

            optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
            param_groups = [
                dict(params=nodecay_params),
                dict(params=decay_params, weight_decay=training_args.weight_decay),
            ]
            optimizer = optim_class(param_groups, **optim_kwargs)

        return optimizer

    @override
    def create_scheduler(
        self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
    ) -> "torch.optim.lr_scheduler.LRScheduler":
        create_custom_scheduler(training_args, num_training_steps, optimizer)
        lr_scheduler = get_scheduler(
            training_args.lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
            num_training_steps=num_training_steps,
        )
        return lr_scheduler

    @torch.no_grad()
    def get_inputs(self, batch: dict[str, "torch.Tensor"]) -> tuple[list["torch.Tensor"], list["torch.Tensor"]]:
        r"""Generate model's responses given queries."""
        if batch["input_ids"].size(0) == 1:  # handle llama2 ppo with gradient accumulation > 1
            start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
            for k, v in batch.items():
                batch[k] = v[:, start_index:]

        with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
            unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
            if self.model_args.upcast_layernorm:
                layernorm_params = dump_layernorm(unwrapped_model)

            generate_output: torch.Tensor = unwrapped_model.generate(
                generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
            )
            if self.model_args.upcast_layernorm:
                restore_layernorm(unwrapped_model, layernorm_params)

        query = batch["input_ids"].detach().cpu()
        response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
        queries, responses = [], []
        for i in range(len(query)):
            query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
            response_indexes = (response[i] != self.tokenizer.pad_token_id).nonzero()

            if len(response_indexes) == 0:  # allow empty response
                response_length = 1
            elif self.tokenizer.eos_token_id == self.tokenizer.pad_token_id:  # include eos token
                response_length = response_indexes[-1].item() + 2
            else:
                response_length = response_indexes[-1].item() + 1

            queries.append(query[i, query_start_index:])  # remove padding from left
            responses.append(response[i, :response_length])  # remove padding from right

        return queries, responses

    def get_inputs_and_midden_status(self, batch: dict[str, "torch.Tensor"], steps: int, gen_length: int, block_length:int, temperature=0.0) -> tuple[list["torch.Tensor"], list["torch.Tensor"]]:
        r"""Generate model's responses given queries."""

        # with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
        #     unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
        #     if self.model_args.upcast_layernorm:
        #         layernorm_params = dump_layernorm(unwrapped_model)

        group_number = steps // self.finetuning_args.steps_per_group

        selected_groups_to_update_model = random.sample(
            list(range(group_number)),
            int(
                group_number * self.finetuning_args.sample_ratio_of_groups_to_update_model
            )
        )
        selected_groups_to_update_model.sort()
        selected_groups_to_update_model = [ {self.accelerator.process_index: selected_groups_to_update_model} ]
        selected_groups_to_update_model_gathered = accelerate.utils.gather_object(selected_groups_to_update_model)
        for item in selected_groups_to_update_model_gathered:
            if 0 in item.keys():
                selected_groups_to_update_model = item[0]

        stats = generate_with_intermediates(
            model=self.model, # unwrapped_model,
            prompt=batch["input_ids"],
            tokenizer=self.tokenizer,
            accelerator=self.accelerator,
            config=self.config,
            optimizer=self.optimizer,
            model_params=self.model_params,
            ground_truth=batch["ground_truth"] if "ground_truth" in batch.keys() else None,
            num_steps_per_group=self.finetuning_args.steps_per_group,
            sampling_size=self.finetuning_args.sampling_num,
            steps=steps,
            gen_length=gen_length,
            block_length=block_length,
            temperature=temperature,
            cfg_scale=0.,
            remasking='low_confidence',
            steps_per_group=self.finetuning_args.steps_per_group,
            sample_ratio_calculating_correlation_inside_response=self.finetuning_args.sample_ratio_calculating_correlation_inside_response,
            selected_groups_to_update_model=selected_groups_to_update_model,
            num_of_groups_to_accumulate_grad=self.finetuning_args.num_of_groups_to_accumulate_grad,
            ppl_scorer=self.ppl_scorer,
            rejection_sampling=self.finetuning_args.rejection_sampling,
            reward_list=self.finetuning_args.reward_list,
        )

        return stats

    @torch.no_grad()
    def get_rewards(
        self,
        queries: list["torch.Tensor"],
        responses: list["torch.Tensor"],
    ) -> list["torch.Tensor"]:
        r"""Compute scores using given reward model.

        Both inputs and outputs are put on CPU.
        """
        if self.finetuning_args.reward_model_type == "api":
            token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
            messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False)
            return get_rewards_from_server(self.reward_model, messages)

        batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses)
        unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)

        if self.finetuning_args.reward_model_type == "lora":
            replace_model(unwrapped_model, target="reward")
            reward_model = self.model
        else:
            reward_model = self.reward_model

        with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context:  # support bf16
            values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]

        if self.finetuning_args.reward_model_type == "lora":
            replace_model(unwrapped_model, target="default")

        rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
        return rewards.float().detach()  # use fp32 type

    @torch.no_grad()
    def get_rewards_prm(
        self,
        midden_status: list["torch.Tensor"],
    ) -> list["torch.Tensor"]:
        r"""Compute scores using given reward model.

        Both inputs and outputs are put on CPU.
        """
        queries = midden_status[::2]
        responses = midden_status[1::2]
        batch = responses  # TODO 奖励模型构造, 写一个reward model的生成的函数
        reward_model = self.reward_model
        with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context:  # support bf16
            values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]
        # TODO process values to rewards
        return values.float().detach()  # use fp32 type
        '''
        TODO: 修改上面的输出作为[(x_0,y_0),(x_1,y_1),...,(x_T,y_T)]
        修改下面的get_rewards()函数，然后计算上面每一个pair的reward 和 shaping reward.
        r_1: 对于每一个pair，计算对应的reward，使用训练好的PRM来做，对PRM输入(x,y_t)然后计算对应的PRM reward score.  note:每一个pair都要计算
        r_2: 对于每一个pair，需要计算一下内部的相关性奖励，按照上次讨论的。
        r_3: 对于每一个pair, 再用outcome reward model计算一下对应一个奖励分数（进行clip）。
        r_4: 对于最后一个输出(x_T, y_T), 是否要计算对应的优势函数？
        '''

    @torch.no_grad()
    def get_rewards_inside(self, x, r_t, mask, sample_percent=100, batch_size=16, mask_id=126336):
        '''
        x: input_ids, one sample
        r_t: response, one sample
        mask: mask of tokens in r_t that need to calculate confidence when all other tokens are fixed
        batch_size: batch size when calculating confidence

        return: reward
        '''
        x = x.squeeze()
        r_t = r_t.squeeze()
        mask = mask.squeeze()
        confidence = []
        _x = torch.cat((x, r_t), -1)
        mask_num = mask.sum(dim=0, keepdim=True)
        mask_idx = (mask == True).nonzero(as_tuple=False).squeeze()+x.shape[0]
        for idx in range(0, mask_num, batch_size):
            num = batch_size if idx+batch_size<=mask_num else mask_num-idx
            batch = torch.stack([_x.clone() for i in range(num)])
            for i in range(num):
                batch[i][mask_idx[idx+i]] = mask_id
                logits, _, values = self.model(batch, return_dict=True, use_cache=False)
                # with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
                #     unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
                #     if self.model_args.upcast_layernorm:
                #         layernorm_params = dump_layernorm(unwrapped_model)
                #     logits = unwrapped_model.pretrained_model(batch).logits
                #     if self.model_args.upcast_layernorm:
                #         restore_layernorm(unwrapped_model, layernorm_params)

            p = F.softmax(logits.to(torch.float64), dim=-1)
            for i in range(num):
                confidence.append(p[i][mask_idx[idx+i]][_x[mask_idx[idx+i]]].cpu())
        reward = torch.mean(torch.stack(confidence)).item()
        return reward


    @override
    @PPODecorators.empty_device_cache()
    def batched_forward_pass(
        self,
        model: "AutoModelForCausalLMWithValueHead",
        queries: "torch.Tensor",
        responses: "torch.Tensor",
        model_inputs: dict[str, Any],
        return_logits: bool = False,
        response_masks: Optional["torch.Tensor"] = None,
    ) -> tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
        r"""Calculate model outputs in multiple batches.

        Subclass and override to inject custom behavior.
        """
        bs = len(queries)
        fbs = self.config.mini_batch_size
        all_logprobs = []
        all_logits = []
        all_masks = []
        all_values = []

        for i in range(math.ceil(bs / fbs)):
            input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
            query_batch = queries[i * fbs : (i + 1) * fbs]
            response_batch = responses[i * fbs : (i + 1) * fbs]
            if response_masks is not None:
                response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
            input_ids = input_kwargs["input_ids"]
            attention_mask = input_kwargs["attention_mask"]

            with self.amp_context:  # support bf16
                logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False)

            logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
            masks = torch.zeros_like(attention_mask)
            masks[:, :-1] = attention_mask[:, 1:]

            for j in range(len(query_batch)):
                start = len(query_batch[j]) - 1
                if attention_mask[j, 0] == 0:  # offset left padding
                    start += attention_mask[j, :].nonzero()[0].item()
                end = start + len(response_batch[j])

                if response_masks is not None:
                    response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]

                masks[j, :start] = 0
                masks[j, end:] = 0
                if response_masks is not None:
                    masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]

            if return_logits:
                all_logits.append(logits)
            else:
                del logits

            all_values.append(values)
            all_logprobs.append(logprobs)
            all_masks.append(masks)

        return (
            torch.cat(all_logprobs),
            torch.cat(all_logits)[:, :-1] if return_logits else None,
            torch.cat(all_values)[:, :-1],
            torch.cat(all_masks)[:, :-1],
        )

    @override
    def save_model(self, output_dir: Optional[str] = None) -> None:
        r"""Save model checkpoint.

        Subclass and override to inject custom behavior.
        """
        if output_dir is None:
            output_dir = self.args.output_dir

        if self.is_fsdp_enabled or self.is_deepspeed_enabled:
            try:
                state_dict = self.accelerator.get_state_dict(self.model)  # must be called at all ranks
                if self.args.should_save:
                    self._save(output_dir, state_dict=state_dict)
            except ValueError:
                logger.warning_rank0(
                    " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
                    " use zero_to_fp32.py to recover weights"
                )
                if self.args.should_save:
                    self._save(output_dir, state_dict={})
                # remove the dummy state_dict
                remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
                self.model.save_checkpoint(output_dir)

        elif self.args.should_save:
            unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
            self._save(output_dir, state_dict=unwrapped_model.state_dict())

    def get_logits(
        self,
        midden_status_all,
        midden_status_mask,
    ):
        all_logprobs = []
        for step_idx in range(len(midden_status_all)):
            x,y = midden_status_all[step_idx]
            mask = midden_status_mask[step_idx]
            logits, _, values = self.model(x, return_dict=True, use_cache=False)
            selected_logits = logits[mask]
            selected_labels = y[mask]
            selected_logits = torch.nn.functional.log_softmax(selected_logits, dim=-1)
            probabilities = selected_logits.gather(-1, selected_labels.unsqueeze(-1)).squeeze(1)
            all_logprobs.append(probabilities.sum(-1))
        return all_logprobs

    # https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py

    @PPODecorators.empty_device_cache()
    def train_minibatch(
        self,
        batch
    ):
        self.model.train()

        stats = self.get_inputs_and_midden_status(
                batch,
                self.finetuning_args.sample_step,
                self.finetuning_args.sample_generate_length,
                self.finetuning_args.sample_block_length,
                self.finetuning_args.sample_temperature
        )

        # we call optimizer.zero_grad() every time and let `accelerator` handle accumulation
        # see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code

        return stats

    @PPODecorators.empty_device_cache()
    def step(
        self,
        batch
    ):
        bs = self.config.batch_size

        timing = dict()
        t0 = time.time()

        all_stats = []
        early_stop = False
        with self.accelerator.accumulate(self.model):
            train_stats = self.train_minibatch(
                batch
            )
            all_stats.append(train_stats)

        timing["time/reinforce/optimize_step"] = time.time() - t0

        stats = stack_dicts(all_stats)

        # Gather/Reduce stats from all processes
        if self.is_distributed:
            stats = self.gather_stats(stats)
        stats = stats_to_np(stats)
        stats["reinforce/learning_rate"] = self.optimizer.param_groups[0]["lr"]

        stats.update(timing)

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return stats

    # Adapted from transformers.Trainer._set_signature_columns_if_needed
    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
            # label => sentiment | we need query and response for logging purpose
            # answer for calculating accuracy_reward
            self._signature_columns += ["label", "query", "response", "ground_truth"]
