import logging
import numpy as np
import re
from rouge import Rouge
from difflib import SequenceMatcher

class RolloutRewardAgent:
    """
    针对筛选通过的子计划，进行rollout奖励计算，支持多种reward评估方式
    适配planning_results.json数据结构，利用checks_once/checks_range/checks_periodic进行评估
    """
    @staticmethod
    def _rougel_score(prediction, ground_truth):
        """计算ROUGE-L分数"""
        rouge = Rouge()
        try:
            scores = rouge.get_scores(prediction, ground_truth, avg=True)
        except ValueError:
            return 0.0
        return scores["rouge-l"]["f"]

    @staticmethod
    def _acc_score(prediction, ground_truth):
        """计算包含准确率"""
        if not prediction or not ground_truth:
            return 0.0
            
        if isinstance(prediction, list):
            prediction = " ".join(prediction)
            
        if ground_truth in prediction or ground_truth.lower() in prediction or ground_truth.capitalize() in prediction:
            return 1.0
        else:
            return 0.0
    
    @staticmethod
    def _soft_match(text, pattern, threshold=0.7):
        """使用相似度而非精确匹配"""
        if not text or not pattern:
            return 0.0
        
        # 清理和标准化文本
        if isinstance(text, list):
            text = " ".join(str(t) for t in text)
        
        text = str(text).lower().strip()
        pattern = str(pattern).lower().strip()
        
        # 直接包含关系判断
        if pattern in text:
            return 1.0
            
        # 计算序列相似度
        similarity = SequenceMatcher(None, text, pattern).ratio()
        
        # 只有达到阈值才视为匹配
        return similarity if similarity >= threshold else 0.0
    
    @staticmethod
    def _extract_keywords(text):
        """从文本中提取关键词"""
        if not text:
            return []
            
        # 简单的关键词提取
        words = str(text).lower().split()
        # 只保留长度大于3的词，可能是更有意义的关键词
        keywords = [w.strip('.,;:?!()[]{}') for w in words if len(w) > 3]
        return list(set(keywords))  # 去重
    
    @staticmethod
    def _extract_content(step):
        """从步骤中提取内容文本"""
        if isinstance(step, str):
            return step
            
        content = ""
        # 按优先级提取内容
        if "content" in step:
            content = step["content"]
        elif "events" in step:
            events = step["events"]
            if isinstance(events, list):
                if all(isinstance(e, dict) for e in events):
                    content = " ".join([
                        f"{e.get('event_name', '')} {e.get('event', '')} {e.get('details', '')}" 
                        for e in events if isinstance(e, dict)
                    ])
                else:
                    content = " ".join(str(e) for e in events)
            else:
                content = str(events)
        elif "dishes" in step:
            dishes = step["dishes"]
            if isinstance(dishes, list):
                content = " ".join(str(d) for d in dishes)
            else:
                content = str(dishes)
        elif "purpose" in step:
            content = step["purpose"]
        elif "use" in step:
            content = step["use"]
            
        # 添加ID信息（如week_id, floor_id等）
        for id_field in ["week_id", "floor_id", "block_id"]:
            if id_field in step:
                content = f"{step[id_field]} {content}"
                
        return content
        
    @staticmethod
    def _calculate_relevance(content, prompt):
        """计算内容与提示的相关性"""
        if not content or not prompt:
            return 0.0
            
        # 提取提示中的关键词
        prompt_keywords = RolloutRewardAgent._extract_keywords(prompt)
        content = content.lower()
        
        # 计算关键词匹配率
        if not prompt_keywords:
            return 0.0
            
        matches = sum(1 for keyword in prompt_keywords if keyword in content)
        return min(1.0, matches / max(len(prompt_keywords) / 3, 1))  # 匹配1/3关键词就满分
    
    @staticmethod
    def _check_date_match(date_str, target_date):
        """检查日期是否匹配目标日期"""
        # 处理MM-DD格式日期
        date_match_target = re.search(r'(\d{1,2})[-/](\d{1,2})', target_date)
        date_match = re.search(r'(\d{1,2})[-/](\d{1,2})', date_str)
        
        if not date_match or not date_match_target:
            return False
            
        target_month, target_day = int(date_match_target.group(1)), int(date_match_target.group(2))
        month, day = int(date_match.group(1)), int(date_match.group(2))
        
        return month == target_month and day == target_day

    @staticmethod
    def _check_week_match(week_str, target_week):
        """检查周次是否匹配目标周次"""
        # 处理周次范围，如Week 10-12
        week_range_match = re.search(r'Week\s+(\d+)[-\s]+(\d+)', week_str)
        if week_range_match:
            start_week = int(week_range_match.group(1))
            end_week = int(week_range_match.group(2))
            return start_week <= target_week <= end_week
            
        # 处理单个周次，如Week 10
        week_match = re.search(r'Week\s+(\d+)', week_str)
        if week_match:
            week_num = int(week_match.group(1))
            return week_num == target_week
            
        return False

    @staticmethod
    def _get_floor_reward(floor, example):
        """针对floor_plan的评分逻辑"""
        reward = 0.0
        floor_id = floor.get("floor_id", "")
        purpose = floor.get("purpose", "")
        
        if not floor_id or not purpose:
            return 0.0
            
        # 提取楼层号码
        floor_num_match = re.search(r'Floor\s+(\d+)', floor_id)
        if not floor_num_match:
            return 0.0
            
        floor_num = int(floor_num_match.group(1))
        
        # 检查checks_once - 特定楼层要求
        for check_floor, expected_use in example.get("checks_once", {}).items():
            if str(floor_num) == str(check_floor):
                # 使用软匹配而不是精确匹配
                match_score = RolloutRewardAgent._soft_match(purpose, expected_use)
                if match_score > 0:
                    reward += match_score
                    break
        
        # 检查checks_range - 楼层范围要求
        for range_str, expected_use in example.get("checks_range", {}).items():
            try:
                check_floor = int(range_str)
                if check_floor == floor_num:
                    match_score = RolloutRewardAgent._soft_match(purpose, expected_use)
                    if match_score > 0:
                        reward += match_score
                        break
            except ValueError:
                # 处理范围格式如"20-26"
                range_match = re.search(r'(\d+)[-](\d+)', range_str)
                if range_match:
                    start_floor = int(range_match.group(1))
                    end_floor = int(range_match.group(2))
                    if start_floor <= floor_num <= end_floor:
                        match_score = RolloutRewardAgent._soft_match(purpose, expected_use)
                        if match_score > 0:
                            reward += match_score
                            break
        
        # 检查checks_periodic - 周期性楼层要求
        for period_str, expected_use in example.get("checks_periodic", {}).items():
            try:
                check_floor = int(period_str)
                if check_floor == floor_num:
                    match_score = RolloutRewardAgent._soft_match(purpose, expected_use)
                    if match_score > 0:
                        reward += match_score
                        break
            except ValueError:
                pass
                
        # 关键词匹配奖励 - 从提示中提取关键词
        prompt = example.get("prompt", "")
        if prompt and purpose:
            relevance_score = RolloutRewardAgent._calculate_relevance(purpose, prompt)
            reward += relevance_score * 0.5  # 最多贡献0.5分
                
        return min(reward, 1.0)  # 确保reward不超过1.0

    @staticmethod
    def _get_block_reward(block, example):
        """针对block_plan的评分逻辑"""
        reward = 0.0
        block_id = block.get("block_id", "")
        use = block.get("use", "")
        
        if not block_id or not use:
            return 0.0
            
        # 尝试提取区块坐标
        coords_match = re.search(r'\((\d+),\s*(\d+)\)', block_id)
        if coords_match:
            x = int(coords_match.group(1))
            y = int(coords_match.group(2))
            
            # 检查checks_once - 特定区块要求
            for check_block, expected_use in example.get("checks_once", {}).items():
                try:
                    # 对于数字格式的检查项，转换为对应的坐标
                    block_num = int(check_block)
                    grid_size = 10  # 假设是10x10的网格
                    check_x = (block_num - 1) % grid_size
                    check_y = (block_num - 1) // grid_size
                    
                    if x == check_x and y == check_y:
                        match_score = RolloutRewardAgent._soft_match(use, expected_use)
                        if match_score > 0:
                            reward += match_score
                            break
                except ValueError:
                    # 对于非数字格式，可能是坐标格式
                    coords_check = re.search(r'\((\d+),\s*(\d+)\)', check_block)
                    if coords_check:
                        check_x = int(coords_check.group(1))
                        check_y = int(coords_check.group(2))
                        if x == check_x and y == check_y:
                            match_score = RolloutRewardAgent._soft_match(use, expected_use)
                            if match_score > 0:
                                reward += match_score
                                break
            
            # 检查checks_range - 区块范围要求
            for range_str, expected_use in example.get("checks_range", {}).items():
                # 假设区块范围是基于区块编号的
                try:
                    range_parts = range_str.split('-')
                    if len(range_parts) == 2:
                        start_block = int(range_parts[0])
                        end_block = int(range_parts[1])
                        
                        block_num = y * grid_size + x + 1
                        if start_block <= block_num <= end_block:
                            match_score = RolloutRewardAgent._soft_match(use, expected_use)
                            if match_score > 0:
                                reward += match_score
                                break
                except ValueError:
                    pass
                    
            # 检查checks_periodic - 周期性区块要求
            for period_str, expected_use in example.get("checks_periodic", {}).items():
                try:
                    period_block = int(period_str)
                    block_num = y * grid_size + x + 1
                    
                    # 简单周期性检查
                    if block_num == period_block:
                        match_score = RolloutRewardAgent._soft_match(use, expected_use)
                        if match_score > 0:
                            reward += match_score
                            break
                except ValueError:
                    pass
            
            # 关键词匹配奖励 - 从提示中提取关键词
            prompt = example.get("prompt", "")
            if prompt and use:
                relevance_score = RolloutRewardAgent._calculate_relevance(use, prompt)
                reward += relevance_score * 0.5  # 最多贡献0.5分
                    
        return min(reward, 1.0)  # 确保reward不超过1.0

    @staticmethod
    def _get_weekly_menu_reward(menu_item, example):
        """针对weekly_plan/menu的评分逻辑"""
        reward = 0.0
        week_id = menu_item.get("week_id", "")
        
        # 获取菜品内容，可能存储在不同字段
        dishes = menu_item.get("dishes", "")
        if isinstance(dishes, list):
            dishes = " ".join(str(dish) for dish in dishes)
        elif not isinstance(dishes, str):
            dishes = str(dishes)
            
        # 也检查events字段
        events = menu_item.get("events", "")
        if isinstance(events, list):
            events = " ".join(str(event) for event in events)
        elif not isinstance(events, str):
            events = str(events)
            
        content = f"{dishes} {events}".lower()
        
        # 提取周数
        week_num = None
        week_match = re.search(r'Week\s+(\d+)', week_id)
        if week_match:
            week_num = int(week_match.group(1))
        
        # 提取日期
        date_match = re.search(r'(\d{1,2})[-/](\d{1,2})', week_id)
        date_str = None
        if date_match:
            date_str = f"{date_match.group(1)}-{date_match.group(2)}"
        
        # 检查checks_once - 特定日期事件
        for check_week, expected in example.get("checks_once", {}).items():
            try:
                check_week_num = int(check_week)
                
                # 检查周次匹配
                if week_num and week_num == check_week_num:
                    match_score = RolloutRewardAgent._soft_match(content, expected)
                    if match_score > 0:
                        reward += match_score
                        break
                        
                # 还要检查日期格式
                expected_date_match = re.search(r'(\d{1,2})[-/](\d{1,2})', expected)
                if expected_date_match and date_str:
                    expected_date = f"{expected_date_match.group(1)}-{expected_date_match.group(2)}"
                    if date_str == expected_date:
                        # 检查内容是否包含相关关键词
                        keywords = RolloutRewardAgent._extract_keywords(expected)
                        if any(keyword in content for keyword in keywords):
                            reward += 0.8  # 日期匹配且有关键词，给0.8分
                            break
            except ValueError:
                pass
        
        # 检查checks_range - 周次范围事件
        if week_num:
            for range_str, expected in example.get("checks_range", {}).items():
                try:
                    # 单个周数
                    check_week = int(range_str)
                    if week_num == check_week:
                        match_score = RolloutRewardAgent._soft_match(content, expected)
                        if match_score > 0:
                            reward += match_score
                            break
                except ValueError:
                    # 周数范围
                    range_match = re.search(r'(\d+)[-](\d+)', range_str)
                    if range_match:
                        start_week = int(range_match.group(1))
                        end_week = int(range_match.group(2))
                        if start_week <= week_num <= end_week:
                            match_score = RolloutRewardAgent._soft_match(content, expected)
                            if match_score > 0:
                                reward += match_score
                                break
        
        # 检查checks_periodic - 周期性事件
        if week_num:
            for period_str, expected in example.get("checks_periodic", {}).items():
                try:
                    check_week = int(period_str)
                    if week_num == check_week:
                        match_score = RolloutRewardAgent._soft_match(content, expected)
                        if match_score > 0:
                            reward += match_score
                            break
                except ValueError:
                    pass
        
        # 关键词匹配奖励 - 从提示中提取关键词
        prompt = example.get("prompt", "")
        if prompt and content:
            relevance_score = RolloutRewardAgent._calculate_relevance(content, prompt)
            reward += relevance_score * 0.5  # 最多贡献0.5分
                    
        return min(reward, 1.0)  # 确保reward不超过1.0

    @staticmethod
    def _get_weekly_diary_reward(week_item, example):
        """针对weekly_plan/diary的评分逻辑"""
        reward = 0.0
        week_id = week_item.get("week_id", "")
        
        # 获取事件内容
        events = week_item.get("events", "")
        if isinstance(events, list):
            if all(isinstance(e, dict) for e in events):
                # 处理事件字典列表
                events_text = " ".join([
                    f"{e.get('event_name', '')} {e.get('event', '')} {e.get('details', '')}" 
                    for e in events if isinstance(e, dict)
                ])
            else:
                events_text = " ".join(str(e) for e in events)
        elif isinstance(events, dict):
            events_text = " ".join(f"{k}: {v}" for k, v in events.items())
        else:
            events_text = str(events)
            
        content = f"{week_id} {events_text}".lower()
        
        # 提取周数
        week_num = None
        week_match = re.search(r'Week\s+(\d+)', week_id)
        if week_match:
            week_num = int(week_match.group(1))
        
        # 提取日期范围
        date_range = re.findall(r'(\w+)\s+(\d+)', week_id)
        
        # 检查checks_once - 重要日期事件(生日等)
        for check_week, expected in example.get("checks_once", {}).items():
            try:
                check_week_num = int(check_week)
                
                # 检查周次匹配
                if week_num and week_num == check_week_num:
                    match_score = RolloutRewardAgent._soft_match(content, expected)
                    if match_score > 0:
                        reward += match_score
                        break
                        
                # 检查是否包含生日关键词
                birthday_match = re.search(r'birthday', expected.lower())
                if birthday_match and "birthday" in content:
                    # 进一步检查具体的家庭成员
                    family_members = ["husband", "wife", "child", "father", "mother"]
                    for member in family_members:
                        if member in expected.lower() and member in content:
                            reward += 0.8  # 生日事件匹配，给0.8分
                            break
            except ValueError:
                pass
        
        # 检查checks_range - 特定周次范围事件(旅行等)
        if week_num:
            for range_str, expected in example.get("checks_range", {}).items():
                try:
                    # 单个周数
                    check_week = int(range_str)
                    if week_num == check_week:
                        keywords = RolloutRewardAgent._extract_keywords(expected)
                        if keywords and any(keyword in content for keyword in keywords):
                            reward += 0.7  # 关键词匹配，给0.7分
                            break
                except ValueError:
                    # 周数范围
                    range_match = re.search(r'(\d+)[-](\d+)', range_str)
                    if range_match:
                        start_week = int(range_match.group(1))
                        end_week = int(range_match.group(2))
                        if start_week <= week_num <= end_week:
                            keywords = RolloutRewardAgent._extract_keywords(expected)
                            if keywords and any(keyword in content for keyword in keywords):
                                reward += 0.7  # 关键词匹配，给0.7分
                                break
        
        # 检查checks_periodic - 周期性事件
        if week_num:
            for period_str, expected in example.get("checks_periodic", {}).items():
                try:
                    check_week = int(period_str)
                    if week_num == check_week:
                        keywords = RolloutRewardAgent._extract_keywords(expected)
                        keyword_matches = sum(1 for keyword in keywords if keyword in content)
                        if keyword_matches:
                            reward += min(0.8, keyword_matches / len(keywords) * 0.8)  # 根据匹配关键词比例给分
                            break
                except ValueError:
                    pass
        
        # 关键词匹配奖励 - 从提示中提取关键词
        prompt = example.get("prompt", "")
        if prompt and content:
            relevance_score = RolloutRewardAgent._calculate_relevance(content, prompt)
            reward += relevance_score * 0.5  # 最多贡献0.5分
                    
        return min(reward, 1.0)  # 确保reward不超过1.0

    @staticmethod
    def _get_reward(step, example, plan_key):
        """
        根据类型自动选择reward函数
        """
        if "floor_plan" in plan_key:
            return RolloutRewardAgent._get_floor_reward(step, example)
        elif "block_plan" in plan_key:
            return RolloutRewardAgent._get_block_reward(step, example)
        elif "weekly_plan" in plan_key and example.get("type", "").lower().startswith("menu"):
            return RolloutRewardAgent._get_weekly_menu_reward(step, example)
        elif "weekly_plan" in plan_key:
            return RolloutRewardAgent._get_weekly_diary_reward(step, example)
        
        # 默认评分逻辑
        gt = example.get("answer", "")  # 假设有标准答案
        text = RolloutRewardAgent._extract_content(step)
        if example.get("reward_type") == "rouge":
            return RolloutRewardAgent._rougel_score(text, gt)
        else:
            return RolloutRewardAgent._acc_score(text, gt)

    @staticmethod
    def _apply_fallback_rewards(plan, example):
        """当所有奖励为零时的后备奖励计算策略"""
        rewards = []
        prompt = example.get('prompt', '')
        
        for step in plan:
            # 提取内容
            content = RolloutRewardAgent._extract_content(step)
            if not content:
                rewards.append(0.1)  # 空内容给最小分
                continue
                
            # 1. 内容长度 - 更长、更详细的内容可能更有价值
            length_reward = min(len(content) / 500, 0.5)  # 最高0.5分
            
            # 2. 结构完整性 - 检查是否包含关键字段
            structure_reward = sum(0.1 for field in ['events', 'dishes', 'purpose', 'use', 'week_id', 'floor_id', 'block_id'] 
                                  if field in step and step[field]) 
            
            # 3. 与提示相关性
            relevance_words = RolloutRewardAgent._extract_keywords(prompt)
            relevance_score = sum(1 for word in relevance_words 
                                 if word in content.lower()) / max(len(relevance_words), 1)
            relevance_reward = relevance_score * 0.3  # 最高0.3分
            
            # 组合奖励
            total_reward = length_reward + structure_reward + relevance_reward
            # 确保有区分度，让奖励在0.1-0.9之间分布
            normalized_reward = 0.1 + (total_reward * 0.8)
            rewards.append(min(0.9, normalized_reward))  # 限制最高分为0.9
        
        # 确保奖励有区分度
        if len(rewards) > 1:
            min_r, max_r = min(rewards), max(rewards)
            if max_r - min_r < 0.2:  # 如果区分度太小
                # 拉伸奖励分布
                rewards = [0.1 + (r - min_r) / (max_r - min_r + 1e-6) * 0.8 for r in rewards]
        
        return rewards

    @staticmethod
    def rollout_rewards(example):
        """
        对有效子计划进行rollout奖励计算（同步版本，便于和gen_dpo_data.py风格一致）
        """
        if not (example.get("is_valid") or example.get("filter_response", "").upper() == "YES"):
            logging.info(f"Example {example.get('id', '')} skipped: not valid.")
            return example

        # 确定计划类型
        plan_key = None
        candidate_keys = [
            "weekly_plan", "floor_plan", "block_plan", 
            "filtered_weekly_plan", "filtered_floor_plan", "filtered_block_plan"
        ]
        
        for key in candidate_keys:
            if key in example and isinstance(example[key], list) and example[key]:
                plan_key = key
                break
                
        # 特殊处理filtered_*_plan的情况
        if plan_key and plan_key.startswith("filtered_"):
            filtered_plan = example[plan_key]
            if isinstance(filtered_plan, dict) and "original_plan" in filtered_plan:
                plan = filtered_plan["original_plan"]
            else:
                plan = filtered_plan
        elif plan_key:
            plan = example[plan_key]
        else:
            logging.warning(f"Example {example.get('id', '')} has no valid plan to rollout.")
            return example

        if not isinstance(plan, list):
            logging.warning(f"Plan for {example.get('id', '')} is not a list.")
            return example

        rewards = []
        processed_plan = []  # 新增：存储处理后的计划
        
        # 计算每个step的reward
        for step in plan:
            # 确保step是字典类型
            if not isinstance(step, dict):
                # 如果是字符串，转换为字典
                if isinstance(step, str):
                    step = {"content": step}
                else:
                    # 如果既不是字典也不是字符串，尝试转换为字符串
                    try:
                        step = {"content": str(step)}
                    except:
                        continue  # 跳过无法处理的步骤
            
            # 确保step是一个副本，避免修改原始数据
            step_dict = dict(step)
            
            reward = RolloutRewardAgent._get_reward(step_dict, example, plan_key)
            step_dict["reward"] = reward
            rewards.append(reward)
            processed_plan.append(step_dict)
            
        # 更新plan为处理后的版本
        plan = processed_plan
            
        # 检查是否所有奖励都为零
        if rewards and all(r == 0 for r in rewards):
            logging.info(f"All rewards are zero for {example.get('id', '')}, applying fallback rewards.")
            fallback_rewards = RolloutRewardAgent._apply_fallback_rewards(plan, example)
            
            # 更新奖励
            for i, r in enumerate(fallback_rewards):
                if i < len(plan):
                    plan[i]["reward"] = r
                    rewards[i] = r

        # 计算最大/最小reward的step
        if not rewards:
            return example
            
        max_reward = max(rewards)
        min_reward = min(rewards)
        
        # 找出所有最大/最小reward的step
        max_steps = [step for step in plan if step.get("reward") == max_reward]
        min_steps = [step for step in plan if step.get("reward") == min_reward]

        example["max_reward_steps"] = max_steps
        example["min_reward_steps"] = min_steps

        # 标准化优势值
        rewards_np = np.array(rewards, dtype=np.float32)
        mean = rewards_np.mean()
        std = rewards_np.std() if rewards_np.std() > 1e-6 else 1.0
        
        for step in plan:
            step["advantage"] = float((step.get("reward", 0) - mean) / std)

        # 更新plan中的steps
        if plan_key.startswith("filtered_"):
            if isinstance(example[plan_key], dict) and "original_plan" in example[plan_key]:
                example[plan_key]["original_plan"] = plan
            else:
                example[plan_key] = plan
        else:
            example[plan_key] = plan

        # 保存奖励统计
        example["reward_stats"] = {
            "mean": float(mean),
            "std": float(std),
            "max": float(max_reward),
            "min": float(min_reward),
            "positive_rewards": sum(1 for r in rewards if r > 0),
            "total_steps": len(rewards),
            "using_fallback": all(r == 0 for r in rewards)
        }
        
        return example

    @staticmethod
    async def async_rollout_rewards(model, example, semaphore):
        """异步版本的rollout_rewards，兼容异步调用"""
        return RolloutRewardAgent.rollout_rewards(example)