# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""

import os
import statistics
import uuid
from collections import defaultdict
from copy import deepcopy
from pprint import pprint

import math
import numpy as np
import torch
from omegaconf import OmegaConf, open_dict
from torch.distributed.fsdp import StateDictType, FullStateDictConfig

from verl import DataProto
from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.trainer.ppo.ray_trainer import Role, WorkerType, ResourcePoolManager, reduce_metrics, _compute_response_info, \
    _timer
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from . import prime_core_algos
from .prime_core_algos import compute_return_abs_accuracy, compute_return_smoothness
from verl.trainer.ppo.core_algos import compute_reinforce_plus_plus_outcome_advantage


def compute_advantage(data: DataProto, adv_estimator, config, dpo_acc=0.5):
    metrics={}
    if adv_estimator == 'rloo':
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        advantages, returns = prime_core_algos.compute_rloo_advantage_return(data, response_mask,
                                                                             config.actor_rollout_ref.rollout.n, config)
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
    # if adv_estimator == 'grpo':
    #     responses = data.batch['responses']
    #     index = data.non_tensor_batch['uid']
    #     response_length = responses.size(-1)
    #     attention_mask = data.batch['attention_mask']
    #     response_mask = attention_mask[:, -response_length:]
    #     advantages, returns = core_algos.compute_grpo_outcome_advantage(data, response_mask,
    #                                                                          config.actor_rollout_ref.rollout.n, config)
    #     data.batch['advantages'] = advantages
    #     data.batch['returns'] = returns
    elif adv_estimator == 'reinforce_plus_plus':
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]

        reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32)

        prompt_ids = data.batch['prompts']
        prompt_length = prompt_ids.shape[-1]
        valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1)

        reward_tensor[
            torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
            valid_response_length - 1] = data.batch['acc']

        advantages, returns = compute_reinforce_plus_plus_outcome_advantage(reward_tensor, response_mask,
                                                                            torch.tensor(1.0))
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
    elif adv_estimator == 'prime':
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        advantages, returns, metrics = prime_core_algos.compute_prime_advantage_return(data, response_mask,
                                                                             config.actor_rollout_ref.rollout.n, config, dpo_acc)
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
        data.meta_info['adv_metrics'] = metrics
    elif adv_estimator == 'prime_value':
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        advantages, returns, metrics = prime_core_algos.compute_prime_value_advantage_return(data, response_mask,
                                                                             config.actor_rollout_ref.rollout.n, config, )
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
        data.meta_info['adv_metrics'] = metrics
    elif adv_estimator == 'prime_value_ce':
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        advantages, returns, metrics = prime_core_algos.compute_reasonable_prime_value_advantage_return(data, response_mask,
                                                                             config.actor_rollout_ref.rollout.n, config, )
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
        data.meta_info['adv_metrics'] = metrics
    elif adv_estimator == 'prime_middle_ce': # There are at least two types of value models, one directly uses sigmoid as value with special BellEq, another converts to probability
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        advantages, returns, metrics = prime_core_algos.compute_middle_prime_advantage_return(data, response_mask,
                                                                             config.actor_rollout_ref.rollout.n, config, linear=False)
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
        data.meta_info['adv_metrics'] = metrics
    elif adv_estimator == 'prime_middle_ce_linear':
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        advantages, returns, metrics = prime_core_algos.compute_middle_prime_advantage_return(data, response_mask,
                                                                                              config.actor_rollout_ref.rollout.n,
                                                                                              config, linear=True)
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
        data.meta_info['adv_metrics'] = metrics
    elif adv_estimator == 'simple_upv':
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        advantages, returns, metrics = prime_core_algos.compute_simple_upv(data, response_mask,
                                                                                              config.actor_rollout_ref.rollout.n,
                                                                                              config, linear=True)
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
        data.meta_info['adv_metrics'] = metrics
    elif adv_estimator == 'single_upv': # this does not require a reward model. it calculates value based on pi_old and the pi_ref, and pi_ref is updated during training.
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        advantages, returns, metrics = prime_core_algos.compute_single_upv(data, response_mask,
                                                                                              config.actor_rollout_ref.rollout.n,
                                                                                              config, linear=True)
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
        data.meta_info['adv_metrics'] = metrics
    elif adv_estimator == 'adaptive_upv':
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        advantages, returns, metrics = prime_core_algos.compute_adaptive_upv(data, response_mask,
                                                                                              config.actor_rollout_ref.rollout.n,
                                                                                              config, linear=True)
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
        data.meta_info['adv_metrics'] = metrics

    else:
        raise NotImplementedError
    return data


def compute_data_metrics(batch, use_critic=True):

    advantages = batch.batch['advantages']
    returns = batch.batch['returns']

    max_response_length = batch.batch['responses'].shape[-1]

    prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool()
    response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool()

    max_prompt_length = prompt_mask.size(-1)

    response_info = _compute_response_info(batch)
    prompt_length = response_info['prompt_length']
    response_length = response_info['response_length']

    valid_adv = torch.masked_select(advantages, response_mask)
    valid_returns = torch.masked_select(returns, response_mask)

    if use_critic:
        values = batch.batch['values']
        valid_values = torch.masked_select(values, response_mask)
        return_diff_var = torch.var(valid_returns - valid_values)
        return_var = torch.var(valid_returns)

    metrics = {
        # adv
        'critic/advantages/mean':
            torch.mean(valid_adv).detach().item(),
        'critic/advantages/max':
            torch.max(valid_adv).detach().item(),
        'critic/advantages/min':
            torch.min(valid_adv).detach().item(),
        # returns
        'critic/returns/mean':
            torch.mean(valid_returns).detach().item(),
        'critic/returns/max':
            torch.max(valid_returns).detach().item(),
        'critic/returns/min':
            torch.min(valid_returns).detach().item(),
        **({
            # values
            'critic/values/mean': torch.mean(valid_values).detach().item(),
            'critic/values/max': torch.max(valid_values).detach().item(),
            'critic/values/min': torch.min(valid_values).detach().item(),
            # vf explained var
            'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
        } if use_critic else {}),

        # response length
        'response_length/mean':
            torch.mean(response_length).detach().item(),
        'response_length/max':
            torch.max(response_length).detach().item(),
        'response_length/min':
            torch.min(response_length).detach().item(),
        'response_length/clip_ratio':
            torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(),
        # prompt length
        'prompt_length/mean':
            torch.mean(prompt_length).detach().item(),
        'prompt_length/max':
            torch.max(prompt_length).detach().item(),
        'prompt_length/min':
            torch.min(prompt_length).detach().item(),
        'prompt_length/clip_ratio':
            torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
    }
    return metrics


def compute_timing_metrics(batch, timing_raw):
    response_info = _compute_response_info(batch)
    num_prompt_tokens = torch.sum(response_info['prompt_length']).item()
    num_response_tokens = torch.sum(response_info['response_length']).item()
    num_overall_tokens = num_prompt_tokens + num_response_tokens

    num_tokens_of_section = {
        'gen': num_response_tokens,
        **{
            name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor']
        },
    }

    return {
        **{
            f'timing_s/{name}': value for name, value in timing_raw.items()
        },
        **{
            f'timing_per_token_ms/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys(
            )) & set(timing_raw.keys())
        },
    }


class RayPRIMETrainer(RayPPOTrainer):
    """
    Note that this trainer runs on the driver process on a single CPU/GPU node.
    """

    # TODO: support each role have individual ray_worker_group_cls,
    # i.e., support different backend of different role
    def __init__(self,
                 config,
                 tokenizer,
                 role_worker_mapping: dict[Role, WorkerType],
                 resource_pool_manager: ResourcePoolManager,
                 ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
                 reward_fn=None,
                 val_reward_fn=None):

        # assert torch.cuda.is_available(), 'cuda must be available on driver'

        super().__init__(config,
                         tokenizer,
                         role_worker_mapping,
                         resource_pool_manager,
                         ray_worker_group_cls,
                         reward_fn=reward_fn,
                         val_reward_fn=val_reward_fn)

        self.use_critic = False
        self.entropy_coeff = self.config.actor_rollout_ref.actor.entropy_coeff
        if self.config.actor_rollout_ref.actor.get('entropy_type', None)=='Adaptive':
            # self.config.actor_rollout_ref.actor.entropy_coeff=0.
            self.current_entropy_coeff = 0.0
            self.effective_entropy_coeff = 0.0
            self.target_entropy = self.config.actor_rollout_ref.actor['entropy_coeff'][0]
        else:
            self.current_entropy_coeff = self.config.actor_rollout_ref.actor.entropy_coeff
            self.effective_entropy_coeff = self.config.actor_rollout_ref.actor.entropy_coeff

    def _validate_config(self):
        super()._validate_config()
        # TODO: Additional config checks can be added here
        config = self.config

    def _create_dataloader(self):
        from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
        # TODO: we have to make sure the batch size is divisible by the dp size
        self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files,
                                         tokenizer=self.tokenizer,
                                         prompt_key=self.config.data.prompt_key,
                                         max_prompt_length=self.config.data.max_prompt_length,
                                         filter_prompts=True,
                                         return_raw_chat=self.config.data.get('return_raw_chat', False),
                                         truncation=self.config.data.truncation)
        # use sampler for better ckpt resume
        if self.config.data.shuffle:
            train_dataloader_generator = torch.Generator()
            train_dataloader_generator.manual_seed(self.config.data.get('seed', 1))
            sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
        else:
            sampler = SequentialSampler(data_source=self.train_dataset)

        self.train_dataloader = DataLoader(dataset=self.train_dataset,
                                           batch_size=self.config.data.train_batch_size,
                                           drop_last=True,
                                           collate_fn=collate_fn,
                                           sampler=sampler)

        self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files,
                                       tokenizer=self.tokenizer,
                                       prompt_key=self.config.data.prompt_key,
                                       max_prompt_length=self.config.data.max_prompt_length,
                                       filter_prompts=True,
                                       return_raw_chat=self.config.data.get('return_raw_chat', False),
                                       truncation=self.config.data.truncation)
        self.val_dataloader = DataLoader(dataset=self.val_dataset,
                                         batch_size=len(self.val_dataset),
                                         shuffle=True,
                                         drop_last=True,
                                         collate_fn=collate_fn)

        assert len(self.train_dataloader) >= 1
        assert len(self.val_dataloader) >= 1

        print(f'Size of train dataloader: {len(self.train_dataloader)}')
        print(f'Size of val dataloader: {len(self.val_dataloader)}')

        # inject total_training_steps to actor/critic optim_config. This is hacky.
        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs

        if self.config.trainer.total_training_steps is not None:
            total_training_steps = self.config.trainer.total_training_steps

        self.total_training_steps = total_training_steps
        print(f'Total training steps: {self.total_training_steps}')

        OmegaConf.set_struct(self.config, True)
        with open_dict(self.config):
            self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
            self.config.critic.optim.total_training_steps = total_training_steps

    def _save_checkpoint(self):
        # path: given_path + `/global_step_{global_steps}` + `/actor`
        local_global_step_folder = os.path.join(self.config.trainer.default_local_dir,
                                                f'global_step_{self.global_steps}')
        actor_local_path = os.path.join(local_global_step_folder, 'actor')

        actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
            self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor')
        self.actor_rollout_wg.save_checkpoint(actor_local_path,
                                              actor_remote_path,
                                              self.global_steps,
                                              remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)

        if self.use_rm:
            reward_local_path = os.path.join(local_global_step_folder, 'reward')
            reward_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
                self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'reward')
            self.rm_wg.save_checkpoint(reward_local_path,
                                       reward_remote_path,
                                       self.global_steps,
                                       remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)

        # save dataloader
        dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt')
        import dill
        torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill)

        # latest checkpointed iteration tracker (for atomic usage)
        local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir,
                                                           'latest_checkpointed_iteration.txt')
        with open(local_latest_checkpointed_iteration, 'w') as f:
            f.write(str(self.global_steps))

    def _load_checkpoint(self):
        if self.config.trainer.resume_mode == 'disable':
            return 0

        # load from hdfs
        if self.config.trainer.default_hdfs_dir is not None:
            NotImplementedError('load from hdfs is not implemented yet')
        else:
            checkpoint_folder = self.config.trainer.default_local_dir  # TODO: check path
            if not os.path.isabs(checkpoint_folder):
                working_dir = os.getcwd()
                checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
            global_step_folder = find_latest_ckpt_path(checkpoint_folder)  # None if no latest

        # find global_step_folder
        if self.config.trainer.resume_mode == 'auto':
            if global_step_folder is None:
                print('Training from scratch')
                return 0
        else:
            if not (self.config.trainer.resume_from_path and global_step_folder is not None):
                assert isinstance(self.config.trainer.resume_mode, str), "resume ckpt must be str type"
                assert 'global_step_' in self.config.trainer.resume_mode, "resume ckpt must specify the global_steps"
                global_step_folder = self.config.trainer.resume_mode
                if not os.path.isabs(global_step_folder):
                    working_dir = os.getcwd()
                    global_step_folder = os.path.join(working_dir, global_step_folder)
        print(f'Load from checkpoint folder: {global_step_folder}')
        # set global step
        self.global_steps = int(global_step_folder.split('global_step_')[-1])

        print(f'Setting global step to {self.global_steps}')
        print(f'Resuming from {global_step_folder}')

        actor_path = os.path.join(global_step_folder, 'actor')
        reward_path = os.path.join(global_step_folder, 'reward')
        # load actor
        self.actor_rollout_wg.load_checkpoint(actor_path,
                                              del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
        # load rm
        if self.use_rm:
            self.rm_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load)

        # load dataloader,
        # TODO: from remote not implemented yet

        # highlight: Due to some bugs I can't fix, dataloader state will no longer be loaded

        dataloader_local_path = os.path.join(global_step_folder, 'data.pt')
        self.train_dataloader = torch.load(dataloader_local_path, weights_only=False)
        if isinstance(self.train_dataloader.dataset, RLHFDataset):
            self.train_dataloader.dataset.resume_dataset_state()

    def fit(self):
        """
        The training loop of PPO.
        The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
        The light-weight advantage computation is done on the driver process.
        """
        from verl.utils.tracking import Tracking
        from omegaconf import OmegaConf

        logger = Tracking(project_name=self.config.trainer.project_name,
                          experiment_name=self.config.trainer.experiment_name,
                          default_backend=self.config.trainer.logger,
                          config=OmegaConf.to_container(self.config, resolve=True))

        self.global_steps = 0

        # load checkpoint before doing anything
        self._load_checkpoint()

        # perform validation before training
        # currently, we only support validation using the reward_function.
        if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
            val_metrics = self._validate()
            pprint(f'Initial validation metrics: {val_metrics}')
            logger.log(data=val_metrics, step=self.global_steps)
            if self.config.trainer.get('val_only', False):
                return

        # we start from step 1
        self.global_steps += 1
        old_batches = [] # Implement simplest sample replay, just use one old batch for now
        for epoch in range(self.config.trainer.total_epochs):
            pending_batch_list = []
            for batch_dict in self.train_dataloader:
                batch: DataProto = DataProto.from_single_dict(batch_dict)
                pending_batch_list.append(batch)
                if len(pending_batch_list) < self.config.data.oversample_factor:
                    continue
                batch = DataProto.concat(pending_batch_list)
                pending_batch_list = []

                metrics = {}
                timing_raw = {}

                # change beta according to config
                if self.config.reward_model.model.get('beta_double', None) is not None:
                    if self.global_steps==1:
                        self.beta=self.config.reward_model.model.get('beta_train',0.05)
                    real_beta = self.beta * math.pow(2, self.global_steps/self.config.reward_model.model.get('beta_double', None))
                    print('beta: '+str(real_beta))
                    self.config.reward_model.model.beta_train = real_beta

                # pop those keys for generation
                if 'multi_modal_inputs' in batch.non_tensor_batch.keys():
                    gen_batch = batch.pop(
                        batch_keys=['input_ids', 'attention_mask', 'position_ids'],
                        non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs', 'raw_prompt',
                                               'data_source', 'reward_model'],
                    )
                else:
                    gen_batch = batch.pop(
                        batch_keys=['input_ids', 'attention_mask', 'position_ids'],
                        non_tensor_batch_keys=['raw_prompt_ids', 'raw_prompt', 'data_source', 'reward_model'],
                    )

                with _timer('step', timing_raw):
                    # generate a batch
                    with _timer('gen', timing_raw):
                        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

                    if self.config.algorithm.adv_estimator == 'remax':
                        with _timer('gen_max', timing_raw):
                            gen_baseline_batch = deepcopy(gen_batch)
                            gen_baseline_batch.meta_info['do_sample'] = False
                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)

                            batch = batch.union(gen_baseline_output)
                            reward_baseline_tensor = self.reward_fn(batch)
                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

                            batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))

                            batch.batch['reward_baselines'] = reward_baseline_tensor

                            del gen_baseline_batch, gen_baseline_output

                    batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
                                                             dtype=object)
                    # repeat to align with repeated responses in rollout
                    batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
                    batch = batch.union(gen_batch_output)

                    # balance the number of valid tokens on each dp rank.
                    # Note that this breaks the order of data inside the batch.
                    # Please take care when you implement group based adv computation such as GRPO and rloo
                    # self._balance_batch(batch, metrics=metrics)

                    # compute global_valid tokens
                    batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()

                    metrics_ratio = self.count_prefix_ratio(batch)
                    print(metrics_ratio)
                    metrics.update(metrics_ratio)

                    # verify
                    with _timer('verify', timing_raw):
                        n_samples = self.config.actor_rollout_ref.rollout.n
                        scores = self.reward_fn.verify(batch, n_samples=n_samples)
                        metrics['acc'] = statistics.mean(scores)
                        metrics.update(self.metric_sources(batch))

                    # filter the batch. 1/oversample_factor samples will be kept. If there is a filter, prompts passing it will be prioritized.
                    self.penalize(batch)

                    if self.config.trainer.filter_batch_for_rm:  # is this is False, we do filtration exactly before the last compute_rm_score
                        batch = self.filter_and_downsample(scores, batch)
                    metrics['oversample_factor'] = self.config.data.oversample_factor
                    batch.meta_info['n'] = self.config.actor_rollout_ref.rollout.n
                    n_samples = self.config.actor_rollout_ref.rollout.n

                    batch.meta_info['avg_response_length'] = batch.batch[
                        'attention_mask'][:, -batch.batch['responses'].shape[-1]:].sum(dim=-1).float().mean().item()

                    # recompute old_log_probs
                    with _timer('old_log_prob', timing_raw):
                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
                        batch = batch.union(old_log_prob)

                    if self.use_reference_policy:
                        # compute reference log_prob
                        with _timer('ref', timing_raw):
                            ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                            batch = batch.union(ref_log_prob)

                    with _timer('adv', timing_raw):
                        dpo_acc = 0.5
                        if self.use_rm:
                            update_style = self.config.reward_model.model.get('update', 'none')
                            if update_style == 'none':  # only run forward
                                reward_output = self.rm_wg.compute_rm_score(batch)

                            elif update_style == 'after':  # update and directly return the reward
                                reward_output = self.rm_wg.update_rm(batch)
                                dpo_acc = reward_output.meta_info['metrics']['reward_model/dpo_acc_continual_before']
                            elif update_style == 'before':  # update reward model, and then run forward
                                reward_output = self.rm_wg.update_rm(batch)
                                if 'metrics' in reward_output.meta_info.keys():
                                    reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics'])
                                    metrics.update(reward_output_metrics)

                                reward_output = self.rm_wg.compute_rm_score(batch)
                                dpo_acc = reward_output.meta_info['metrics']['reward_model/dpo_acc_continual']
                            elif update_style == 'reverse':  # run forward to calculate statistics, then update reward model
                                reward_output = self.rm_wg.compute_rm_score(batch)
                                # broadcast q and acc tensor to each result
                                bc_td = DataProto.from_dict(
                                    tensors={
                                        'Q_bc':
                                            reward_output.batch['q'].sum(dim=-1).view(-1, n_samples).unsqueeze(
                                                1).expand(-1, n_samples, -1).reshape(-1, n_samples),
                                        'acc_bc':
                                            batch.batch['acc'].view(-1, n_samples).unsqueeze(1).expand(
                                                -1, n_samples, -1).reshape(-1, n_samples)
                                    })
                                batch = batch.union(bc_td)
                                reward_output = self.rm_wg.update_rm(batch)
                                dpo_acc = reward_output.meta_info['metrics']['reward_model/dpo_acc_continual_before']
                            else:
                                raise NotImplementedError
                            batch = batch.union(reward_output)
                            if not self.config.trainer.filter_batch_for_rm:
                                batch = self.filter_and_downsample(scores, batch)
                            if 'metrics' in reward_output.meta_info.keys():
                                reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics'])
                                metrics.update(reward_output_metrics)

                        # compute advantages, executed on the driver process
                        batch = compute_advantage(batch,
                                                  adv_estimator=self.config.algorithm.adv_estimator,
                                                  config=self.config,
                                                  dpo_acc=dpo_acc)
                        if 'adv_metrics'in batch.meta_info:
                            metrics.update(batch.meta_info['adv_metrics'])

                        # compute return accuracy here to see if the reward model is working fine
                        metrics.update({
                            'reward_model/return_acc':
                                compute_return_abs_accuracy(batch.batch['returns'], batch.batch['acc']).item()
                        })

                        # check the smoothness of the return to see if the value model is working fine
                        metrics.update(
                            {'reward_model/td0_loss': compute_return_smoothness(batch.batch['returns']).item()})

                    if self.config.algorithm.adv_estimator == 'single_upv':
                        print('extracting pi_old state dict')
                        fsdp_state = self.actor_rollout_wg.get_actor_state_dict()

                    # update actor. if warmup is toggled on, skip this step
                    # because the batch might be changed, use another variable pointing at the current batch
                    batch_ = batch
                    if self.config.algorithm.get('warmup',False) == False:
                        with _timer('update_actor', timing_raw):
                            # ppo epoch logic is placed here
                            ppo_epoch=0
                            while True:
                                if self.config.actor_rollout_ref.actor.ppo_epochs>=1 and ppo_epoch>=self.config.actor_rollout_ref.actor.ppo_epochs:
                                    break
                                ppo_epoch+=1
                                batch.meta_info['entropy_coeff'] = self.effective_entropy_coeff

                                # if self.config.actor_rollout_ref.actor.get('shuffle', False): # warning: this is an in-place operation!
                                #     batch = batch.reorder(torch.randperm(len(batch)))

                                # If ppo_epoch is not 1, extract old batch for training
                                # An additional strategy: based on entropy, don't switch samples under high entropy, try to exploit quickly. Switch samples under low entropy
                                if ppo_epoch>1 and len(old_batches)>0 and actor_output_metrics['actor/entropy_loss']<self.target_entropy:
                                    batch = old_batches[-((ppo_epoch-1)%len(old_batches))-1]

                                actor_output = self.actor_rollout_wg.update_actor(batch)
                                actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
                                # Set entropy_coef based on entropy
                                if self.config.actor_rollout_ref.actor.get('entropy_type',None) == 'Adaptive':
                                    cur_entropy = actor_output_metrics['actor/entropy_loss']
                                    if cur_entropy<self.target_entropy:
                                        self.current_entropy_coeff += self.config.actor_rollout_ref.actor['entropy_coeff'][1]
                                        self.effective_entropy_coeff = self.current_entropy_coeff
                                        self.current_entropy_coeff = max(min(self.current_entropy_coeff, self.config.actor_rollout_ref.actor['entropy_coeff'][2]), 0)
                                        # self.current_entropy_coeff = max(min(self.current_entropy_coeff, 1),0)

                                        # self.config.actor_rollout_ref.actor.entropy_coeff = self.current_entropy_coeff
                                    else:
                                        self.current_entropy_coeff -= self.config.actor_rollout_ref.actor['entropy_coeff'][1]
                                        self.current_entropy_coeff = max(min(self.current_entropy_coeff, self.config.actor_rollout_ref.actor['entropy_coeff'][2]), 0)
                                        self.effective_entropy_coeff = 0
                                        # self.config.actor_rollout_ref.actor.entropy_coeff=0

                                # If ppo_epoch is not an integer, only allow exit when ppo_kl>=this value.
                                if self.config.actor_rollout_ref.actor.ppo_epochs<1 and actor_output_metrics['actor/ppo_kl_exact']>=self.config.actor_rollout_ref.actor.ppo_epochs:
                                    break
                                # Allow at most 4 epochs
                                if ppo_epoch>=self.config.actor_rollout_ref.actor.ppo_epochs_max:
                                    break
                        metrics.update(actor_output_metrics)
                        metrics['ppo_epoch']=ppo_epoch
                    batch = batch_

                    old_batches.append(batch)
                    old_batches = old_batches[-8:]

                    # Some methods need to update ref model
                    if self.config.algorithm.adv_estimator == 'single_upv':
                        print('updating reference model to pi_old')
                        self.ref_policy_wg.load_ref_state_dict(fsdp_state)

                        # validate
                    if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
                        self.global_steps % self.config.trainer.test_freq == 0:
                        with _timer('testing', timing_raw):
                            val_metrics: dict = self._validate()
                        metrics.update(val_metrics)

                    if self.config.trainer.save_freq > 0 and \
                            self.global_steps % self.config.trainer.save_freq == 0:
                        with _timer('save_checkpoint', timing_raw):
                            self._save_checkpoint()

                # collect metrics
                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))

                # TODO: make a canonical logger that supports various backend
                logger.log(data=metrics, step=self.global_steps)

                self.global_steps += 1

                if self.global_steps >= self.total_training_steps:

                    # perform validation after training
                    if self.val_reward_fn is not None:
                        val_metrics = self._validate()
                        pprint(f'Final validation metrics: {val_metrics}')
                        logger.log(data=val_metrics, step=self.global_steps)
                    if self.config.trainer.save_freq > 0 and \
                            (self.global_steps - 1) % self.config.trainer.save_freq != 0:
                        with _timer('save_checkpoint', timing_raw):
                            self._save_checkpoint()
                    return

    def count_prefix_ratio(self, batch):
        # Check what proportion of prefix is shared in the same prompt
        n_samples =self.config.actor_rollout_ref.rollout.n
        responses = batch.batch['responses']
        response_len = responses.size(1)
        attention_mask = batch.batch['attention_mask'][:, -response_len: ]
        prefix_mask = torch.zeros_like(attention_mask, dtype=torch.bool)
        for start_pos in range(0, len(batch), n_samples):
            for i in range(n_samples):
                for j in range(i+1,n_samples):
                    prefix_len = (responses[start_pos+i] == responses[start_pos+j]).cumprod(dim=0).sum()
                    prefix_mask[[start_pos+i,start_pos+j],:prefix_len] = True
        prefix_mask[attention_mask==0]=False
        return {
            'prefix_ratio': prefix_mask.sum().item() / attention_mask.sum().item()
        }

    def metric_sources(self, batch):
        # Separately count acc for data from different sources
        sources = batch.non_tensor_batch['data_source']
        acc = batch.batch['acc'].cpu().tolist()
        metrics = defaultdict(list)
        for a,s in zip(acc, sources):
            key_name = 'train_acc/'+s
            metrics[key_name].append(a)

        for k,v in metrics.items():
            metrics[k] = statistics.mean(v)
        return metrics

    def penalize(self, batch):
        if self.config.data.penalty is None:
            self.config.data.penalty = []
        response_ids = batch.batch['responses']
        attention_mask = batch.batch['attention_mask'][:, -response_ids.shape[-1]:]
        sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)
        abilities = batch.non_tensor_batch['ability']
        for strategy in self.config.data.penalty:
            if strategy == 'instruction_error':
                for i in range(len(batch)):
                    if abilities[i] == 'math':
                        if sequences_str[i].find('print(')>0:
                            batch.batch['acc'][i] /=2

            elif strategy == 'multi_language':
                pass
            elif strategy == 'repetition':
                for i in range(len(batch)):
                    N=5
                    response_id_list = response_ids[i][attention_mask[i]].cpu().tolist()
                    if len(response_id_list)<N:
                        continue
                    unique_5grams = set()
                    for j in range(len(response_id_list)-N+1):
                        unique_5grams.add(tuple(response_id_list[j:j+5]))
                    uniqueness = len(unique_5grams) / (len(response_id_list)-N+1)
                    if uniqueness<0.3:
                        batch.batch['acc'][i] /= 2
            else:
                raise NotImplementedError
        return batch