import re
from typing import List

import re
from typing import List, Dict, Any

def format_reward(completions: List[List[Dict[str, Any]]], **kwargs) -> List[float]:
    """
    (已修复并增强) 奖励函数，用于严格检查模型的输出是否遵循指定的格式。

    格式要求: 
        1. |<think>|, |</think>|, |<answer>|, |</answer>| 四个标签必须且仅出现一次。
        2. 它们必须按以上顺序出现。
    奖励:
        - 1.0: 如果格式严格正确。
        - 0.0: 如果格式不正确。
    """
    # 顺序检查的正则表达式 (保持不变)
    order_pattern = re.compile(r"^.*?\|\<think\>\|.*?\|\</think\>\|.*?\|\<answer\>\|.*?\|\</answer\>\|.*?$", re.DOTALL)
    
    rewards = []
    for comp_item in completions:
        reward = 0.0
        try:
            content_str = comp_item[0]['content']
            
            # 【核心修正】第一步：进行严格的数量检查
            correct_counts = (
                content_str.count("|<think>|") == 1 and
                content_str.count("|</think>|") == 1 and
                content_str.count("|<answer>|") == 1 and
                content_str.count("|</answer>|") == 1
            )
            
            # 如果数量正确，再进行第二步：顺序检查
            if correct_counts:
                if order_pattern.fullmatch(content_str.strip()):
                    reward = 1.0

        except (IndexError, KeyError, TypeError):
            pass  # 数据结构不符，保持奖励为0
        
        rewards.append(reward)
            
    return rewards


def result_reward(completions: List[List[Dict[str, Any]]], solution: List[str], **kwargs) -> List[float]:
    """
    (已修复) 奖励函数，用于检查模型提取的答案是否与标准答案一致。

    逻辑:
    1. 从模型输出中提取 |<answer>| 和 |</answer>| 标签之间的内容。
    2. 将提取的内容与 'solution' (标准答案) 进行比较。
    奖励:
        - 1.0: 如果答案正确。
        - 0.0: 如果答案不正确或无法提取答案。
    """
    # 用于从文本中提取答案的正则表达式
    pattern = re.compile(r"\|\<answer\>\|(.*?)\|\</answer\>\|", re.DOTALL)
    
    rewards = []
    # 同时遍历模型输出和标准答案
    for comp_item, sol in zip(completions, solution):
        reward = 0.0 # 默认奖励为0
        
        try:
            # 核心改动：从嵌套结构中提取出真正的字符串内容
            content_str = comp_item[0]['content']

            # 对提取出的字符串执行原有的逻辑
            match = re.search(pattern, content_str)
            
            if match:
                # 提取捕获组1的内容，并去除首尾的空白字符
                extracted_answer = match.group(1).strip()
                
                # 与标准答案比较 (忽略大小写和首尾空格)
                if extracted_answer.lower() == sol.strip().lower():
                    reward = 1.0
        except (IndexError, KeyError, TypeError):
            # 如果数据结构不符合预期，则保持奖励为0
            pass
        
        rewards.append(reward)
        
    return rewards

def length_reward(completions: List[List[Dict[str, Any]]], target_length: int = 500, zero_length: int = 512, **kwargs) -> List[float]:
    """
    奖励函数，用于鼓励模型输出接近目标长度的文本，并在超过特定长度后迅速惩罚。

    逻辑:
    1. 计算输出内容中的词数 (通过空格分割近似为token数)。
    2. 奖励呈非对称的线性（三角形）分布：
        - 在 `target_length` (如500) 时，奖励为 1.0。
        - 在 0 到 `target_length` 之间，奖励从 0 线性增加到 1。
        - 在 `target_length` 到 `zero_length` (如512) 之间，奖励从 1 线性下降到 0。
        - 长度超过 `zero_length` 后，奖励恒为 0。
    
    参数:
        - completions: 模型的输出列表。
        - target_length: 奖励为1.0的期望长度，默认为 500。
        - zero_length: 奖励降为0.0的长度，默认为 512。

    返回:
        - 一个范围在 [0.0, 1.0] 之间的奖励分数列表。
    """
    if zero_length <= target_length:
        raise ValueError(f"zero_length ({zero_length}) must be greater than target_length ({target_length}).")

    rewards = []
    # 计算下降区间的长度
    decay_width = zero_length - target_length

    for comp_item in completions:
        reward = 0.0
        try:
            content_str = comp_item[0]['content']
            actual_length = len(content_str.split())

            if actual_length <= target_length:
                # 在 [0, target_length] 区间内，奖励线性增加
                reward = actual_length / target_length
            elif actual_length <= zero_length:
                # 【新逻辑】在 (target_length, zero_length] 区间内，奖励线性减少
                # reward = 1.0 - (超出的部分 / 下降区间的总宽度)
                reward = 1.0 - (actual_length - target_length) / decay_width
            else:
                # 超过 zero_length，奖励为0
                reward = 0.0
        
        except (IndexError, KeyError, TypeError):
            pass
        
        rewards.append(max(0.0, min(1.0, reward)))
        
    return rewards

# ================================================================
#  修正后的单元测试代码
# ================================================================
if __name__ == '__main__':
    print("--- 测试奖励函数 (输入为修正后的数据结构) ---")

    # 1. 原始的纯文本数据 和 标准答案 (保持不变)
    original_completions_text = [
        "|<think>|思考过程...|</think>|这是中间分析。|<answer>| B |</answer>|",# 格式正确，答案正确
        "|<think>|思考过程...|</think>| |<answer>|C|</answer>|",# 格式正确，答案错误
        "|<answer>| B |</answer>| |<think>|思考过程...|</think>|",# 格式错误 (顺序颠倒)，答案正确
        "这是一个没有标签的答案D",# 格式错误，无法提取答案
        "|<think>|思考过程...|</think>|<answer>|A|</answer>",
        "|<think>| |<think>| |<think>||</think>||<answer>||</answer>|"
    ]
    # 为了让第5个例子更有趣，我们假设它缺少了 </think> 标签

    solutions_data = ["B", "B", "B", "D", "A", "D"]

    # 2. 【核心修正】将原始文本数据包装成函数期望的格式
    #    使用列表推导式可以很方便地完成这个转换
    formatted_completions_data = [[{"role": "assistant", "content": text}] for text in original_completions_text]

    # 3. 使用修正后的数据调用奖励函数
    format_rewards = format_reward(formatted_completions_data)
    result_rewards = result_reward(formatted_completions_data, solution=solutions_data)

    # 4. 打印结果 (为了清晰，我们打印原始文本)
    print(f"标准答案: {solutions_data}\n")
    for i, content in enumerate(original_completions_text):
        print(f"模型输出 {i+1}: \"{content}\"")
        print(f"  - 格式奖励: {format_rewards[i]}")
        print(f"  - 结果奖励: {result_rewards[i]}\n")