import json
import os
import re
import pickle
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Union

# --- 数据结构定义 ---

@dataclass
class SubExperience:
    """存储一条细分的、可独立检索的监督经验"""
    parent_record_index: int
    experience_type: str  # e.g., 'strategic_heuristic', 'failure_pattern'
    text: str
    embedding: np.ndarray = field(repr=False) # 不在打印对象时显示embedding

@dataclass
class SupervisorWorkflowInstance:
    """存储一个包含多维度监督经验的完整工作流实例"""
    record_index: int
    question: str
    true_answer: str
    supervisor_experience: Dict[str, List[str]]
    # 用于第一阶段检索的嵌入
    question_embedding: Optional[np.ndarray] = field(default=None, repr=False)

# --- 核心检索管理器 ---

class SupervisorKBManager:
    def __init__(self, json_file_path: str):
        print("Initializing SupervisorKBManager...")
        # 定义缓存文件的路径
        cache_path = json_file_path + ".cache.pkl"

        # 检查缓存是否存在
        if os.path.exists(cache_path):
            print(f"Cache found at '{cache_path}'. Loading pre-computed embeddings...")
            self._load_from_cache(cache_path)
            print("Successfully loaded from cache.")
        else:
            print("No cache found. Building embeddings from scratch... (This may take a few minutes)")
            # 只有在没有缓存时，才加载模型并构建索引
            print("Loading embedding model (BAAI/bge-base-en-v1.5)...")
            self.embedding_model = SentenceTransformer('BAAI/bge-base-en-v1.5')
            self.workflows: Dict[int, SupervisorWorkflowInstance] = {}
            self.sub_experiences: List[SubExperience] = []
            self._load_data_and_build_indices(json_file_path)
            
            # 构建完成后，保存到缓存文件以备下次使用
            self._save_to_cache(cache_path)
            print(f"Embeddings built and saved to cache at '{cache_path}'.")
    
    # 从缓存文件加载数据
    def _load_from_cache(self, cache_path: str):
        """从 pickle 缓存文件中加载工作流和子经验数据"""
        with open(cache_path, 'rb') as f:
            cached_data = pickle.load(f)
            self.workflows = cached_data['workflows']
            self.sub_experiences = cached_data['sub_experiences']
            # 注意：embedding_model 不需要加载，因为它只在检索时使用，且检索时传入的文本会被动态编码
            self.embedding_model = SentenceTransformer('BAAI/bge-base-en-v1.5')

    # 将数据保存到缓存文件的方法
    def _save_to_cache(self, cache_path: str):
        """将已处理的工作流和子经验数据保存到 pickle 缓存文件"""
        data_to_cache = {
            'workflows': self.workflows,
            'sub_experiences': self.sub_experiences
        }
        with open(cache_path, 'wb') as f:
            pickle.dump(data_to_cache, f)


    # MODIFICATION: 新增一个强大的、用于解析多种格式经验的工具函数
    def _parse_experience_value(self, value: Union[List[str], str]) -> List[str]:
        """
        智能解析 supervisor_experience 中的值。
        无论输入是列表还是多种格式的字符串，都统一返回一个干净的字符串列表。
        """
        # 情况1: 输入已经是列表 (理想情况)
        if isinstance(value, list):
            return [str(item).strip() for item in value if str(item).strip()]

        # 情况2: 输入是字符串，需要解析
        if isinstance(value, str):
            # 尝试按数字编号 (e.g., "1. xxx", "(2) xxx") 分割
            # 这个正则表达式可以匹配 "1.", "(1)", "2.", "(2)" 等模式
            if re.search(r'\s*\(\d+\)\s*|\s*\d+\.\s*', value):
                # 用找到的编号模式作为分隔符来切分字符串
                items = re.split(r'\s*\(\d+\)\s*|\s*\d+\.\s*', value)
                # 切分后第一个元素通常是无用前缀，例如 "Common pitfalls include: "
                return [item.strip() for item in items if item.strip()]

            # 尝试按换行符和 - 或 • 分割 (e.g., "- xxx\n- yyy" 或 "• xxx\n• yyy")
            if '•' in value or value.strip().startswith('-'):
                # 统一替换分隔符为换行符，然后按换行符切分
                items = value.replace('•', '\n-').strip().split('\n')
                # 去除每项开头的 '-' 和多余空格
                return [item.lstrip('-').strip() for item in items if item.strip()]

            # 兜底方案：如果以上都不是，则简单地按换行符分割
            return [line.strip() for line in value.split('\n') if line.strip()]
        
        # 如果是其他类型或为空，返回空列表
        return []


    def _load_data_and_build_indices(self, json_file_path: str):
        print(f"Loading and processing knowledge base from: {json_file_path}")
        with open(json_file_path, 'r', encoding='utf-8') as f:
            raw_data = json.load(f)

        all_questions = []
        sub_exp_texts_to_embed = []
        temp_workflows = {}
        temp_sub_experiences = []

        for i, record in enumerate(tqdm(raw_data, desc="Preparing data")):
            question = record.get('question', '')
            
            instance = SupervisorWorkflowInstance(
                record_index=i,
                question=question,
                true_answer=record.get('true_answer', ''),
                supervisor_experience=record.get('supervisor_experience', {})
            )
            temp_workflows[i] = instance
            all_questions.append(question)

            for exp_type, raw_value in instance.supervisor_experience.items():
                # MODIFICATION: 调用新的智能解析函数来处理不同格式的经验
                parsed_exp_list = self._parse_experience_value(raw_value)
                
                for exp_text in parsed_exp_list:
                    sub_exp_texts_to_embed.append(exp_text)
                    temp_sub_experiences.append({
                        'parent_index': i,
                        'type': exp_type,
                        'text': exp_text
                    })

        print("Generating embeddings for all questions (Stage 1)...")
        question_embeddings = self.embedding_model.encode(
            all_questions, show_progress_bar=True, convert_to_numpy=True
        )

        print("Generating embeddings for all sub-experiences (Stage 3)...")
        sub_exp_embeddings = self.embedding_model.encode(
            sub_exp_texts_to_embed, show_progress_bar=True, convert_to_numpy=True
        )

        for i, instance in enumerate(temp_workflows.values()):
            instance.question_embedding = question_embeddings[i]
            self.workflows[instance.record_index] = instance
        
        for i, temp_exp in enumerate(temp_sub_experiences):
            self.sub_experiences.append(SubExperience(
                parent_record_index=temp_exp['parent_index'],
                experience_type=temp_exp['type'],
                text=temp_exp['text'],
                embedding=sub_exp_embeddings[i]
            ))
            
        print(f"Indexing complete. Loaded {len(self.workflows)} workflows and {len(self.sub_experiences)} sub-experiences.")

    def search_experiences(
        self,
        global_task: str,
        local_task: str,
        dynamic_context: str, # step task
        target_experience_type: str,
        top_k_workflows: int = 3,
        top_n_points: int = 3
    ) -> List[Dict]:
        """
        执行四阶段混合上下文检索，为 Supervisor 提供最相关的经验。
        """
        print("\n--- Starting Hybrid 4-Stage Experience Retrieval ---")

        # --- 阶段一: 根据 Global Task 检索 Top-K 相关工作流  ---
        global_task_embedding = self.embedding_model.encode(global_task, convert_to_numpy=True)
        all_workflow_instances = list(self.workflows.values())
        question_embeddings = np.array([wf.question_embedding for wf in all_workflow_instances])
        
        similarities_stage1 = cosine_similarity([global_task_embedding], question_embeddings)[0]
        top_workflow_indices = similarities_stage1.argsort()[-top_k_workflows:][::-1]
        
        retrieved_record_indices = {all_workflow_instances[i].record_index for i in top_workflow_indices}
        # print(f"Stage 1: Retrieved {len(retrieved_record_indices)} relevant workflows based on global task.")

        # --- 阶段二: 根据经验类型筛选候选经验 ---
        if not target_experience_type:
            # print(f"Stage 2: No specific experience type for targeted supervision '{target_experience_type}'. Skipping.")
            return []

        candidate_sub_experiences = [
            exp for exp in self.sub_experiences
            if exp.parent_record_index in retrieved_record_indices and exp.experience_type == target_experience_type
        ]
        
        if not candidate_sub_experiences:
            # print(f"Stage 2: No candidate sub-experiences of type '{target_experience_type}' found in retrieved workflows.")
            return []
        
        # print(f"Stage 2: Filtered to {len(candidate_sub_experiences)} candidate sub-experiences of type '{target_experience_type}'.")

        # --- 阶段三: 并行计算与 local_task 和 dynamic_context 的相似度 ---
        # print("Stage 3: Performing parallel similarity scoring against goal and situation.")
        local_task_embedding = self.embedding_model.encode(local_task, convert_to_numpy=True)
        dynamic_context_embedding = self.embedding_model.encode(dynamic_context, convert_to_numpy=True)
        
        candidate_embeddings = np.array([exp.embedding for exp in candidate_sub_experiences])
        
        similarities_goal = cosine_similarity([local_task_embedding], candidate_embeddings)[0]  # 相似度A: 目标相关性
        similarities_situation = cosine_similarity([dynamic_context_embedding], candidate_embeddings)[0]    # 相似度B: 情境相关性

        # --- 阶段四: 混合加权排名 ---
        # print("Stage 4: Calculating hybrid scores for final ranking.")
        
        W_LOCAL_TASK = 0.3  # 目标权重
        W_DYNAMIC_CONTEXT = 0.7  # 当前情境权重

        hybrid_scores = (W_LOCAL_TASK * similarities_goal) + (W_DYNAMIC_CONTEXT * similarities_situation)
        
        # 根据混合分数找到Top-N的索引
        top_indices = hybrid_scores.argsort()[-top_n_points:][::-1]

        final_results = [{
            'experience': candidate_sub_experiences[i].text,
            'score': float(hybrid_scores[i]), # 使用混合分数作为最终得分
            'type': candidate_sub_experiences[i].experience_type,
            'source_question': self.workflows[candidate_sub_experiences[i].parent_record_index].question
        } for i in top_indices]
        
        print(f"--- [KB-Retrieval] Retrieved top {len(final_results)} sub-experiences based on hybrid ranking.")
        print("--- Retrieval Complete ---")
        return final_results


# --- 使用示例 ---
if __name__ == '__main__':
    # 假设您的 supervisor_kb.json 文件已存在
    KB_FILE_PATH = "/home/ofo/project_workflow_auto_generation/smolagents/examples/open_deep_research/data/supervisor_database.json"

    # 创建一些示例数据，如果文件不存在
    if not os.path.exists(KB_FILE_PATH):
        print(f"'{KB_FILE_PATH}' not found. Creating a dummy file for demonstration.")
        dummy_data = [
            {
                "question": "In a country, 14.699 million people participated in a survey... What was the unemployment rate for people with disabilities?",
                "true_answer": "9.1%",
                "supervisor_experience": {
                  "strategic_heuristic": ["Lock the scope before searching: identify the exact country and year from the question."],
                  "failure_pattern": ["Inverse-rate fallacy: Confusing employment rate with unemployment rate, and calculating one from the other incorrectly without knowing the labor force participation rate."],
                  "corrective_action": ["Impose a pre-search checklist: require the agent to list all numerical values and their definitions from the prompt before starting calculations."],
                  "verification_checkpoint": ["Trigger secondary verification for any final percentage calculation that involves mixing data from different years (2013 vs 2011 in this case)."]
                }
            },
            {
                "question": "Who is the Silicon Valley figure associated with AI who was removed from a board position and is now starting a new venture?",
                "true_answer": "Sam Altman",
                "supervisor_experience": {
                  "strategic_heuristic": ["When dealing with recent events, prioritize news sources from the last few months to get the most current information."],
                  "failure_pattern": ["Using outdated search results that describe a person's old role without capturing recent, major changes like a departure or new company launch."],
                  "corrective_action": ["If an agent's search results seem outdated, a good intervention is to inject a time-filter into its next search query, e.g., 'search for X after date Y'."],
                  "verification_checkpoint": ["Claims about a person starting a 'new venture' are often based on rumors. This should trigger verification to find an official announcement or a report from a top-tier news outlet."]
                }
            }
        ]
        with open(KB_FILE_PATH, 'w', encoding='utf-8') as f:
            json.dump(dummy_data, f, indent=2)

    # 1. 初始化管理器，它会自动加载数据并构建索引
    kb_manager = SupervisorKBManager(json_file_path=KB_FILE_PATH)
    
    # 2. 模拟一次监督场景
    current_global_task = "What was the unemployment rate for disabled people in a specific country in 2011, given some unrelated 2013 survey data?"
    current_local_task = "The agent is trying to calculate the unemployment rate (9.1%) from the employment rate (24%)."
    current_supervision_type = "error_analysis"

    # 3. 执行三阶段检索
    retrieved_experiences = kb_manager.search_experiences(
        global_task=current_global_task,
        local_task=current_local_task,
        supervision_type=current_supervision_type
    )

    # 4. 打印结果
    print("\n\n--- Top Retrieved Experiences for Supervisor ---")
    if retrieved_experiences:
        for exp in retrieved_experiences:
            print(f"Score: {exp['score']:.4f}")
            print(f"Type: {exp['type']}")
            print(f"Experience: {exp['experience']}")
            print(f"Source Question: {exp['source_question'][:80]}...")
            print("-" * 20)
    else:
        print("No relevant experiences found.")