# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import math
import os
import textwrap
import time
from collections import defaultdict
from pathlib import Path
from typing import Callable, Optional, Union

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from accelerate import Accelerator
from accelerate.utils import broadcast, gather_object
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
    BaseImageProcessor,
    DataCollatorWithPadding,
    FeatureExtractionMixin,
    GenerationConfig,
    PreTrainedTokenizerBase,
    ProcessorMixin,
    Trainer,
    TrainerCallback,
    TrainerControl,
    is_wandb_available,
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils import is_rich_available

from ..models.utils import unwrap_model_for_generation
from ..trainer.utils import (
    OnlineTrainerState,
    batch_generation,
    disable_dropout_in_model,
    exact_div,
    first_true_indices,
    forward,
    _get_reward as get_reward,
    prepare_deepspeed,
    print_rich_table,
    selective_log_softmax,
    truncate_response,
    batch_generation_with_hiddenstates,
    compute_BT_loss,
    _compute_BT_loss,
    sort_by_score
)

from .rloo_config import RLOOConfig
from .utils import empty_cache, generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment, set_reward_model_time_step


import torch.optim as optim

import torch.distributed as dist

if is_wandb_available():
    import wandb

INVALID_LOGPROB = 1.0


class RLOOTrainerFeedback_v2(Trainer):
    _tag_names = ["trl", "rloo"]
    def __init__(
        self,
        config: RLOOConfig,
        processing_class: Optional[
            Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
        ],
        rm_processing_class: Optional[
            Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
        ],
        policy: nn.Module,
        ref_policy: nn.Module,
        reward_model: Union[nn.Module, Callable[[list[str]], list[float]]],
        train_dataset: Dataset,
        data_collator: Optional[DataCollatorWithPadding] = None,
        eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
        # less commonly used
        optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        callbacks: Optional[list[TrainerCallback]] = None,
    ) -> None:
        if ref_policy is policy:
            raise ValueError(
                "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
                "same as `policy`, you must mass a copy of it, or `None` if you use peft."
            )

        self.args = config
        args = config
        self.processing_class = processing_class
        self.rm_processing_class=rm_processing_class
        self.policy = policy 

        # Define the collator if not provided
        if data_collator is None:
            data_collator = DataCollatorWithPadding(self.processing_class)

        # FIX enable the generation config
        self.policy.generation_config.eos_token_id = (
            None  # disable `pad_token_id` and `eos_token_id` because we just want to
        )
        self.policy.generation_config.pad_token_id = None  # generate tokens without truncation / padding

        self.ref_policy = ref_policy
        self.reward_model = reward_model
        # fix, set reward model's pad token id (use the policy's tokenizer, which is the same as get_reward)
        self.reward_model.config.pad_token_id = processing_class.pad_token_id

        self.train_dataset = train_dataset
        self.train_dataset_len = len(train_dataset)
        self.data_collator = data_collator
        self.eval_dataset = eval_dataset
        self.optimizer, self.lr_scheduler = optimizers
        self.optimizer_cls_and_kwargs = None  # needed for transformers >= 4.47

        self.rm_optimizer=optim.Adam(self.reward_model.parameters(), lr=config.rm_lr)
        # record global eval generations
        self.global_table = defaultdict(list)
        #########
        # calculate various batch sizes
        #########
        if args.total_episodes is None:  # allow the users to define episodes in terms of epochs.
            args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
        # accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
        accelerator = Accelerator()
        accelerator.gradient_accumulation_steps=args.gradient_accumulation_steps
        
        self.accelerator = accelerator
        args.world_size = accelerator.num_processes
        args.local_batch_size = (
            args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
        ) # allocate sample size for each device
        args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
        args.batch_size = int(args.local_batch_size * args.world_size)  # the global batch size
        args.mini_batch_size = exact_div(
            args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
        )
        args.local_mini_batch_size = exact_div(
            args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
        ) # divide local batch into local mini batch 
        args.num_total_batches = math.ceil(
            args.total_episodes / args.batch_size
        )  # we may train for more than `total_episodes`
        time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
        time_int = broadcast(time_tensor, 0).item()  # avoid different timestamps across processes
        args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
        self.local_seed = args.seed + accelerator.process_index * 100003  # Prime
        if args.num_sample_generations > 0:
            self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) 
            self.sample_generations_freq=5
        self.local_dataloader_batch_size = exact_div(
            args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k"
        )  # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times

        #########
        # setup model, optimizer, and others
        #########
        for module in [policy, ref_policy, reward_model]:
            if isinstance(module, nn.Module):
                disable_dropout_in_model(module)
        if args.stop_token and args.stop_token == "eos":
            args.stop_token_id = self.processing_class.eos_token_id
        self.model = policy

        # TODO: fix pad token id
        self.model.generation_config.eos_token_id = (
            None  # disable `pad_token_id` and `eos_token_id` because we just want to
        )
        self.model.generation_config.pad_token_id = None  # generate tokens without truncation / padding



        self.create_optimizer_and_scheduler(
            num_training_steps=args.num_total_batches
        )  # note that we are calling `self.lr_scheduler.step()` manually only at the batch level

        #########
        ### trainer specifics
        #########
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
        self.callback_handler = CallbackHandler(
            self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
        )
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
        self.control = TrainerControl()
        self.state = OnlineTrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
            stateful_callbacks=[
                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
            ],
        )

        self.current_flos = 0
        self.hp_search_backend = None
        self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
        self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
        # Create distant repo and output directory if needed
        self.hub_model_id = None
        if self.args.push_to_hub:
            self.init_hf_repo()
        if self.args.should_save:
            os.makedirs(self.args.output_dir, exist_ok=True)
        self.backup_model = None

        # Add tags for models that have been loaded with the correct transformers version
        if hasattr(self.model, "add_model_tags"):
            self.model.add_model_tags(self._tag_names)

        #########
        ### setup dataloader
        #########
        self.dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.local_dataloader_batch_size,
            shuffle=True,
            collate_fn=self.data_collator,
            drop_last=True,  # needed; otherwise the last batch will be of ragged shape
        )
        # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
        # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
        torch.manual_seed(args.seed)
        self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
        torch.manual_seed(self.local_seed)  # reset the local seed again

        self.eval_dataloader = DataLoader(
            self.eval_dataset,
            batch_size=args.per_device_eval_batch_size,
            collate_fn=self.data_collator,
            drop_last=True,
        )  # no need to shuffle eval dataset
        self.eval_dataloader = accelerator.prepare(self.eval_dataloader)

        if self.is_deepspeed_enabled:
            if isinstance(self.reward_model, nn.Module):
                self.reward_model = prepare_deepspeed(
                    self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
                )
            self.ref_policy = prepare_deepspeed(
                self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
            )
            self.deepspeed = self.model
        else:
            self.ref_policy = self.ref_policy.to(self.accelerator.device)
            if isinstance(self.reward_model, nn.Module):
                self.reward_model = self.reward_model.to(self.accelerator.device)

    def get_train_dataloader(self) -> DataLoader:
        return self.dataloader

    def get_eval_dataloader(self) -> DataLoader:
        return self.eval_dataloader

    def train(self):
        args = self.args
        accelerator = self.accelerator
        optimizer = self.optimizer
        model = self.model
        self.model_wrapped = self.model
        ref_policy = self.ref_policy
        reward_model = self.reward_model
        processing_class = self.processing_class
        dataloader = self.dataloader
        device = accelerator.device
        rm_optimizer=self.rm_optimizer
        rm_processing_class=self.rm_processing_class

        def repeat_generator():
            while True:
                yield from dataloader

        iter_dataloader = iter(repeat_generator())

        generation_config = GenerationConfig(
            max_new_tokens=args.response_length,
            temperature=(args.temperature + 1e-7),
            top_k=0.0,
            top_p=1.0,
            do_sample=True,
        )
        # fix: ensure the generation config of model is right
        # model.generation_config.pad_token_id=processing_class.pad_token_id


        # fix, add pad token id
        # generation_config.eos_token_id = processing_class.eos_token_id

        # wandb metrics
        accelerator.print("===training policy===")
        start_time = time.time()
        stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
        approxkl_stats = torch.zeros(stats_shape, device=device)
        pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
        pg_loss_stats = torch.zeros(stats_shape, device=device)
        vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
        entropy_stats = torch.zeros(stats_shape, device=device)
        ratio_stats = torch.zeros(stats_shape, device=device)
        model.train()

        # trainer state initialization
        self.state.global_step = 0
        self.state.episode = 0
        self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2
        self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
        # Compute absolute values for logging, eval, and save if given as ratio
        if args.logging_steps is not None:
            if args.logging_steps < 1:
                self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
            else:
                self.state.logging_steps = args.logging_steps
        if args.eval_steps is not None:
            if args.eval_steps < 1:
                self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
            else:
                self.state.eval_steps = args.eval_steps
        if args.save_steps is not None:
            if args.save_steps < 1:
                self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
            else:
                self.state.save_steps = args.save_steps
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control) # control the save, log, and eval step
        
        # draw one batch prompt from dataset
        for update in range(1, args.num_total_batches + 1): 
            accelerator.print(f"===current batch: {update}, total_batches :{args.num_total_batches}===")
            self.state.episode += 1 * args.batch_size 
            data = next(iter_dataloader) 
            with torch.no_grad():
                queries = data["input_ids"].to(device) #[local_batch_size/rloo_K,seq_len]
                queries = queries.repeat(args.rloo_k, 1) 
                context_length = queries.shape[1]
                responses = []
                postprocessed_responses = []
                logprobs = []
                ref_logprobs = []
                scores = []
                sequence_lengths = []
                postprocessed_query_responses = []
                rm_context_length=context_length  # padding or truncate to the same length

                accelerator.print(f"===batch generation at batch {update}===")
                # Generate responses and compute logprobs 
                with unwrap_model_for_generation(
                    self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
                ) as unwrapped_model:
                    if not args.rm_with_feedback:
                        query_responses, logitss = batch_generation(
                            unwrapped_model,
                            queries,
                            args.local_rollout_forward_batch_size,
                            processing_class.pad_token_id,
                            generation_config,
                        ) # policy model forward [batch,len_x+len_y] [batch,len(y),vocab_size] 
                        hiddenstates=None
                    else:
                        query_responses, logitss, hiddenstates, _ = batch_generation_with_hiddenstates(
                        unwrapped_model,
                        queries,
                        args.local_rollout_forward_batch_size,
                        processing_class.pad_token_id,
                        generation_config,
                        enable_lqh=args.lqh
                        ) # policy model forward [batch,len_x+len_y] [batch,len(y),vocab_size]  [batch,dimension]
                        hiddenstates=hiddenstates.detach()
            

                accelerator.print(f"===annotate rewards at batch {update}===")
                # Process responses in batches
                for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): 
                    query = queries[i : i + args.local_rollout_forward_batch_size] #[2,482]
                    query_response = query_responses[i : i + args.local_rollout_forward_batch_size] 
                    # TODO devide into local batches
                    if args.rm_with_feedback:
                        hiddenstate=hiddenstates[i : i + args.local_rollout_forward_batch_size]
                    else:
                        hiddenstate=None

                    response = query_response[:, context_length:] #[2,53] 
                    logits = logitss[i : i + args.local_rollout_forward_batch_size] #[2,53,vocab_size]
                    logprob = selective_log_softmax(logits, response) #[2,53]
                    del logits
                    empty_cache()
                    # ref model forward
                    ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
                    ref_logits = ref_output.logits[:, context_length - 1 : -1] #[2,53,vocab_size]
                    ref_logits /= args.temperature + 1e-7
                    ref_logprob = selective_log_softmax(ref_logits, response) #[2,53]
                    del ref_output, ref_logits 
                    empty_cache()

                    # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
                    postprocessed_response = response
                    if args.stop_token_id is not None:  # handle the edge case when stop_token_id exists but is 0
                        postprocessed_response = truncate_response(
                            args.stop_token_id, processing_class.pad_token_id, response
                        ) 

                    # TODO: decode the query and the response enforce the query padding to the 
                    decoded_query=processing_class.batch_decode(query, skip_special_tokens=True)
                    decoded_response=processing_class.batch_decode(postprocessed_response, skip_special_tokens=True)
                    # force keeping the same shape as before
                    query_ids=rm_processing_class(decoded_query, padding="max_length", 
                                                  truncation=True,return_tensors="pt", 
                                                  padding_side="left", max_length=context_length)["input_ids"].to(queries.device)
                    response_ids=rm_processing_class(decoded_response, padding="max_length", 
                                                     truncation=True, return_tensors="pt", 
                                                     padding_side="right", max_length=args.response_length)["input_ids"].to(queries.device)
                    _query=query_ids # should be 16,len(x)
                    _postprocessed_response=response_ids # should be 16,len(y)

                    del decoded_query, decoded_response

                    # Response Processing 2. run reward model on the truncated responses
                    # TODO  use the rm tokenized query response
                    # postprocessed_query_response = torch.cat((query, postprocessed_response), 1) #[2,x_len+truncated_y_len]
                    postprocessed_query_response = torch.cat((_query, _postprocessed_response), 1) #[2,x_len+truncated_y_len]

                    # TODO change to rm pad token
                    # sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 
                    sequence_length = first_true_indices(postprocessed_response == rm_processing_class.pad_token_id) - 1 

                    if isinstance(reward_model, nn.Module):
                        # TODO
                        if args.dynamic_fw:
                            set_reward_model_time_step(reward_model, (update-1)/args.num_total_batches)
                            
                        _, score, _ , _, _ = get_reward(
                            reward_model, postprocessed_query_response, rm_processing_class.pad_token_id, context_length, hiddenstate
                        ) 
                    else:
                        score = torch.tensor(
                            reward_model(
                                processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
                            ),
                            dtype=torch.float,
                        ).to(device)

                    # Store batch results
                    responses.append(response) # from policy tokenize 
                    postprocessed_responses.append(postprocessed_response) # from policy tokenzie
                    logprobs.append(logprob)
                    ref_logprobs.append(ref_logprob)
                    sequence_lengths.append(sequence_length) # from the rm tokenize
                    scores.append(score)
                    postprocessed_query_responses.append(postprocessed_query_response) # from rm tokenize (we add)


                # Concatenate all batched results, process the current batch data before used for training
                responses = torch.cat(responses, 0)
                postprocessed_responses = torch.cat(postprocessed_responses, 0)
                logprobs = torch.cat(logprobs, 0)
                ref_logprobs = torch.cat(ref_logprobs, 0)
                sequence_lengths = torch.cat(sequence_lengths, 0)

                scores = torch.cat(scores, 0)
                postprocessed_query_responses=torch.cat(postprocessed_query_responses,0)

                # clear the history hiddenstates, prepare for training
                if args.rm_with_feedback:
                    feedback_hiddenstates=torch.zeros(hiddenstates.shape[0],hiddenstates.shape[1]+1,hiddenstates.shape[2],dtype=hiddenstates.dtype,device=hiddenstates.device)

                # del (logprob, ref_logprob, hiddenstates, score)
                # TODO: drop all the inner loop variants
                del (response, postprocessed_response, logprob, ref_logprob, sequence_length, score, postprocessed_query_response, hiddenstate, hiddenstates)

                empty_cache()
                gc.collect()

                # Response Processing 3. filter response. Ensure that the sample contains stop_token_id
                # responses not passing that filter will receive a low (fixed) score
                # only query humans on responses that pass that filter
                contain_eos_token = torch.any(postprocessed_responses == processing_class.eos_token_id, dim=-1) 
                if args.missing_eos_penalty is not None:
                    scores[~contain_eos_token] -= self.args.missing_eos_penalty #[B]
                # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")

                # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw  (屏蔽padding 部分的对数概率)
                response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
                padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
                logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
                ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)

                # 4. compute rewards
                # Compute KL divergence
                kl = logprobs - ref_logprobs # token level kl = 该token的logp做差 [batch,seq_len] 

                # Normalize rewards reinforce++
                if args.normalize_reward:
                    scores = (scores - scores.mean()) / (scores.std() + 1e-8)
                    scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range)


                # Compute total reward with KL penalty
                if args.token_level_kl: # reinforce++
                    # Token-level KL penalty: apply KL penalty per token
                    kl_reward = -args.kl_coef * kl

                    # Get the index of the last non-padded token for each sequence
                    eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
                    last_reward = torch.zeros_like(kl)
                    # Ensure scores has correct shape and type
                    scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
                    last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)

                    # Combine KL reward and last reward
                    non_score_reward = kl_reward.sum(1)  # Keep this for logging
                    reward = last_reward + kl_reward
                    rlhf_reward = reward.sum(1)  # Sum across sequence length
                else:
                    # Sequence-level KL penalty: sum KL across tokens first
                    sequence_kl = kl.sum(1) # add token level kl,to get seq level kl [B,S]->[B]
                    non_score_reward = -args.kl_coef * sequence_kl
                    rlhf_reward = non_score_reward + scores # seq level reward 

                # vectorized RLOO advantages implementation
                # Group related computation, B->K,B//K,->B
                rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)  # [K,B//K]  think K as group
                baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)  # get other seq's avg reward as baseline
                advantages = rlhf_reward - baseline # final reward for each seq
                advantages = advantages.flatten() #[K,B//K]->[B]

                # Normalize advantages reinforce++
                if args.normalize_advantage:
                    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                empty_cache()

            accelerator.print(f"===train policy at batch {update}===")
            # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
            for ppo_epoch_idx in range(args.num_ppo_epochs): # for each epoch, train the same batch generated before, with a shuffled order
                b_inds = np.random.permutation(args.local_batch_size)
                minibatch_idx = 0
                for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): #  local_mini_batch_size
                    mini_batch_end = mini_batch_start + args.local_mini_batch_size
                    mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
                    gradient_accumulation_idx = 0
                    for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): # allocate minibatch for each device 
                        with accelerator.accumulate(model): 
                            micro_batch_end = micro_batch_start + args.per_device_train_batch_size 
                            micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] # 当前device 被分到的batch

                            # Get batch data mb->minibatch
                            mb_advantage = advantages[micro_batch_inds] #[B]
                            mb_responses = responses[micro_batch_inds] #[B,y_len]
                            mb_query_responses = query_responses[micro_batch_inds] # [B,x_len]
                            mb_logprobs = logprobs[micro_batch_inds] #[B,y_len]

                            # Forward pass
                            output = forward(model, mb_query_responses, processing_class.pad_token_id)
                            logits = output.logits[:, context_length - 1 : -1]
                            logits /= args.temperature + 1e-7 # higher temperature, lower logits, higher uncertainty
                            
                            # only update at the last epoch
                            if ppo_epoch_idx == args.num_ppo_epochs - 1 and  args.rm_with_feedback:
                                # get new outputs 
                                nw_mb_hiddenstates=output.hidden_states[-1].detach().clone()
                                # update the hiddenstates
                                feedback_hiddenstates[micro_batch_inds] = nw_mb_hiddenstates

                            # Compute new logprobs
                            new_logprobs = selective_log_softmax(logits, mb_responses) #  [B,len_y]
                            new_logprobs = torch.masked_fill(
                                new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
                            )

                            # Compute probability ratios
                            new_ratio = (new_logprobs - mb_logprobs).exp() # token level ratio  
                            new_logprobs = new_logprobs.sum(1) # new model seq level logprobs 
                            mb_logprobs = mb_logprobs.sum(1) # old model seq level logprobs
                            logprobs_diff = new_logprobs - mb_logprobs
                            ratio = torch.exp(logprobs_diff)

                            # PPO clipped loss
                            pg_losses = -mb_advantage * ratio
                            pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
                            pg_loss_max = torch.max(pg_losses, pg_losses2)
                            pg_loss = pg_loss_max.mean()

                            # Final loss
                            loss = pg_loss

                            # Optimization step
                            accelerator.backward(loss) 
                            optimizer.step() 
                            optimizer.zero_grad()
                            
                            with torch.no_grad():
                                pg_clipfrac = (pg_losses2 > pg_losses).float().mean() 
                                prob_dist = torch.nn.functional.softmax(logits, dim=-1) # [B,len(y),vocab] 
                                entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) # [B] 
                                approxkl = 0.5 * (logprobs_diff**2).mean() # logprobs_diff = new_logprobs - mb_logprobs
                                approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
                                pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
                                    pg_clipfrac
                                )
                                pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
                                entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() # tokenlevel_entropy-> seqlevel_entropy
                                ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
                        gradient_accumulation_idx += 1 
                    minibatch_idx += 1 

                    # del everything and empty cache
                    # fmt: off
                    del (
                        output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
                        pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
                        mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
                    )
                    # fmt: on
                    empty_cache()

            
            rm_loss=None
            if args.rm_with_feedback:
                accelerator.print(f"===train reward model at batch {update}===")
                # train RM  with the current batch
                # query_responses [batch, x+y]  hiddenstates [batch, x+y, D]
                # TODO construct preference pair
                # reshape to groups
                scores=scores.reshape(args.rloo_k,-1)
                feedback_hiddenstates=feedback_hiddenstates.reshape(args.rloo_k,-1,feedback_hiddenstates.shape[-2],feedback_hiddenstates.shape[-1])  #
                # query_responses=query_responses.reshape(args.rloo_k,-1,query_responses.shape[-1]) 
                query_responses=postprocessed_query_responses.reshape(args.rloo_k,-1,postprocessed_query_responses.shape[-1])
                sequence_lengths=sequence_lengths.reshape(args.rloo_k,-1)
                attention_mask = query_responses != rm_processing_class.pad_token_id
                # down sort in each group
                feedback_hiddenstates = sort_by_score(scores, feedback_hiddenstates)
                query_responses=sort_by_score(scores, query_responses)
                attention_mask=sort_by_score(scores, attention_mask)
                sequence_lengths=sort_by_score(scores, sequence_lengths)


                # construct the preference pair select at the first dim
                # shape as batch_size//k, ...
                chosen_feedback_hiddenstates=feedback_hiddenstates[0,:,:,:]
                chosen_query_responses=query_responses[0,:,:]
                chosen_attention_mask=attention_mask[0,:,:]
                chosen_sequence_lengths=sequence_lengths[0,:]

                # shape as batch_size//k, ...
                rejected_feedback_hiddenstates=feedback_hiddenstates[-1,:,:,:]
                rejected_query_responses=query_responses[-1,:,:]
                rejected_attention_mask=attention_mask[-1,:,:]
                rejected_sequence_lengths=sequence_lengths[-1,:]

                # shape as batch_size//k, k, ...
                if args.enable_le:
                    group_query_responses=query_responses.reshape(-1, args.rloo_k, query_responses.shape[-1])
                    group_feedback_hiddenstates=feedback_hiddenstates.reshape(-1, args.rloo_k, feedback_hiddenstates.shape[-2], feedback_hiddenstates.shape[-1])
                    
                # TODO freeze the params except MLP and Score head
                for param in reward_model.parameters():
                    param.requires_grad = False

                if hasattr(reward_model, 'model'):
                    lm_backbone=reward_model.model
                elif hasattr(reward_model, "gpt_neox"):
                    lm_backbone=reward_model.gpt_neox
                
                modules_to_freeze=["mlp","attention","adapter","mean_net", "var_net"]

                for module in modules_to_freeze:
                    target_attr=getattr(lm_backbone, module, None)
                    if target_attr:
                        for param in target_attr.parameters() :
                            param.requires_grad=True
                
                for param in reward_model.score.parameters():
                    param.requires_grad = True
                    
                # TODO compute loss
                for i in range(0, chosen_feedback_hiddenstates.shape[0], args.local_rollout_forward_batch_size):

                    inputs = {
                        "input_ids_chosen": chosen_query_responses[i : i + args.local_rollout_forward_batch_size],        
                        "attention_mask_chosen": chosen_attention_mask[i : i + args.local_rollout_forward_batch_size],   
                        "sequence_lengths_chosen": chosen_sequence_lengths[i : i + args.local_rollout_forward_batch_size],   
                        "input_ids_rejected": rejected_query_responses[i : i + args.local_rollout_forward_batch_size],      
                        "attention_mask_rejected": rejected_attention_mask[i : i + args.local_rollout_forward_batch_size], 
                        "sequence_lengths_rejected": rejected_sequence_lengths[i : i + args.local_rollout_forward_batch_size],
                        "chosen_feedback": chosen_feedback_hiddenstates[i : i + args.local_rollout_forward_batch_size],
                        "rejected_feedback": rejected_feedback_hiddenstates[i : i + args.local_rollout_forward_batch_size]
                    }

                    if args.enable_le:
                        # for group entropy loss, construct as  batch_size//k, k,
                        inputs["group_input_ids"]=group_query_responses[i : i + args.local_rollout_forward_batch_size]
                        inputs["group_feedback"]=group_feedback_hiddenstates[i : i + args.local_rollout_forward_batch_size]

                    # loss=compute_BT_loss(reward_model,inputs) 
                    rm_loss=_compute_BT_loss(reward_model,
                                            inputs,  
                                            pad_token_id=rm_processing_class.pad_token_id,
                                            context_length=rm_context_length,
                                            enable_lm=args.enable_lm,
                                            enable_le=args.enable_le,
                                            le_weight=args.le_weight)
                    accelerator.backward(rm_loss)
                    rm_optimizer.step()
                    rm_optimizer.zero_grad()

                # close the params
                for module in modules_to_freeze:
                    target_attr=getattr(lm_backbone, module, None)
                    if target_attr:
                        for param in target_attr.parameters() :
                            param.requires_grad=False

                for param in reward_model.score.parameters():
                    param.requires_grad = False

                # free the space
                del (
                    # chosen
                    chosen_feedback_hiddenstates,
                    chosen_query_responses,
                    chosen_attention_mask,
                    # reject 
                    rejected_feedback_hiddenstates,
                    rejected_query_responses,
                    rejected_attention_mask,
                    # global
                    group_query_responses,
                    group_feedback_hiddenstates,
                    feedback_hiddenstates,
                    query_responses,
                    attention_mask
                )
                del inputs

                empty_cache()


            # Compute metrics
            with torch.no_grad():
                mean_kl = kl.sum(1).mean() # [B,len(y)]-> [B]->[]
                mean_entropy = (-logprobs).sum(1).mean() # [B,len(y)]-> [B]->[]
                mean_non_score_reward = non_score_reward.mean() # [B]->[]
                eps = int(self.state.episode / (time.time() - start_time))
                metrics = {} # wandb metric, each update, report once
                metrics["eps"] = eps
                metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
                metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item() 
                metrics["objective/non_score_reward"] = (
                    self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
                )
                metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item() # [batch] gather-> [batch*8]->mean->[]
                metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item() # [batch] gather-> [batch*8]->mean->[]
                metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item() # gather [a,b,c]->[a*8,b,c]->[]  
                metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item() 

                metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item() 



                metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
                metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item() 
                metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item() 



                metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item() 
                metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() 
                metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
                metrics["episode"] = self.state.episode

                # add rm_loss
                if rm_loss:
                    metrics["rm_loss"] = self.accelerator.gather_for_metrics(rm_loss).mean().item()


                self.state.epoch = self.state.episode / (args.rloo_k * self.train_dataset_len)  # used by self.log
                self.log(metrics)
                
            del kl, mean_kl, mean_entropy, scores

            self.lr_scheduler.step()


            self.state.global_step += 1 
            self.control = self.callback_handler.on_step_end(args, self.state, self.control)  
            
            if self.control.should_save: 
                self._save_checkpoint(model, trial=None)
                self.control = self.callback_handler.on_save(self.args, self.state, self.control)
            


            # eval 
            if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
                accelerator.print(f"===eval policy at batch {update}===")
                self.generate_completions(sampling=True)


            empty_cache()
            gc.collect()


        # HF trainer specifics
        self.control = self.callback_handler.on_train_end(args, self.state, self.control) 
        if self.control.should_save:
            self._save_checkpoint(model, trial=None, metrics=None)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control) 



    def generate_completions(self, sampling: bool = False):
        args = self.args
        processing_class = self.processing_class
        rm_processing_class=self.rm_processing_class
        generation_config = GenerationConfig(
            max_new_tokens=self.args.response_length,
            temperature=(0.01 + 1e-7),
            top_k=0.0,
            top_p=1.0,
            do_sample=True
        )
        metrics={}
        # table = defaultdict(list)

        cur_step=self.state.episode // args.batch_size 

        with unwrap_model_for_generation(
            self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
        ) as unwrapped_model:
            for batch in self.eval_dataloader: # each device generate one batch, and gather
                query = batch["input_ids"]
                with torch.no_grad():
                    context_length = query.shape[1]
                    if not args.rm_with_feedback:
                        query_response, _ = batch_generation(
                            unwrapped_model,
                            query,
                            query.shape[0],
                            processing_class.pad_token_id,
                            generation_config,
                        )
                        hiddenstates=None
                    else: 
                        query_response, _ , hiddenstates, _ = batch_generation_with_hiddenstates(
                            unwrapped_model,
                            query,
                            query.shape[0],
                            processing_class.pad_token_id,
                            generation_config,
                            enable_lqh=args.lqh
                        )

                    response = query_response[:, context_length:]
                    postprocessed_response = response
                    if args.stop_token_id is not None:  # handle the edge case when stop_token_id exists but is 0
                        postprocessed_response = truncate_response(
                            args.stop_token_id, processing_class.pad_token_id, response
                        ) # set <pad> to the tokens after eos token 
                    _query=gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
                    # table["query"].extend(_query)
                    self.global_table["query"].extend(_query)

                    _model_response= gather_object(processing_class.batch_decode(postprocessed_response,skip_special_tokens=False, clean_up_tokenization_spaces=False))
                    # table["model response"].extend(_model_response)
                    self.global_table["model response"].extend(_model_response)

                    _step=[cur_step for _ in range(len(_model_response))]
                    # table["step"].extend(_step)
                    self.global_table["step"].extend(_step)


                    # TODO decode 
                    decoded_query=processing_class.batch_decode(query, skip_special_tokens=True)
                    decoded_response=processing_class.batch_decode(postprocessed_response, skip_special_tokens=True)
                    # TODO encode
                    query_ids=rm_processing_class(decoded_query, padding="max_length", 
                                                  truncation=True,return_tensors="pt", 
                                                  padding_side="left", max_length=context_length)["input_ids"].to(query.device)
                    response_ids=rm_processing_class(decoded_response, padding="max_length", 
                                                     truncation=True, return_tensors="pt", 
                                                     padding_side="right", max_length=args.response_length)["input_ids"].to(query.device)
                    
                    _query=query_ids # should be 16,len(x)
                    _postprocessed_response=response_ids # should be 16,len(y)

                    # TODO concat
                    # postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
                    postprocessed_query_response = torch.cat((_query, _postprocessed_response), 1) #[2,x_len+truncated_y_len]

                    if isinstance(self.reward_model, nn.Module):
                        _, score, raw_score, _ , _ = get_reward(
                            self.reward_model,
                            postprocessed_query_response,
                            rm_processing_class.pad_token_id,
                            context_length,
                            hiddenstates,
                            return_raw_score=True
                        ) 
                    else:
                        score = torch.tensor(
                            self.reward_model(
                                processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
                            ),
                            dtype=torch.float,
                        ).to(postprocessed_query_response.device)

                    _score=self.accelerator.gather_for_metrics(score).float().cpu().numpy()
                    _raw_score=self.accelerator.gather_for_metrics(raw_score).float().cpu().numpy()
                    # table["score"].extend(_score)
                    self.global_table["score"].extend(_score)
                    self.global_table["raw_score"].extend(_raw_score)
                    metrics["eval/scores"]= float(_score.mean())
                    metrics["eval/raw_scores"]= float(_raw_score.mean())
                    metrics["eval/scores_std"]=float(_score.std())
                    metrics["eval/raw_scores_std"]=float(_raw_score.std())

                    # process 
                    del query, context_length, query_response, response, postprocessed_response
                    del _query, _model_response, _step, decoded_query, decoded_response
                    del query_ids, response_ids, _postprocessed_response, postprocessed_query_response
                    del score, raw_score, _score, _raw_score
                    if 'hiddenstates' in locals():
                        del hiddenstates
                    empty_cache()

                if sampling: 
                    break

        # df = pd.DataFrame(table)
        df=pd.DataFrame(self.global_table)
        
        # use the main process to report the eval results
        if self.accelerator.is_main_process:
            if is_rich_available():
                print_rich_table(df.iloc[0 : 0 + 5])
            

            if "wandb" in args.report_to:
                self.log(metrics)


    # Ensure the model card is saved along with the checkpoint
    def _save_checkpoint(self, model, trial):
        if self.args.hub_model_id is None:
            model_name = Path(self.args.output_dir).name
        else:
            model_name = self.args.hub_model_id.split("/")[-1]
        self.create_model_card(model_name=model_name)
        super()._save_checkpoint(model, trial)

    def create_model_card(
        self,
        model_name: Optional[str] = None,
        dataset_name: Optional[str] = None,
        tags: Union[str, list[str], None] = None,
    ):
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            model_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the model.
            dataset_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the dataset used for training.
            tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
                Tags to be associated with the model card.
        """
        if not self.is_world_process_zero():
            return

        if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
            base_model = self.model.config._name_or_path
        else:
            base_model = None

        tags = tags or set()
        if isinstance(tags, str):
            tags = {tags}

        if hasattr(self.model.config, "unsloth_version"):
            tags.add("unsloth")

        tags.update(self._tag_names)

        citation = textwrap.dedent("""\
        @inproceedings{ahmadian2024back,
            title        = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
            author       = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
            year         = 2024,
            booktitle    = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
            publisher    = {Association for Computational Linguistics},
            pages        = {12248--12267},
            editor       = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
        }""")

        model_card = generate_model_card(
            base_model=base_model,
            model_name=model_name,
            hub_model_id=self.hub_model_id,
            dataset_name=dataset_name,
            tags=tags,
            wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
            comet_url=get_comet_experiment_url(),
            trainer_name="RLOO",
            trainer_citation=citation,
            paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
            paper_id="2402.14740",
        )

        model_card.save(os.path.join(self.args.output_dir, "README.md"))
