# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Single Process Actor
"""

import logging
import os

import torch
import pdb
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty, calculate_f_divergence_loss
from verl.trainer.ppo.ray_trainer import TrainingType
from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available
from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_
from verl.utils.profiler import GPUMemoryLogger
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch
from verl.utils.torch_functional import logprobs_from_logits
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs
from verl.workers.actor import BasePPOActor
from verl.workers.config import ActorConfig

if is_cuda_available:
    from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
elif is_npu_available:
    from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input


__all__ = ["DataParallelPPOActor"]

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


class DataParallelPPOActor(BasePPOActor):
    """FSDP DataParallel PPO Actor or Ref worker

    Args:
        config (ActorConfig): Actor config
        actor_module (nn.Module): Actor or ref module
        actor_optimizer (torch.optim.Optimizer, optional): Actor optimizer. Defaults to None.
    """

    def __init__(self, config: ActorConfig, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None):
        """When optimizer is None, it is Reference Policy"""
        super().__init__(config)
        self.actor_module = actor_module
        self.actor_optimizer = actor_optimizer
        role = "Ref" if actor_optimizer is None else "Actor"

        self.use_remove_padding = self.config.get("use_remove_padding", False)
        if torch.distributed.get_rank() == 0:
            print(f"{role} use_remove_padding={self.use_remove_padding}")
        self.use_fused_kernels = self.config.get("use_fused_kernels", False)
        if torch.distributed.get_rank() == 0:
            print(f"{role} use_fused_kernels={self.use_fused_kernels}")

        self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
        self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1

        if self.config.entropy_from_logits_with_chunking:
            entropy_from_logits = verl_F.entropy_from_logits_with_chunking
        else:
            entropy_from_logits = verl_F.entropy_from_logits

        self.compute_entropy_from_logits = (
            torch.compile(entropy_from_logits, dynamic=True)
            if self.config.get("use_torch_compile", True)  #  use torch compile by default
            else entropy_from_logits
        )
        self.device_name = get_device_name()

    def _forward_micro_batch(
        self, micro_batch, temperature, calculate_entropy=False, return_logits=False
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
        """
        Returns:
            entropy: # (bs, response_len)
            log_probs: # (bs, response_len)
            logits: # (bs, response_len, vocab_size) if return_logits=True, else None
        """
        response_length = micro_batch["responses"].size(-1)
        multi_modal_inputs = {}
        if "multi_modal_inputs" in micro_batch.keys():
            if "image_bound" in micro_batch["multi_modal_inputs"][0]:  # minicpm-o logic
                for key in micro_batch["multi_modal_inputs"][0].keys():
                    multi_modal_inputs[key] = [inputs[key] for inputs in micro_batch["multi_modal_inputs"]]
            else:
                for key in micro_batch["multi_modal_inputs"][0].keys():
                    multi_modal_inputs[key] = torch.cat(
                        [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
                    )
        
        output_logits = None  # Initialize logits output

        with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):
            input_ids = micro_batch["input_ids"]
            batch_size, seqlen = input_ids.shape
            attention_mask = micro_batch["attention_mask"]
            position_ids = micro_batch["position_ids"]
            entropy = None
            if position_ids.dim() == 3:  # qwen2vl mrope
                position_ids = position_ids.transpose(0, 1)  # (bsz, 3, seqlen) -> (3, bsz, seqlen)

            if self.use_remove_padding:
                input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(
                    input_ids.unsqueeze(-1), attention_mask
                )  # input_ids_rmpad (total_nnz, ...)
                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)

                # unpad the position_ids to align the rotary
                if position_ids.dim() == 3:
                    position_ids_rmpad = (
                        index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
                        .transpose(0, 1)
                        .unsqueeze(1)
                    )  # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
                else:
                    position_ids_rmpad = index_first_axis(
                        rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
                    ).transpose(0, 1)

                if "image_bound" in multi_modal_inputs:
                    from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo

                    multi_modal_inputs = process_multi_modal_inputs_for_minicpmo(
                        input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs
                    )

                # for compute the log_prob
                input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)

                # pad and slice the inputs if sp > 1
                if self.use_ulysses_sp:
                    is_vlm_model = "multi_modal_inputs" in micro_batch.keys()
                    if is_vlm_model:
                        # vlm model's inputs will be sliced after embedding
                        input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(
                            input_ids_rmpad,
                            position_ids_rmpad=position_ids_rmpad,
                            sp_size=self.ulysses_sequence_parallel_size,
                        )
                    else:
                        input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
                            input_ids_rmpad,
                            position_ids_rmpad=position_ids_rmpad,
                            sp_size=self.ulysses_sequence_parallel_size,
                        )
                    input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
                        input_ids_rmpad_rolled,
                        position_ids_rmpad=None,
                        sp_size=self.ulysses_sequence_parallel_size,
                    )

                input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)  # ((total_nnz / sp) + pad)

                # only pass input_ids and position_ids to enable flash_attn_varlen
                extra_args = {}
                if self.use_fused_kernels:
                    extra_args["temperature"] = temperature
                    extra_args["return_dict"] = True

                output = self.actor_module(
                    input_ids=input_ids_rmpad,
                    attention_mask=None,
                    position_ids=position_ids_rmpad,
                    **multi_modal_inputs,
                    use_cache=False,
                    **extra_args,
                )  # prevent model thinks we are generating

                if self.use_fused_kernels:
                    log_probs = output.log_probs.squeeze(0)  # (total_nnz,)
                    entropy_rmpad = output.entropy.squeeze(0)  # (total_nnz,)

                else:
                    logits_rmpad = output.logits.squeeze(0)  # (total_nnz, vocab_size)
                    logits_rmpad.div_(temperature)

                    # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
                    inplace_backward = True
                    if calculate_entropy:
                        inplace_backward = False
                    log_probs = logprobs_from_logits(
                        logits=logits_rmpad,
                        labels=input_ids_rmpad_rolled,
                        inplace_backward=inplace_backward,
                    )

                    # compute entropy
                    if calculate_entropy:
                        if not self.config.entropy_checkpointing:
                            entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad)  # ((total_nnz / sp) + pad)
                        else:
                            entropy_rmpad = torch.utils.checkpoint.checkpoint(
                                self.compute_entropy_from_logits, logits_rmpad
                            )

                # gather log_prob if sp > 1
                if self.use_ulysses_sp:
                    # gather and unpad for the ulysses sp
                    log_probs = gather_outputs_and_unpad(
                        log_probs,
                        gather_dim=0,
                        unpad_dim=0,
                        padding_size=pad_size,
                    )
                    if calculate_entropy:
                        entropy_rmpad = gather_outputs_and_unpad(
                            entropy_rmpad,
                            gather_dim=0,
                            unpad_dim=0,
                            padding_size=pad_size,
                        )
                # pad back to (bsz, seqlen)
                if calculate_entropy:
                    full_entropy = pad_input(
                        hidden_states=entropy_rmpad.unsqueeze(-1),
                        indices=indices,
                        batch=batch_size,
                        seqlen=seqlen,
                    )
                full_log_probs = pad_input(
                    hidden_states=log_probs.unsqueeze(-1),
                    indices=indices,
                    batch=batch_size,
                    seqlen=seqlen,
                )

                # only return response part:
                if calculate_entropy:
                    entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1]  # (bsz, response_length)
                log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1]  # (bsz, response_length)
                
                # Note: logits are not easily recoverable in rmpad mode, return None
                # APO will fall back to standard loss if logits unavailable
                if return_logits:
                    output_logits = None  # rmpad mode doesn't support returning logits efficiently

            else:  # not using rmpad and no ulysses sp
                extra_args = {}
                if self.use_fused_kernels:
                    extra_args["temperature"] = temperature
                    extra_args["return_dict"] = True

                output = self.actor_module(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    **multi_modal_inputs,
                    use_cache=False,
                    **extra_args,
                )  # prevent model thinks we are generating

                if self.use_fused_kernels:
                    log_probs = output.log_probs[:, -response_length - 1 : -1]
                    entropy = output.entropy[:, -response_length - 1 : -1]  # (bsz, response_length)
                    if return_logits:
                        output_logits = None  # fused kernels don't return raw logits

                else:
                    logits = output.logits

                    logits.div_(temperature)
                    logits = logits[:, -response_length - 1 : -1, :]  # (bsz, response_length, vocab_size)
                    log_probs = logprobs_from_logits(logits, micro_batch["responses"])
                    if calculate_entropy:
                        if not self.config.entropy_checkpointing:
                            entropy = verl_F.entropy_from_logits(logits)  # (bsz, response_length)
                        else:
                            entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits)
                    
                    # Store logits for APO if requested
                    if return_logits:
                        output_logits = logits  # (bsz, response_length, vocab_size)

            return entropy, log_probs, output_logits

    def _optimizer_step(self):
        assert self.config.grad_clip is not None

        if isinstance(self.actor_module, FSDP):
            grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip)
        elif isinstance(self.actor_module, FSDPModule):
            grad_norm = fsdp2_clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)
        else:
            grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)

        # if grad_norm is not finite, skip the update
        if not torch.isfinite(grad_norm):
            print(f"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}")
            self.actor_optimizer.zero_grad()
        else:
            self.actor_optimizer.step()
        return grad_norm

    @GPUMemoryLogger(role="dp actor", logger=logger)
    def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:
        """Compute the log probability of the responses given input_ids, attention_mask and position_ids

        Args:
            data (DataProto): a DataProto containing keys

                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
                concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.

                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.

                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.

                ``responses``:  tensor of shape [batch_size, response_length]. torch.int64.

        Returns:
            torch.Tensor: the log_prob tensor
        """
        # set to eval
        self.actor_module.eval()

        micro_batch_size = data.meta_info["micro_batch_size"]
        temperature = data.meta_info["temperature"]  # temperature must be in the data.meta_info to avoid silent error
        use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
        has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
        select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
        non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []

        data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)

        if use_dynamic_bsz:
            max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
            micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len)
        else:
            micro_batches = data.split(micro_batch_size)

        log_probs_lst = []
        entropy_lst = []
        for micro_batch in micro_batches:
            micro_batch = micro_batch.to(get_device_id())
            model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
            with torch.no_grad():
                entropy, log_probs, _ = self._forward_micro_batch(
                    model_inputs, temperature=temperature, calculate_entropy=calculate_entropy, return_logits=False
                )
            log_probs_lst.append(log_probs)
            if calculate_entropy:
                entropy_lst.append(entropy)

        log_probs = torch.concat(log_probs_lst, dim=0)
        entropys = None
        if calculate_entropy:
            entropys = torch.concat(entropy_lst, dim=0)

        if use_dynamic_bsz:
            log_probs = restore_dynamic_batch(log_probs, batch_idx_list)
            if calculate_entropy:
                entropys = restore_dynamic_batch(entropys, batch_idx_list)

        return log_probs, entropys

    @GPUMemoryLogger(role="dp actor", logger=logger)
    def compute_log_prob_and_topk_for_apo(
        self, data: DataProto, apo_topk: int = 5, apo_exclude_sampled: bool = True
    ) -> DataProto:
        """Compute log probability and Top-K token info for APO (memory-optimized).
        
        This method computes:
        - ref_log_prob: log probabilities for sampled tokens
        - ref_topk_indices: indices of Top-K tokens from this model
        - ref_topk_weights: normalized weights for Top-K tokens
        - ref_topk_log_probs: log probabilities of Top-K tokens (for apo_ratio Method II)
        
        This is much more memory-efficient than returning full logits, as it only
        returns (bs, response_length, K) instead of (bs, response_length, vocab_size).

        Args:
            data (DataProto): a DataProto containing keys
                ``input_ids``, ``attention_mask``, ``position_ids``, ``responses``
            apo_topk (int): Number of Top-K tokens to return
            apo_exclude_sampled (bool): Whether to exclude sampled token from Top-K

        Returns:
            DataProto with tensors:
                - ref_log_prob: (batch_size, response_length)
                - ref_topk_indices: (batch_size, response_length, K)
                - ref_topk_weights: (batch_size, response_length, K)
                - ref_topk_log_probs: (batch_size, response_length, K) - log π_ref for Top-K tokens
        """
        # set to eval
        self.actor_module.eval()

        micro_batch_size = data.meta_info["micro_batch_size"]
        temperature = data.meta_info["temperature"]
        use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
        has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
        select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
        non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []

        data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)

        if use_dynamic_bsz:
            max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
            micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len)
        else:
            micro_batches = data.split(micro_batch_size)

        log_probs_lst = []
        topk_indices_lst = []
        topk_weights_lst = []
        topk_log_probs_lst = []  # For apo_ratio: log π_ref(y_k)
        
        for micro_batch in micro_batches:
            micro_batch = micro_batch.to(get_device_id())
            model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
            responses = model_inputs["responses"]  # (micro_bs, response_length)
            
            with torch.no_grad():
                _, log_probs, logits = self._forward_micro_batch(
                    model_inputs, temperature=temperature, calculate_entropy=False, return_logits=True
                )
                log_probs_lst.append(log_probs)
                
                if logits is not None:
                    # Compute Top-K from logits
                    # logits: (micro_bs, response_length, vocab_size)
                    probs = torch.softmax(logits, dim=-1)
                    log_probs_full = torch.log_softmax(logits, dim=-1)  # For ref_topk_log_probs
                    
                    if apo_exclude_sampled:
                        # Mask out sampled tokens before computing Top-K
                        probs_masked = probs.clone()
                        responses_expanded = responses.unsqueeze(-1)  # (micro_bs, response_length, 1)
                        probs_masked.scatter_(-1, responses_expanded, 0.0)
                        topk_probs, topk_indices = torch.topk(probs_masked, k=apo_topk, dim=-1)
                    else:
                        topk_probs, topk_indices = torch.topk(probs, k=apo_topk, dim=-1)
                    
                    # Normalize to get weights
                    topk_weights = topk_probs / (topk_probs.sum(dim=-1, keepdim=True) + 1e-8)
                    
                    # Gather log probs for Top-K tokens (needed for apo_ratio)
                    # topk_log_probs: (micro_bs, response_length, K)
                    topk_log_probs = torch.gather(log_probs_full, dim=-1, index=topk_indices)
                    
                    topk_indices_lst.append(topk_indices)
                    topk_weights_lst.append(topk_weights)
                    topk_log_probs_lst.append(topk_log_probs)

        log_probs = torch.concat(log_probs_lst, dim=0)
        
        result = {"ref_log_prob": log_probs}
        
        if len(topk_indices_lst) > 0:
            topk_indices = torch.concat(topk_indices_lst, dim=0)
            topk_weights = torch.concat(topk_weights_lst, dim=0)
            topk_log_probs = torch.concat(topk_log_probs_lst, dim=0)  # For apo_ratio
            
            if use_dynamic_bsz:
                log_probs = restore_dynamic_batch(log_probs, batch_idx_list)
                topk_indices = restore_dynamic_batch(topk_indices, batch_idx_list)
                topk_weights = restore_dynamic_batch(topk_weights, batch_idx_list)
                topk_log_probs = restore_dynamic_batch(topk_log_probs, batch_idx_list)
            
            result["ref_log_prob"] = log_probs
            result["ref_topk_indices"] = topk_indices
            result["ref_topk_weights"] = topk_weights
            result["ref_topk_log_probs"] = topk_log_probs  # log π_ref(y_k) for apo_ratio
        elif use_dynamic_bsz:
            log_probs = restore_dynamic_batch(log_probs, batch_idx_list)
            result["ref_log_prob"] = log_probs

        return DataProto.from_dict(tensors=result)

    @GPUMemoryLogger(role="dp actor", logger=logger)
    def update_policy(self, data: DataProto):
        # make sure we are in training mode
        self.actor_module.train()

        temperature = data.meta_info["temperature"]  # temperature must be in the data.meta_info to avoid silent error

        select_keys = [
            "responses",
            "response_mask",
            "input_ids",
            "attention_mask",
            "position_ids",
            "old_log_probs",
            "type",
            "advantages",
        ]
        if self.config.use_kl_loss or self.config.use_kl_loss_on_wrong:
            select_keys.append("ref_log_prob")
        if self.config.tis_imp_ratio_cap > 0:
            assert "rollout_log_probs" in data.batch.keys(), (
                "Truncated Importance Sampling (TIS) requires to configure "
                "`actor_rollout_ref.rollout.calculate_log_probs=True` "
                "and is not currently supported in Server mode (agent loop)."
            )
            select_keys.append("rollout_log_probs")
        
        # Check if APO is enabled and needs ref Top-K info
        loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
        use_apo = loss_mode == "apo_ratio"  # Only apo_ratio supported in public release
        if use_apo:
            # Use memory-optimized Top-K format
            if "ref_topk_indices" in data.batch.keys():
                select_keys.append("ref_topk_indices")
            if "ref_topk_weights" in data.batch.keys():
                select_keys.append("ref_topk_weights")
            # apo_ratio also needs ref_topk_log_probs for importance ratio computation
            if loss_mode == "apo_ratio" and "ref_topk_log_probs" in data.batch.keys():
                select_keys.append("ref_topk_log_probs")

        has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
        non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []

        data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)

        # Split to make minibatch iterator for updating the actor
        # See PPO paper for details. https://arxiv.org/abs/1707.06347
        mini_batches = data.split(self.config.ppo_mini_batch_size)

        on_policy = len(mini_batches) == 1 and self.config.ppo_epochs == 1

        metrics = {}
        for _ in range(self.config.ppo_epochs):
            for batch_idx, mini_batch in enumerate(mini_batches):
                if self.config.use_dynamic_bsz:
                    max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
                    micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len)
                else:
                    self.gradient_accumulation = (
                        self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
                    )
                    micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)

                self.actor_optimizer.zero_grad()

                for micro_batch in micro_batches:
                    micro_batch = micro_batch.to(get_device_id())
                    micro_batch_metrics = {}
                    model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
                    response_mask = model_inputs["response_mask"]
                    old_log_prob = model_inputs["old_log_probs"]
                    data_type = model_inputs['type'] if 'type' in model_inputs else None
                    rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.tis_imp_ratio_cap > 0 else None
                    advantages = model_inputs["advantages"]

                    entropy_coeff = self.config.entropy_coeff
                    loss_agg_mode = self.config.loss_agg_mode

                    if self.config.use_dynamic_bsz:
                        loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size
                    else:
                        loss_scale_factor = 1 / self.gradient_accumulation

                    # all return: (bsz, response_length) for entropy and log_prob
                    # logits: (bsz, response_length, vocab_size) if return_logits=True
                    calculate_entropy = False
                    if entropy_coeff != 0:
                        calculate_entropy = True
                    
                    # For APO, we need to get logits
                    return_logits = use_apo
                    entropy, log_prob, logits = self._forward_micro_batch(
                        model_inputs, temperature=temperature, calculate_entropy=calculate_entropy,
                        return_logits=return_logits
                    )
                    
                    # Get ref Top-K info for APO (memory-optimized format)
                    ref_topk_indices = model_inputs.get("ref_topk_indices", None) if use_apo else None
                    ref_topk_weights = model_inputs.get("ref_topk_weights", None) if use_apo else None
                    # ref_topk_log_probs needed for apo_ratio (Method II) to compute importance ratio
                    ref_topk_log_probs = model_inputs.get("ref_topk_log_probs", None) if use_apo else None

                    loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
                    sft_loss_mode = self.config.get("sft_loss_mode", "forward")

                    micro_batch_metrics[f"actor/{sft_loss_mode}"] = 0.0

                    # Handle mixed batch types: separate SFT and RL data
                    if data_type is not None:
                        # Create masks for different data types
                        sft_mask = (data_type == TrainingType.SFT_TYPE.value)
                        rl_mask = (data_type == TrainingType.RL_TYPE.value)
                        
                        # Initialize loss components
                        device = log_prob.device
                        pg_loss_total = torch.tensor(0.0, device=device)
                        pg_clipfrac = torch.tensor(0.0, device=device)
                        ppo_kl = torch.tensor(0.0, device=device)
                        pg_clipfrac_lower = torch.tensor(0.0, device=device)
                        
                        # Process SFT data if present
                        if torch.any(sft_mask):
                            # Create SFT response mask by zeroing out RL data
                            sft_response_mask = response_mask.clone()
                            sft_response_mask[rl_mask] = 0  # Zero out RL data
                            
                            # Check if SFT mask has any valid tokens
                            
                            if torch.any(sft_response_mask):
                                use_js = True
                                
                                if use_js:  
                                    sft_pg_loss = calculate_f_divergence_loss(
                                        sft_loss_mode="js_low_var", 
                                        old_log_prob=old_log_prob,
                                        log_prob=log_prob,
                                        response_mask=sft_response_mask,
                                        loss_agg_mode=loss_agg_mode,
                                        config=self.config)
                                    micro_batch_metrics[f"actor/js"] = sft_pg_loss.detach().item()
                                    pg_loss_total += sft_pg_loss
                                else:
                                    # Compute MSE loss between log_prob and old_log_prob
                                    mse_loss = torch.nn.functional.mse_loss(log_prob, old_log_prob, reduction='none')
                                    sft_pg_loss = agg_loss(loss_mat=mse_loss, loss_mask=sft_response_mask, loss_agg_mode=loss_agg_mode)
                                    micro_batch_metrics[f"actor/{sft_loss_mode}"] = sft_pg_loss.detach().item()
                                    pg_loss_total += sft_pg_loss
                            else:
                                # No valid SFT tokens, skip SFT loss
                                micro_batch_metrics[f"actor/js"] = 0.0
                        
                        # Process RL data if present
                        if torch.any(rl_mask):
                            # Create RL response mask by zeroing out SFT data
                            rl_response_mask = response_mask.clone()
                            rl_response_mask[sft_mask] = 0  # Zero out SFT data
                            
                            # Check if RL mask has any valid tokens
                            if torch.any(rl_response_mask):
                                # Compute RL policy loss
                                # Build kwargs for policy loss function
                                policy_loss_kwargs = {
                                    "old_log_prob": old_log_prob,
                                    "log_prob": log_prob,
                                    "advantages": advantages,
                                    "response_mask": rl_response_mask,
                                    "loss_agg_mode": loss_agg_mode,
                                    "config": self.config,
                                    "rollout_log_probs": rollout_log_probs if rollout_log_probs is not None else None,
                                }
                                # Add APO-specific kwargs if using APO
                                if use_apo:
                                    policy_loss_kwargs["logits"] = logits
                                    policy_loss_kwargs["ref_topk_indices"] = ref_topk_indices
                                    policy_loss_kwargs["ref_topk_weights"] = ref_topk_weights
                                    policy_loss_kwargs["responses"] = model_inputs["responses"]
                                    # apo_ratio needs ref_topk_log_probs for ratio computation
                                    if loss_mode == "apo_ratio":
                                        policy_loss_kwargs["ref_topk_log_probs"] = model_inputs.get("ref_topk_log_probs", None)
                                
                                rl_pg_loss, rl_pg_clipfrac, rl_ppo_kl, rl_pg_clipfrac_lower = get_policy_loss_fn(loss_mode)(
                                    **policy_loss_kwargs
                                )
                                
                                # Accumulate RL losses
                                pg_loss_total += rl_pg_loss
                                pg_clipfrac += rl_pg_clipfrac
                                ppo_kl += rl_ppo_kl
                                pg_clipfrac_lower += rl_pg_clipfrac_lower
                            else:
                                # No valid RL tokens, set RL metrics to zero
                                pass
                        
                        # Final loss is the sum of both types
                        pg_loss = pg_loss_total
                        
                    else:
                        if on_policy:
                            old_log_prob = log_prob.detach()
                        else:
                            old_log_prob = model_inputs["old_log_probs"]

                        # vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla
                        # gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
                        # clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
                        # apo -> verl.trainer.ppo.core_algos.compute_policy_loss_apo (Anchored Policy Optimization)
                        policy_loss_fn = get_policy_loss_fn(loss_mode)
                        
                        # Build kwargs for policy loss function
                        policy_loss_kwargs = {
                            "old_log_prob": old_log_prob,
                            "log_prob": log_prob,
                            "advantages": advantages,
                            "response_mask": response_mask,
                            "loss_agg_mode": loss_agg_mode,
                            "config": self.config,
                            "rollout_log_probs": rollout_log_probs,
                        }
                        # Add APO-specific kwargs if using APO
                        if use_apo:
                            policy_loss_kwargs["logits"] = logits
                            policy_loss_kwargs["ref_topk_indices"] = ref_topk_indices
                            policy_loss_kwargs["ref_topk_weights"] = ref_topk_weights
                            policy_loss_kwargs["responses"] = model_inputs["responses"]
                            # apo_ratio needs ref_topk_log_probs for ratio computation
                            if loss_mode == "apo_ratio":
                                policy_loss_kwargs["ref_topk_log_probs"] = model_inputs.get("ref_topk_log_probs", None)
                        
                        pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(**policy_loss_kwargs)

                    if entropy_coeff != 0:
                        entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

                        # compute policy loss
                        policy_loss = pg_loss - entropy_loss * entropy_coeff
                    else:
                        policy_loss = pg_loss

                    if self.config.use_kl_loss:
                        ref_log_prob = model_inputs["ref_log_prob"]
                        # compute kl loss
                        kld = kl_penalty(
                            logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type
                        )
                        kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

                        policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
                        micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() * loss_scale_factor
                        micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef
                    
                    # KL loss only on wrong samples (negative advantage) - for APO ablation
                    if self.config.use_kl_loss_on_wrong:
                        ref_log_prob = model_inputs["ref_log_prob"]
                        # compute kl loss
                        kld = kl_penalty(
                            logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type
                        )
                        # Create mask for negative advantage samples (wrong samples)
                        negative_adv_mask = (advantages < 0).float() * response_mask
                        # Apply KL loss only on wrong samples
                        kl_loss_on_wrong = agg_loss(loss_mat=kld, loss_mask=negative_adv_mask, loss_agg_mode=loss_agg_mode)
                        
                        policy_loss = policy_loss + kl_loss_on_wrong * self.config.kl_loss_coef
                        micro_batch_metrics["actor/kl_loss_on_wrong"] = kl_loss_on_wrong.detach().item() * loss_scale_factor
                        micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef

                    if self.config.use_dynamic_bsz:
                        # relative to the dynamic bsz
                        loss = policy_loss * loss_scale_factor
                    else:
                        loss = policy_loss * loss_scale_factor
                    loss.backward()

                    micro_batch_metrics.update(
                        {
                            "actor/pg_loss": pg_loss.detach().item() * loss_scale_factor,
                            "actor/pg_clipfrac": pg_clipfrac.detach().item(),
                            "actor/ppo_kl": ppo_kl.detach().item(),
                            "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
                        }
                    )
                    append_to_dict(metrics, micro_batch_metrics)

                grad_norm = self._optimizer_step()
                mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()}
                append_to_dict(metrics, mini_batch_metrics)
        self.actor_optimizer.zero_grad()
        return metrics
