import logging
import asyncio
from llms.llms import async_call_llm
from json_repair import repair_json
import json
import re

class GroupingAgent:
    """
    负责将生成的内容按任务位置分组，构建 step-level group
    每个 group 对应一个具体的上下文状态（例如同一层楼、同一周的任务）
    该状态下的所有生成版本构成 step-level group
    """
    
    @staticmethod
    async def async_create_groups(model, example, semaphore):
        """根据内容类型分发到不同的处理函数"""
        if example["type"] == "Week":
            example = await GroupingAgent.async_group_weeks(model, example, semaphore)
        elif example["type"] == "Floor":
            example = await GroupingAgent.async_group_floors(model, example, semaphore)
        elif example["type"] == "Menu Week":
            example = await GroupingAgent.async_group_menu_weeks(model, example, semaphore)
        elif example["type"] == "Block":
            example = await GroupingAgent.async_group_blocks(model, example, semaphore)
        return example
    
    @staticmethod
    async def async_group_weeks(model, example, semaphore):
        """将周记内容按周分组"""
        logging.info("Grouping weekly contents")
        
        # 从weekly_plan中提取所有周的信息
        weeks = example.get('weekly_plan', [])
        if not weeks:
            logging.error("No weekly plan found")
            return example
        
        # 创建groups字典，key为week_id，value为该周的所有内容
        groups = {}
        for week in weeks:
            week_id = week.get('week_id')
            if not week_id:
                continue
                
            # 初始化该周的group
            if week_id not in groups:
                groups[week_id] = {
                    'week_id': week_id,
                    'events': week.get('events', ''),
                    'versions': []
                }
            
            # 添加该周的diary_entry作为一个版本
            if 'diary_entry' in week:
                groups[week_id]['versions'].append({
                    'content': week['diary_entry'],
                    'metadata': {'source': 'initial_generation'}
                })
        
        # 将groups转换为列表形式
        example['step_groups'] = list(groups.values())
        
        # 使用LLM为每个组生成一个描述性总结
        tasks = []
        for group in example['step_groups']:
            tasks.append(GroupingAgent._generate_group_summary(model, group, example['prompt'], semaphore))
        
        summaries = await asyncio.gather(*tasks)
        
        # 更新每个组的摘要
        for i, summary in enumerate(summaries):
            if i < len(example['step_groups']):
                example['step_groups'][i]['summary'] = summary
        
        return example
    
    @staticmethod
    async def async_group_floors(model, example, semaphore):
        """将楼层内容按楼层分组"""
        logging.info("Grouping floor contents")
        
        # 从floor_plan中提取所有楼层信息
        floors = example.get('floor_plan', [])
        if not floors:
            logging.error("No floor plan found")
            return example
        
        # 创建groups字典，key为floor_id，value为该楼层的所有内容
        groups = {}
        for floor in floors:
            floor_id = floor.get('floor_id')
            if not floor_id:
                continue
                
            # 初始化该楼层的group
            if floor_id not in groups:
                groups[floor_id] = {
                    'floor_id': floor_id,
                    'purpose': floor.get('purpose', ''),
                    'versions': []
                }
            
            # 添加该楼层的plan作为一个版本
            if 'plan' in floor:
                groups[floor_id]['versions'].append({
                    'content': floor['plan'],
                    'metadata': {'source': 'initial_generation'}
                })
        
        # 将groups转换为列表形式
        example['step_groups'] = list(groups.values())
        
        # 使用LLM为每个组生成一个描述性总结
        tasks = []
        for group in example['step_groups']:
            tasks.append(GroupingAgent._generate_group_summary(model, group, example['prompt'], semaphore))
        
        summaries = await asyncio.gather(*tasks)
        
        # 更新每个组的摘要
        for i, summary in enumerate(summaries):
            if i < len(example['step_groups']):
                example['step_groups'][i]['summary'] = summary
        
        return example
    
    @staticmethod
    async def async_group_menu_weeks(model, example, semaphore):
        """将菜单内容按周分组"""
        logging.info("Grouping menu contents")
        
        # 从weekly_plan中提取所有周的菜单信息
        weeks = example.get('weekly_plan', [])
        if not weeks:
            logging.error("No weekly menu plan found")
            return example
        
        # 创建groups字典，key为week_id，value为该周的所有菜单内容
        groups = {}
        for week in weeks:
            week_id = week.get('week_id')
            if not week_id:
                continue
                
            # 初始化该周的group
            if week_id not in groups:
                groups[week_id] = {
                    'week_id': week_id,
                    'dishes': week.get('dishes', ''),
                    'versions': []
                }
            
            # 添加该周的week_menu作为一个版本
            if 'week_menu' in week:
                groups[week_id]['versions'].append({
                    'content': week['week_menu'],
                    'metadata': {'source': 'initial_generation'}
                })
        
        # 将groups转换为列表形式
        example['step_groups'] = list(groups.values())
        
        # 使用LLM为每个组生成一个描述性总结
        tasks = []
        for group in example['step_groups']:
            tasks.append(GroupingAgent._generate_group_summary(model, group, example['prompt'], semaphore))
        
        summaries = await asyncio.gather(*tasks)
        
        # 更新每个组的摘要
        for i, summary in enumerate(summaries):
            if i < len(example['step_groups']):
                example['step_groups'][i]['summary'] = summary
        
        return example
    
    @staticmethod
    async def async_group_blocks(model, example, semaphore):
        """将城市区块内容按区块分组"""
        logging.info("Grouping block contents")
        
        # 从block_plan中提取所有区块信息
        blocks = example.get('block_plan', [])
        if not blocks:
            logging.error("No block plan found")
            return example
        
        # 创建groups字典，key为block_id，value为该区块的所有内容
        groups = {}
        for block in blocks:
            block_id = block.get('block_id')
            if not block_id:
                continue
                
            # 初始化该区块的group
            if block_id not in groups:
                groups[block_id] = {
                    'block_id': block_id,
                    'use': block.get('use', ''),
                    'versions': []
                }
            
            # 添加该区块的plan作为一个版本
            if 'plan' in block:
                groups[block_id]['versions'].append({
                    'content': block['plan'],
                    'metadata': {'source': 'initial_generation'}
                })
        
        # 将groups转换为列表形式
        example['step_groups'] = list(groups.values())
        
        # 使用LLM为每个组生成一个描述性总结
        tasks = []
        for group in example['step_groups']:
            tasks.append(GroupingAgent._generate_group_summary(model, group, example['prompt'], semaphore))
        
        summaries = await asyncio.gather(*tasks)
        
        # 更新每个组的摘要
        for i, summary in enumerate(summaries):
            if i < len(example['step_groups']):
                example['step_groups'][i]['summary'] = summary
        
        return example
    
    @staticmethod
    async def _generate_group_summary(model, group, user_prompt, semaphore):
        """为每个group生成描述性摘要"""
        # 确定group的类型和ID
        group_type = ""
        group_id = ""
        
        if 'week_id' in group:
            group_type = "week"
            group_id = group['week_id']
            context = group.get('events', '') if 'events' in group else group.get('dishes', '')
        elif 'floor_id' in group:
            group_type = "floor"
            group_id = group['floor_id']
            context = group.get('purpose', '')
        elif 'block_id' in group:
            group_type = "block"
            group_id = group['block_id']
            context = group.get('use', '')
        else:
            return "No summary available"
        
        # 获取该组的所有内容版本
        versions_content = ""
        for i, version in enumerate(group.get('versions', [])):
            versions_content += f"Version {i+1}: {version.get('content', '')}\n\n"
        
        prompt = f"""
You are an expert analyst. Create a concise summary (around 30-50 words) for the following {group_type} content:

{group_type.title()} ID: {group_id}
{group_type.title()} Context: {context}
User Requirements: {user_prompt}

Content Versions:
{versions_content}

Create a concise summary that captures the essence of what this {group_type} represents based on its context and content versions.
Return only the summary text, without any prefixes or additional explanations.
"""

        try:
            async with semaphore:
                response = await async_call_llm(model, prompt)
            # 清理响应，只保留摘要文本
            summary = response.strip()
            # 如果摘要太长，截断
            if len(summary) > 200:
                summary = summary[:197] + "..."
            return summary
        except Exception as e:
            logging.error(f"Error generating summary: {e}")
            return f"Summary for {group_id}"