# Copyright (c) 2025, HUAWEI CORPORATION.  All rights reserved.
from typing import Dict, Tuple

import torch

from mindspeed_rl.models.loss.loss_func_factory import LossFuncFactory
from mindspeed_rl.models.loss.base_loss_func import BaseLossFunc
from mindspeed_rl.utils.compute import compute_kl_penalty
from mindspeed_rl.utils.utils import generate_mask
import mindspeed_rl.utils.torch_functional as F
from mindspeed_rl.utils.utils import MsProbe


@LossFuncFactory.register_loss('ray_grpo', 'actor')
class GRPOActorLossFunc(BaseLossFunc):
    def __init__(self):
        super().__init__()
        self.clip_ratio = 0.2
        self.entropy_coeff = 0.0
        self.target_entropy = 0.01  # 设置目标熵下限

    def add_loss_meta_info(self, meta_info: Dict):
        if meta_info is None:
            return
        if "clip_ratio" in meta_info.keys():
            self.clip_ratio = float(meta_info["clip_ratio"])
        if "kl_ctrl" in meta_info.keys():
            self.kl_ctrl = meta_info["kl_ctrl"]
        if "entropy_coeff" in meta_info.keys():
            self.entropy_coeff = meta_info["entropy_coeff"]
        if "kl_penalty" in meta_info.keys():
            self.kl_penalty = meta_info["kl_penalty"]
        if "target_entropy" in meta_info.keys():
            self.target_entropy = meta_info["target_entropy"]  # 允许从元信息设置目标熵

    @staticmethod
    def _get_policy_loss_input(batch: Dict[str, torch.Tensor]):
        if 'responses' not in batch:
            raise ValueError("The responses is None")
        response_mask = generate_mask(batch['responses'], batch['response_length']).npu()
        old_log_prob = batch['old_log_prob'] if 'old_log_prob' in batch else None
        advantages = batch['advantages'] if 'advantages' in batch else None
        ref_log_prob = batch['ref_log_prob'] if 'ref_log_prob' in batch else None
        return response_mask, old_log_prob, advantages, ref_log_prob

    def compute_loss(self, output: torch.Tensor,
                     batch: Dict[str, torch.Tensor],
                     forward_only=False,
                     non_loss_data=True,
                     **kwargs) -> Tuple[torch.Tensor, Dict]:
        """
        计算损失函数，子类必须实现。
        :param output: 模型的输出 logits。
        :param batch: 输入数据，包含 responses、attention_mask 等。
        :param forward_only: 是否只进行前向计算。
        :return: 损失值和统计信息。
        """
        # compute log probs
        if forward_only:
            log_probs, _ = super().compute_log_probs(output=output, batch=batch, **kwargs)
            return log_probs
        log_probs, entropy = super().compute_log_probs(output=output, batch=batch, skip_entropy=(self.entropy_coeff == 0), **kwargs)

        response_mask, old_log_prob, advantages, ref_log_prob = self._get_policy_loss_input(batch=batch)
        
        # Compute per-sample entropy (sample-level control)
        # Calculate entropy for each sample by averaging over sequence length
        per_sample_entropy = F.masked_mean(entropy, response_mask, axis=-1)  # shape: (bs,)
        target = torch.tensor(self.target_entropy, dtype=per_sample_entropy.dtype, device=per_sample_entropy.device)
        
        # 计算每个样本的熵与目标熵的差值
        entropy_diff = per_sample_entropy - target
        
        # 计算当前平均熵（用于目标熵控制）
        current_entropy = (entropy * response_mask).sum() / response_mask.sum()  # scalar
        # 指示函数：当当前熵 ≤ 目标熵时为1，否则为0
        entropy_diff = target - per_sample_entropy

        # 使用 relu 激活损失项（shape: (bs,)）
        per_sample_entropy_loss = self.entropy_coeff * torch.relu(entropy_diff)
        # 构造 entropy loss: 当 e_k ≤ tgt-ent 时激活，否则为0
        # 熵损失项 = β * 指示函数
        entropy_loss_term = per_sample_entropy_loss.mean()
        use_target_entropy=True if self.target_entropy > 0 else False
        # compute policy loss with per-sample entropy control
        pg_loss, pg_clipfrac, ppo_kl, kl_loss, policy_entropy_ratio = self._compute_grpo_policy_loss_per_sample(
            old_log_prob=old_log_prob,
            log_prob=log_probs,
            ref_log_prob=ref_log_prob,
            advantages=advantages,
            entropy=entropy,
            eos_mask=response_mask,
            cliprange=self.clip_ratio,
            kl_ctrl=self.kl_ctrl,
            kl_penalty=self.kl_penalty,
            entropy_coeff=self.entropy_coeff,
            per_sample_entropy=per_sample_entropy,
            entropy_loss_term=entropy_loss_term,
            use_target_entropy=use_target_entropy)

        use_dynamic_bsz = kwargs.get('use_dynamic_bsz', False)
        actual_micro_batch_size = kwargs.get('actual_micro_batch_size', None)
        if use_dynamic_bsz and not forward_only:
            policy_loss = pg_loss * (batch['responses'].size(0) / actual_micro_batch_size)
        else:
            policy_loss = pg_loss

        data_tobe_saved = {
            "old_log_prob": old_log_prob,
            "log_prob": log_probs,
            "ref_log_prob": ref_log_prob,
            "advantages": advantages,
            "loss": pg_loss,
            "kl_loss": kl_loss,
            "per_sample_entropy": per_sample_entropy,
            "current_entropy": current_entropy,
            "target_entropy": self.target_entropy,
            # "entropy_indicator": indicator if use_target_entropy else None
        }
        MsProbe.save_data(data_tobe_saved)

        stats = {
            'actor/pg_loss': abs(pg_loss.detach().item()),
            'actor/pg_clipfrac': pg_clipfrac.detach().item(),
            'actor/ppo_kl': ppo_kl.detach().item(),
            'actor/kl_loss': kl_loss.detach().item(),
            # 'actor/entropy': entropy_loss.detach().item(),
            'actor/per_sample_entropy_min': per_sample_entropy.min().detach().item(),
            'actor/per_sample_entropy_max': per_sample_entropy.max().detach().item(),
            'actor/per_sample_entropy_mean': per_sample_entropy.mean().detach().item(),
            'actor/policy_entropy_ratio': policy_entropy_ratio.detach().item(),
            'actor/entropy_loss_term': entropy_loss_term.detach().item(),
            # 'actor/entropy_indicator': indicator.detach().item() if use_target_entropy else 0,
            'actor/entropy_below_target': (current_entropy <= self.target_entropy).detach().item()
        }
        return policy_loss, stats

    @staticmethod
    def _compute_grpo_policy_loss_per_sample(old_log_prob, log_prob, ref_log_prob, advantages, entropy, eos_mask, cliprange, kl_ctrl, kl_penalty, entropy_coeff, per_sample_entropy, entropy_loss_term, use_target_entropy):
        """
        Args:
            old_log_prob: `(torch.Tensor)`
                shape: (bs, response_length)
            log_prob: `(torch.Tensor)`
                shape: (bs, response_length)
            ref_log_prob `(torch.Tensor)`
                shape: (bs, response_length)
            advantages: `(torch.Tensor)`
                shape: (bs, response_length)
            entropy: `(torch.Tensor)`
                shape: (bs, response_length)
            eos_mask: `(torch.Tensor)`
                shape: (bs, response_length)
            cliprange: (float)
                The clip range used in GRPO.
            kl_ctrl: (float)
                The kL value
            per_sample_entropy: `(torch.Tensor)`
                shape: (bs,) per-sample entropy values
            entropy_loss_term: scalar, 已计算好的熵惩罚项（β · I{e_k ≤ tgt-ent}）
            use_target_entropy: bool, 是否使用目标熵机制

        Returns:
            pg_loss: `a scalar torch.Tensor`
                policy gradient loss computed via GRPO
            pg_clipfrac: (float)
                a float number indicating the fraction of policy gradient loss being clipped
            policy_entropy_ratio: (float)
                ratio of entropy term to policy gradient loss
        """
        if old_log_prob is None:
            old_log_prob = log_prob.detach().clone()
        negative_approx_kl = log_prob - old_log_prob
        ratio = torch.exp(negative_approx_kl)
        ppo_kl = F.masked_mean(-negative_approx_kl, eos_mask)

        pg_losses = -advantages * ratio
        pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)

        # Compute per-sample policy gradient losses
        per_sample_pg_losses = F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask, axis=-1)  # shape: (bs,)
        pg_mean_clipfrac = F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
        
        # Compute per-sample KL losses
        kl_losses = compute_kl_penalty(log_prob, ref_log_prob, kl_penalty)
        per_sample_kl_losses = F.masked_mean(kl_losses, eos_mask, axis=-1)  # shape: (bs,)
        kl_mean_loss = per_sample_kl_losses.mean()

        # Compute policy gradient loss component
        pg_mean_loss = per_sample_pg_losses.mean()
        kl_mean_loss_value = kl_mean_loss.detach().item()
        
        # Compute policy entropy ratio for monitoring
        if pg_mean_loss.detach().item() != 0:
            policy_entropy_ratio = abs(entropy_coeff * per_sample_entropy.mean().detach().item() / pg_mean_loss.detach().item())
        else:
            policy_entropy_ratio = 0.0

        # Compute final per-sample loss (sample-level entropy control)
        if use_target_entropy == True and entropy_coeff > 0:
            # 使用新的熵损失策略：当熵低于目标时添加惩罚
            per_sample_pg_loss = per_sample_pg_losses + per_sample_kl_losses * kl_ctrl.value + entropy_loss_term
        else:
            # 回退：不使用熵loss
            per_sample_pg_loss = per_sample_pg_losses + per_sample_kl_losses * kl_ctrl.value
        
        # Final loss is mean across samples
        pg_loss = per_sample_pg_loss.mean()

        return pg_loss, pg_mean_clipfrac, ppo_kl, kl_mean_loss, policy_entropy_ratio