from typing import Dict, Tuple
import torch
import torch.nn as nn
import re

from mindspeed_rl.models.loss.loss_func_factory import LossFuncFactory
from mindspeed_rl.models.loss.base_loss_func import BaseLossFunc
from mindspeed_rl.utils.compute import compute_kl_penalty
from mindspeed_rl.utils.utils import generate_mask
import mindspeed_rl.utils.torch_functional as F
from mindspeed_rl.utils.utils import MsProbe


@LossFuncFactory.register_loss('ray_grpo', 'actor')
class GRPOActorLossFunc(BaseLossFunc, nn.Module):
    def __init__(self):
        super().__init__()
        self.clip_ratio = 0.2
        self.beta_S = 0.1  # 固定短路径熵系数
        self.target_entropy_long = 0.2   # 🔷 新增：长路径目标熵
        self.target_entropy_short = 0.1  # 🔷 新增：短路径目标熵
        self.clip_higher_enable = False
        self.clip_ratio_low = 0.2
        self.clip_ratio_high = 0.2
        self.kl_ctrl = None      # 或设为默认控制器（如一个 float 或 AdaptiveKLController 实例）
        self.kl_penalty = "kl"   # 或 "abs", "mse" 等，默认值需与 compute_kl_penalty 兼容
        # 🔷 可学习的 long 路径熵系数（log scale）
        self.log_beta_L = nn.Parameter(torch.log(torch.tensor(0.1)))  # 初始 β_L = 0.1
        self.optimizer_beta = None  # 将在外部注册
        self._cached_beta_loss = None


    def add_loss_meta_info(self, meta_info: Dict):
        if meta_info is None:
            return
        if "clip_ratio" in meta_info.keys():
            self.clip_ratio = float(meta_info["clip_ratio"])
        if "kl_ctrl" in meta_info.keys():
            self.kl_ctrl = meta_info["kl_ctrl"]
        if "kl_penalty" in meta_info.keys():
            self.kl_penalty = meta_info["kl_penalty"]
        if "entropy_coeff" in meta_info.keys():
            self.beta_S = meta_info["entropy_coeff"]
        if "clip_higher_enable" in meta_info.keys():
            self.clip_higher_enable = bool(meta_info["clip_higher_enable"])
        if "clip_ratio_low" in meta_info.keys():
            self.clip_ratio_low = float(meta_info["clip_ratio_low"])
        if "clip_ratio_high" in meta_info.keys():
            self.clip_ratio_high = float(meta_info["clip_ratio_high"])
        # 🔷 支持动态设置 target entropy
        if "target_entropy_long" in meta_info:
            self.target_entropy_long = float(meta_info["target_entropy_long"])
        if "target_entropy_short" in meta_info:
            self.target_entropy_short = float(meta_info["target_entropy_short"])

    @staticmethod
    def _get_policy_loss_input(batch: Dict[str, torch.Tensor], **kwargs):
        if 'responses' not in batch:
            raise ValueError("The responses is None")
        response_mask = generate_mask(batch['responses'], batch['response_length']).npu()
        old_log_prob = batch['old_log_prob'] if 'old_log_prob' in batch else None
        advantages = batch['advantages'] if 'advantages' in batch else None
        ref_log_prob = batch['ref_log_prob'] if 'ref_log_prob' in batch else None
        
        # 始终从responses解码提取think_mode，而不是优先使用batch中已有的
        tokenizer = kwargs.get('tokenizer')
        # rl_config = kwargs.get('rl_config')
        think_mode = batch.get('think_mode', None)
        
        if tokenizer is not None and think_mode is None:
            # 解码responses为文本
            ignore_token = kwargs.get('ignore_token', -100)
            str_responses = tokenizer.batch_decode(
                torch.where(batch["responses"] == ignore_token, tokenizer.eos_token_id, batch["responses"]), 
                skip_special_tokens=True
            )
            
            # 提取think_mode
            tool_mode_list = []
            tool_generalized_mode = 4
            
            for resp in str_responses:
                resp_clean = resp.strip()  # 清理前后空白，确保 ^ 能匹配开头
                if tool_generalized_mode == 1:
                    think_pattern = r'^<think>.*?</think>'
                    no_think_pattern = r'^<no_think>.*?</no_think>'
                    if re.search(think_pattern, resp_clean, re.DOTALL):
                        tool_mode_list.append(1)
                    elif re.search(no_think_pattern, resp_clean, re.DOTALL):
                        tool_mode_list.append(0)
                    else:
                        tool_mode_list.append(0)
                elif tool_generalized_mode == 2:
                    think_pattern = r'^\[think\].*?\[/think\]'
                    no_think_pattern = r'^\[no_think\]\s*\[/no_think\]'
                    if re.search(think_pattern, resp_clean, re.DOTALL):
                        tool_mode_list.append(1)
                    elif re.search(no_think_pattern, resp_clean, re.DOTALL):
                        tool_mode_list.append(0)
                    else:
                        tool_mode_list.append(0)
                elif tool_generalized_mode == 3:
                    mode_think = re.search(r'^<mode>\s*think\s*</mode>\s*<think>.*?</think>', resp_clean, re.IGNORECASE | re.DOTALL)
                    mode_no_think = re.search(r'^<mode>\s*no_think\s*</mode>\s*<no_think>.*?</no_think>', resp_clean, re.IGNORECASE | re.DOTALL)
                    if mode_think:
                        tool_mode_list.append(1)
                    elif mode_no_think:
                        tool_mode_list.append(0)
                    else:
                        tool_mode_list.append(0)
                elif tool_generalized_mode == 4:
                    mode_think = re.search(r'^\[mode\]\s*think\s*\[/mode\]\s*\[think\].*?\[/think\]', resp_clean, re.IGNORECASE | re.DOTALL)
                    mode_no_think = re.search(r'^\[mode\]\s*no_think\s*\[/mode\]\s*\[no_think\].*?\[/no_think\]', resp_clean, re.IGNORECASE | re.DOTALL)
                    if mode_think:
                        tool_mode_list.append(1)
                    elif mode_no_think:
                        tool_mode_list.append(0)
                    else:
                        tool_mode_list.append(0)
                else:
                    tool_mode_list.append(0)
            think_mode = torch.tensor(tool_mode_list, dtype=torch.long, device=batch['responses'].device)
            return response_mask, old_log_prob, advantages, ref_log_prob, think_mode

    def compute_loss(self, output: torch.Tensor,
                     batch: Dict[str, torch.Tensor],
                     forward_only=False,
                     non_loss_data=True,
                     **kwargs) -> Tuple[torch.Tensor, Dict]:
        if forward_only:
            log_probs, _ = super().compute_log_probs(output=output, batch=batch, **kwargs)
            return log_probs

        # 1. 始终计算熵
        log_probs, full_entropy = super().compute_log_probs(output=output, batch=batch, skip_entropy=False, **kwargs)

        # 2. 获取输入
        response_mask, old_log_prob, advantages, ref_log_prob, think_mode = self._get_policy_loss_input(batch=batch, **kwargs)

        if not self.clip_higher_enable:
            self.clip_ratio_low = self.clip_ratio
            self.clip_ratio_high = self.clip_ratio

        if think_mode is None:
            raise ValueError("Missing 'think_mode' in batch. Use 1 for think, 0 for no_think.")

        # 3. 调用修改后的损失函数
        pg_loss, pg_clipfrac, ppo_kl, kl_loss, entropy_loss = self._compute_grpo_policy_loss(
            old_log_prob=old_log_prob,
            log_prob=log_probs,
            ref_log_prob=ref_log_prob,
            advantages=advantages,
            entropy=full_entropy,
            think_mode=think_mode,
            beta_S=self.beta_S,
            log_beta_L=self.log_beta_L,
            target_entropy_long=self.target_entropy_long,
            target_entropy_short=self.target_entropy_short,
            eos_mask=response_mask,
            cliprange=self.clip_ratio,
            clip_ratio_low=self.clip_ratio_low,
            clip_ratio_high=self.clip_ratio_high,
            kl_ctrl=self.kl_ctrl,
            kl_penalty=self.kl_penalty
        )

        # 4. 自适应 β 损失（仅作用于长路径样本）
        long_mask = think_mode.float()  # [bs], 1 for long, 0 for short
        if long_mask.sum() == 0:
            self._cached_beta_loss = torch.tensor(0.0, device=pg_loss.device, requires_grad=True)
        else:
            per_sample_entropy = F.masked_mean(full_entropy, response_mask, axis=-1)  # [bs]
            target = torch.full_like(per_sample_entropy, self.target_entropy_long)
            entropy_diff = F.relu(target - per_sample_entropy)  # ✅ 修改点
            beta_L = self.log_beta_L.exp()
            beta_loss = (beta_L * entropy_diff * long_mask).sum() / long_mask.sum() # 现在 beta_loss 依赖于 log_beta_L
            self._cached_beta_loss = beta_loss

        # 5. 统计信息（用 token-level 平均熵做 display）
        display_entropy = (full_entropy * response_mask).sum() / response_mask.sum()

        use_dynamic_bsz = kwargs.get('use_dynamic_bsz', False)
        actual_micro_batch_size = kwargs.get('actual_micro_batch_size', None)
        if use_dynamic_bsz and not forward_only:
            policy_loss = pg_loss * (batch['responses'].size(0) / actual_micro_batch_size)
        else:
            policy_loss = pg_loss

        # 保存数据用于 debug
        data_tobe_saved = {
            "old_log_prob": old_log_prob,
            "log_prob": log_probs,
            "ref_log_prob": ref_log_prob,
            "advantages": advantages,
            "loss": pg_loss,
            "kl_loss": kl_loss,
            "think_mode": think_mode,
            "beta_L": self.log_beta_L.exp(),
            "entropy_loss": entropy_loss,  # 现在是真正的训练损失项
        }
        MsProbe.save_data(data_tobe_saved)

        stats = {
            'actor/pg_loss': pg_loss.detach().item(),
            'actor/pg_clipfrac': pg_clipfrac.detach().item(),
            'actor/entropy_loss': entropy_loss.detach().item(),  # 🔷 现在是训练用的 entropy_loss
            'actor/entropy': display_entropy.detach().item(),
            'actor/beta_L': self.log_beta_L.exp().detach().item(),
            'beta_loss': self._cached_beta_loss.detach().item(),
            'beta_loss_cached': self._cached_beta_loss.detach().item(),
            "think_mode": think_mode.mean().detach().item(),
        }

        return policy_loss, stats

    @staticmethod
    def _compute_grpo_policy_loss(old_log_prob, log_prob, ref_log_prob, advantages, entropy,
                                  think_mode, beta_S, log_beta_L,
                                  target_entropy_long, target_entropy_short,
                                  eos_mask, cliprange, clip_ratio_low, clip_ratio_high, kl_ctrl, kl_penalty):
        """
        使用 F.relu 计算 entropy_gap = max(0, target - H)
        """
        if old_log_prob is None:
            old_log_prob = log_prob.detach().clone()

        negative_approx_kl = log_prob - old_log_prob
        ratio = torch.exp(negative_approx_kl)
        ppo_kl = F.masked_mean(-negative_approx_kl, eos_mask)

        # 🔷 Step 1: 构造 beta_coeff
        beta_L = log_beta_L.exp()
        is_long_path_expanded = think_mode.unsqueeze(-1)  # [bs, 1]
        beta_coeff = (1 - is_long_path_expanded) * beta_S + is_long_path_expanded * beta_L  # [bs, 1]

        # 🔷 Step 2: sample-level 熵
        per_sample_entropy = F.masked_mean(entropy, eos_mask, axis=-1)  # [bs]

        # 🔷 Step 3: 根据 think_mode 选择 target_entropy
        target_entropy = torch.where(
            think_mode.bool(),
            torch.tensor(target_entropy_long, device=per_sample_entropy.device, dtype=per_sample_entropy.dtype),
            torch.tensor(target_entropy_short, device=per_sample_entropy.device, dtype=per_sample_entropy.dtype)
        )  # [bs]

        # 🔷 Step 4: 使用 ReLU 计算 entropy_gap = max(0, target - H)
        entropy_gap = F.relu(target_entropy - per_sample_entropy)  # [bs] —— ✅ 核心修改点

        # sample-level entropy loss
        entropy_loss = (beta_coeff.squeeze(-1) * entropy_gap).mean()  # scalar

        # GRPO 策略损失
        pg_losses = -advantages * ratio
        pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - clip_ratio_low, 1.0 + clip_ratio_high)
        pg_mean_loss = F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
        pg_mean_clipfrac = F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)

        # KL 惩罚
        kl_losses = compute_kl_penalty(log_prob, ref_log_prob, kl_penalty)
        kl_mean_loss = F.masked_mean(kl_losses, eos_mask)

        # 总损失：策略 + KL + entropy_loss（惩罚项）
        pg_loss = pg_mean_loss + kl_mean_loss * kl_ctrl.value + entropy_loss

        return pg_loss, pg_mean_clipfrac, ppo_kl, kl_mean_loss, entropy_loss