from typing import Dict, Tuple

import torch

from mindspeed_rl.models.loss.loss_func_factory import LossFuncFactory
from mindspeed_rl.models.loss.base_loss_func import BaseLossFunc
from mindspeed_rl.utils.compute import compute_kl_penalty
from mindspeed_rl.utils.utils import generate_mask
import mindspeed_rl.utils.torch_functional as F
from mindspeed_rl.utils.utils import MsProbe


@LossFuncFactory.register_loss('ray_grpo', 'actor')
class GRPOActorLossFunc(BaseLossFunc):
    def __init__(self):
        super().__init__()
        self.clip_ratio = 0.2
        self.entropy_coeff = 1e-4  # 使用之前提到的1e-4作为默认值
        self.target_entropy = 0.01  # 设置目标熵下限为0.2

    def add_loss_meta_info(self, meta_info: Dict):
        if meta_info is None:
            return
        if "clip_ratio" in meta_info.keys():
            self.clip_ratio = float(meta_info["clip_ratio"])
        if "kl_ctrl" in meta_info.keys():
            self.kl_ctrl = meta_info["kl_ctrl"]
        if "entropy_coeff" in meta_info.keys():
            self.entropy_coeff = meta_info["entropy_coeff"]
        if "kl_penalty" in meta_info.keys():
            self.kl_penalty = meta_info["kl_penalty"]
        if "target_entropy" in meta_info.keys():
            self.target_entropy = meta_info["target_entropy"]  # 允许从元信息设置目标熵

    @staticmethod
    def _get_policy_loss_input(batch: Dict[str, torch.Tensor]):
        if 'responses' not in batch:
            raise ValueError("The responses is None")
        response_mask = generate_mask(batch['responses'], batch['response_length']).npu()
        old_log_prob = batch['old_log_prob'] if 'old_log_prob' in batch else None
        advantages = batch['advantages'] if 'advantages' in batch else None
        ref_log_prob = batch['ref_log_prob'] if 'ref_log_prob' in batch else None
        return response_mask, old_log_prob, advantages, ref_log_prob

    def compute_loss(self, output: torch.Tensor,
                     batch: Dict[str, torch.Tensor],
                     forward_only=False,
                     non_loss_data=True,
                     **kwargs) -> Tuple[torch.Tensor, Dict]:

        if forward_only:
            log_probs, _ = super().compute_log_probs(output=output, batch=batch,** kwargs)
            return log_probs

        # 计算 log_probs 和 full_entropy (per-token)
        log_probs, full_entropy = super().compute_log_probs(
            output=output, batch=batch, skip_entropy=False, **kwargs
        )

        response_mask, old_log_prob, advantages, ref_log_prob = self._get_policy_loss_input(batch=batch)

        # === 计算当前平均熵（用于目标熵控制）===
        current_entropy = (full_entropy * response_mask).sum() / response_mask.sum()  # scalar
        target = torch.tensor(self.target_entropy, dtype=current_entropy.dtype, device=current_entropy.device)
        # 指示函数：当当前熵 ≤ 目标熵时为1，否则为0
        indicator = (current_entropy <= target).float()
        
        # === 构造 entropy loss: 当 e_k ≤ tgt-ent 时激活，否则为0 ===
        # 根据新策略实现熵损失：loss = β · I{e_k ≤ tgt-ent}
        if self.target_entropy is not None and self.entropy_coeff > 0 and indicator > 1e-5:
            # 创建目标熵张量
            # 熵损失项 = β * 指示函数
            entropy_loss_term = self.entropy_coeff * indicator
            use_target_entropy = True
        else:
            entropy_loss_term = torch.zeros_like(current_entropy)
            use_target_entropy = False

        # === 调用 policy loss，现在传入 full_entropy ===
        pg_loss, pg_clipfrac, ppo_kl, kl_loss = self._compute_grpo_policy_loss(
            old_log_prob=old_log_prob,
            log_prob=log_probs,
            ref_log_prob=ref_log_prob,
            advantages=advantages,
            full_entropy=full_entropy,           # ✅ 显式传入
            eos_mask=response_mask,
            cliprange=self.clip_ratio,
            kl_ctrl=self.kl_ctrl,
            kl_penalty=self.kl_penalty,
            entropy_coeff=self.entropy_coeff,
            entropy_loss_term=entropy_loss_term,
            use_target_entropy=use_target_entropy
        )

        # === 动态 batch size 处理 ===
        use_dynamic_bsz = kwargs.get('use_dynamic_bsz', False)
        actual_micro_batch_size = kwargs.get('actual_micro_batch_size', None)
        if use_dynamic_bsz and not forward_only:
            policy_loss = pg_loss * (batch['responses'].size(0) / actual_micro_batch_size)
        else:
            policy_loss = pg_loss

        # === 探针保存中间数据 ===
        data_tobe_saved = {
            "old_log_prob": old_log_prob,
            "log_prob": log_probs,
            "ref_log_prob": ref_log_prob,
            "advantages": advantages,
            "full_entropy": full_entropy,
            "response_mask": response_mask,
            "loss": pg_loss,
            "kl_loss": kl_loss,
            "current_entropy": current_entropy,
            "target_entropy": self.target_entropy,
            "entropy_indicator": indicator if use_target_entropy else None  # 保存指示函数值
        }
        MsProbe.save_data(data_tobe_saved)

        # === 统计信息 ===
        stats = {
            'actor/pg_loss': pg_loss.detach().item(),
            'actor/pg_clipfrac': pg_clipfrac.detach().item(),
            'actor/ppo_kl': ppo_kl.detach().item(),
            'actor/kl_loss': kl_loss.detach().item(),
            'actor/entropy': current_entropy.detach().item(),
            'actor/entropy_loss_term': entropy_loss_term.detach().item(),
            'actor/entropy_indicator': indicator.detach().item() if use_target_entropy else 0,
            'actor/entropy_below_target': (current_entropy <= self.target_entropy).detach().item()
        }

        return policy_loss, stats
    
    
    @staticmethod
    def _compute_grpo_policy_loss(old_log_prob, log_prob, ref_log_prob, advantages, 
                                  full_entropy, eos_mask, cliprange, kl_ctrl, 
                                  kl_penalty, entropy_coeff, entropy_loss_term, 
                                  use_target_entropy=False):
        """
        Args:
            old_log_prob: (bs, seq_len)
            log_prob: (bs, seq_len)
            ref_log_prob: (bs, seq_len)
            advantages: (bs, seq_len)
            full_entropy: (bs, seq_len)  # 每个 token 的熵
            eos_mask: (bs, seq_len)
            cliprange: float
            kl_ctrl: object with .value 或 float
            kl_penalty: str, e.g., 'kl', 'abs', 'sq'
            entropy_coeff: float
            entropy_loss_term: scalar, 已计算好的熵惩罚项（β · I{e_k ≤ tgt-ent}）
            use_target_entropy: bool, 是否使用目标熵机制

        Returns:
            pg_loss: scalar
            pg_clipfrac: scalar
            ppo_kl: scalar
            kl_loss: scalar
        """
        with torch.no_grad():
            if old_log_prob is None:
                old_log_prob = log_prob.detach().clone()

            negative_approx_kl = log_prob - old_log_prob
            ratio = torch.exp(negative_approx_kl)
            ppo_kl = F.masked_mean(-negative_approx_kl, eos_mask)

            # Clip fraction
            pg_losses = -advantages * ratio
            pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
            pg_mean_clipfrac = F.masked_mean((pg_losses2 > pg_losses).float(), eos_mask)

        # PPO 主损失
        pg_mean_loss = F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)

        # KL 正则项
        kl_losses = compute_kl_penalty(log_prob, ref_log_prob, kl_penalty)
        kl_mean_loss = F.masked_mean(kl_losses, eos_mask)

        # 获取 kl_ctrl.value 或直接使用 float
        kl_weight = kl_ctrl.value if hasattr(kl_ctrl, 'value') else kl_ctrl

        # 总损失
        if use_target_entropy == True and entropy_coeff > 0:
            print('============+++++++++this is having entropy loss')
            # 使用新的熵损失策略：当熵低于目标时添加惩罚
            total_loss = pg_mean_loss + kl_mean_loss * kl_weight + entropy_loss_term
        else:
            # 回退：最大化熵（原始行为, 不使用熵约束）
            print('============+++++++++this is no entropy loss')
            current_entropy = F.masked_mean(full_entropy, eos_mask)
            total_loss = pg_mean_loss + kl_mean_loss * kl_weight

        return total_loss, pg_mean_clipfrac, ppo_kl, kl_mean_loss
