"""
Kaggle竞赛知识检索与提取系统
从Kaggle竞赛解决方案中提取可复用的技巧，并提供语义检索功能
"""

import json
import os
import re
from pathlib import Path
from typing import List, Dict, Optional, Any
from datetime import datetime
import traceback

import numpy as np
from schema import ExpertFeedbackItem, FeedbackSource
from knowledgeBase.knowledge import ExpertFeedbackStore
import logger


class KaggleKnowledgeRetriever:
    """
    从Kaggle竞赛中检索和提取技巧知识
    
    核心功能:
    1. 语义检索：基于任务描述找到相似的竞赛
    2. 技巧提取：从竞赛解决方案中提炼可复用的技术经验
    3. 知识存储：将提取的技巧存入知识库
    
    Args:
        task_desc: 目标任务描述
        embedding_model: 句子编码模型名称
        kaggle_corpus_path: Kaggle竞赛语料库路径
        feedback_store: 知识库存储实例
    """

    def __init__(
        self,
        task_desc: str,
        kaggle_corpus_path: str = "knowledgeBase/case_library.jsonl",
        feedback_store: Optional[ExpertFeedbackStore] = None,
    ):
        self.task_desc = task_desc
        self.kaggle_corpus_path = kaggle_corpus_path
        self.feedback_store = feedback_store or ExpertFeedbackStore()
        
        # 延迟加载模型以节省内存
        self._embedding_model = None
        self._corpus = None
        self._embedded_corpus = None
        self.index = None  # FAISS 索引

    @property
    def embedding_model(self):
        """延迟加载 Sentence Transformer 模型"""
        if self._embedding_model is None:
            try:
                from sentence_transformers import SentenceTransformer
                import torch
                device = "cuda" if torch.cuda.is_available() else "cpu"
                # 使用轻量级模型
                self._embedding_model = SentenceTransformer(
                    "BAAI/bge-m3"
                ).to(device)
                logger.info(f"已加载嵌入模型 (device: {device})")
            except ImportError:
                logger.error("请安装 sentence-transformers: pip install sentence-transformers")
                raise
        return self._embedding_model

    @property
    def corpus(self):
        """延迟加载语料库"""
        if self._corpus is None:
            self._corpus = self._load_corpus()
        return self._corpus

    def _load_corpus(self) -> List[Dict]:
        """从 JSONL 文件加载 Kaggle 竞赛语料库"""
        corpus_path = Path(self.kaggle_corpus_path)
        if not corpus_path.exists():
            logger.warning(f"语料库文件不存在: {corpus_path}")
            return []
        
        corpus = []
        with open(corpus_path, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    obj = json.loads(line)
                    corpus.append(obj)
                except Exception as e:
                    logger.warning(f"解析语料库行失败: {e}")
        
        logger.info(f"已加载 {len(corpus)} 个竞赛案例")
        return corpus

    def _build_index(self):
        """构建或加载 FAISS 向量索引"""
        if not self.corpus:
            return

        try:
            import faiss
            import torch
        except ImportError:
            logger.error("请安装 faiss-cpu 或 faiss-gpu: pip install faiss-cpu")
            return

        # 索引文件路径
        index_path = Path(self.kaggle_corpus_path).with_suffix('.faiss')

        # 尝试加载现有索引
        index_loaded = False
        if index_path.exists():
            try:
                logger.info(f"正在加载 FAISS 索引: {index_path}")
                self.index = faiss.read_index(str(index_path))
                index_loaded = True
            except Exception as e:
                logger.warning(f"加载索引失败: {e}，将重新构建")

        if index_loaded:
            # 增量更新逻辑
            if self.index.ntotal == len(self.corpus):
                logger.info(f"成功加载索引，包含 {self.index.ntotal} 个向量")
                return
            elif self.index.ntotal < len(self.corpus):
                logger.info(f"索引大小 ({self.index.ntotal}) 小于语料库 ({len(self.corpus)})，正在进行增量更新...")
                start_idx = self.index.ntotal
                new_items = self.corpus[start_idx:]
                
                # 为新项目计算嵌入
                texts = []
                for comp in new_items:
                    text = comp.get("description") or comp.get("introduction", "")
                    text += " tags: " + " ".join(comp.get("tags", []))
                    texts.append(text[:1000])
                
                if texts:
                    embeddings = self.embedding_model.encode(texts, convert_to_tensor=True, show_progress_bar=True)
                    embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
                    embeddings_np = embeddings.cpu().numpy()
                    self.index.add(embeddings_np)
                    
                    # 保存更新后的索引
                    try:
                        faiss.write_index(self.index, str(index_path))
                        logger.info(f"FAISS 索引已增量更新并保存，当前包含 {self.index.ntotal} 个向量")
                    except Exception as e:
                        logger.warning(f"保存更新后的索引失败: {e}")
                return
            else:
                # 这种情况通常发生在删除了部分案例后。
                # 由于 FAISS 索引与语料库列表是按位置强耦合的，ID 错位会导致检索结果错误。
                # 因此必须全量重新构建索引。
                logger.warning(f"索引大小 ({self.index.ntotal}) 大于语料库 ({len(self.corpus)})，检测到案例删除或数据不一致")
                logger.warning("为防止ID错位，将执行全量索引重新构建...")
                self.index = None # 指向 None，确保后续逻辑创建一个新的索引对象

        logger.info("正在构建 FAISS 索引...")
        
        # 1. 准备文本
        texts = []
        for comp in self.corpus:
            text = comp.get("description") or comp.get("introduction", "")
            text += " tags: " + " ".join(comp.get("tags", []))
            texts.append(text[:1000]) # 截断过长的文本
            
        # 2. 计算嵌入
        embeddings = self.embedding_model.encode(texts, convert_to_tensor=True, show_progress_bar=True)
        
        # 3. 归一化 (用于余弦相似度)
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        
        # 4. 转换为 numpy
        embeddings_np = embeddings.cpu().numpy()
        
        # 5. 创建索引
        dimension = embeddings_np.shape[1]
        self.index = faiss.IndexFlatIP(dimension)
        self.index.add(embeddings_np)
        
        # 保存索引
        try:
            faiss.write_index(self.index, str(index_path))
            logger.info(f"FAISS 索引已保存至: {index_path}")
        except Exception as e:
            logger.warning(f"保存索引失败: {e}")
            
        logger.info(f"FAISS 索引构建完成，包含 {self.index.ntotal} 个向量")

    def search_similar_competitions(self, top_k: int = 2, threshold: float = 0.65) -> List[Dict]:
        """
        基于任务描述搜索最相似的 Kaggle 竞赛 (使用 FAISS)
        
        Args:
            top_k: 返回的最相似竞赛数量
            threshold: 相似度阈值，低于此值的将被过滤
            
        Returns:
            相似竞赛列表，按相似度排序
        """
        if not self.corpus:
            logger.warning("语料库为空，无法检索")
            return []
            
        if self.index is None:
            self._build_index()
            
        if self.index is None:
            logger.error("索引构建失败，无法检索")
            return []
        
        logger.info(f"正在检索与任务相似竞赛案例...")
        
        # 1. 编码查询
        import torch
        query_embedding = self.embedding_model.encode(self.task_desc, convert_to_tensor=True)
        query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=0)
        query_np = query_embedding.cpu().numpy().reshape(1, -1)
        
        # 2. 搜索 (多搜索一些以备过滤)
        search_k = max(top_k * 2, 10)
        D, I = self.index.search(query_np, search_k)
        
        # 3. 格式化结果并过滤
        results = []
        for i, idx in enumerate(I[0]):
            if idx == -1: continue
            similarity = float(D[0][i])
            
            if similarity < threshold:
                break
                
            competition = self.corpus[idx]
            
            results.append({
                "index": int(idx),
                "title": competition.get("title", f"Competition_{idx}"),
                "similarity": similarity,
                "competition": competition
            })
            
            if len(results) >= top_k:
                break
        
        
        return results

    def add_case(self, case_data: Dict, save_index: bool = True):
        """
        手动添加一个案例并更新索引
        注意：此前应确保 case_data 已经写入了 disk 的 jsonl 文件，
             或者在此方法后自行处理写入，否则下次加载会不一致。
        """
        # 确保基础索引已就绪
        if self.index is None:
            self._build_index()
            
        # 1. 更新内存语料库
        if self._corpus is not None:
             self._corpus.append(case_data)
             
        # 2. 计算 Embedding 并添加
        text = case_data.get("description") or case_data.get("introduction", "")
        text += " tags: " + " ".join(case_data.get("tags", []))
        text = text[:1000]
        
        import torch
        import faiss
        
        embedding = self.embedding_model.encode(text, convert_to_tensor=True)
        embedding = torch.nn.functional.normalize(embedding, p=2, dim=0)
        embedding_np = embedding.cpu().numpy().reshape(1, -1)
        
        self.index.add(embedding_np)
        logger.info(f"已手动添加案例至索引: {case_data.get('title')}, 当前索引大小: {self.index.ntotal}")
        
        if save_index:
            index_path = Path(self.kaggle_corpus_path).with_suffix('.faiss')
            try:
                faiss.write_index(self.index, str(index_path))
                logger.debug(f"索引已保存至: {index_path}")
            except Exception as e:
                logger.warning(f"保存更新后的索引失败: {e}")


# 便捷函数
def retrieve_kaggle_knowledge(
    task_description: str,
    top_k: int = 2,
    threshold: float = 0.65,
    corpus_path: str = "knowledgeBase/case_library.jsonl"
) -> List[Dict]:
    """
    从 Kaggle 竞赛中检索相似案例的便捷函数
    
    Args:
        task_description: 当前任务描述
        top_k: 检索前 K 个相似竞赛
        threshold: 相似度阈值
        corpus_path: Kaggle 语料库路径
    
    Returns:
        相似竞赛列表
    """
    logger.debug(f"开始检索 Kaggle 竞赛知识，任务描述: {task_description}")
    retriever = KaggleKnowledgeRetriever(
        task_desc=task_description,
        kaggle_corpus_path=corpus_path
    )
    related_cases = retriever.search_similar_competitions(top_k=top_k, threshold=threshold)
    logger.debug(f"检索任务描述: {task_description},\n检索案例结果：{format_kaggle_competitions(related_cases)}")
    return related_cases


def add_new_case_to_index(corpus_path: str = "knowledgeBase/case_library.jsonl"):
    """
    触发知识库索引的增量更新。
    只需实例化 Retriever 并尝试访问其索引，
    内部的 _build_index 方法会自动检测 JSONL 和 FAISS 的数量差异并进行增量更新。
    
    Args:
        corpus_path: 你的 case_library.jsonl 路径
    """
    # 实例化将自动触发 _build_index (如果访问 index 属性)
    # 但由于 lazy loading，我们需要显式调用一下
    try:
        retriever = KaggleKnowledgeRetriever(task_desc="update_index", kaggle_corpus_path=corpus_path)
        # 触发构建/加载/更新
        if retriever.index is not None:
            logger.info(f"知识库索引检查完毕，当前大小: {retriever.index.ntotal}")
    except Exception as e:
        logger.error(f"更新索引时出错: {e}")


# 对检索到的竞赛列表转化为字符串描述
def format_kaggle_competitions(competitions: List[Dict]) -> str:
    """
    将检索到的 Kaggle 竞赛列表格式化为字符串描述
    
    Args:
        competitions: 检索到的竞赛列表
    
    Returns:
        格式化后的字符串
    """
    if not competitions:
        return "未找到相关竞赛案例。"
    
    lines = []
    for i, comp in enumerate(competitions, 1):
        title = comp.get("competition", {}).get("title", f"Competition_{i}")
        solutions = comp.get("competition", {}).get("solutions", "无解决方案")
        desc = comp.get("competition", {}).get("description", "无描述")[:200] + "..."
        similarity = comp.get("similarity", 0.0)
        
        lines.append(f"{i}. {title}——相似度: {similarity:.4f}; 描述: {desc}; 解决方案: {solutions}.")
    
    return "\n".join(lines)

if __name__ == "__main__":
    # 测试代码：创建示例语料库
    import asyncio
    
    
    # 2. 测试检索
    async def test():
        task = "预测房价，数据包含房屋的各种特征如面积、位置、房间数量等。"
        results = retrieve_kaggle_knowledge(task, top_k=2)
        print(f"成功检索到 {len(results)} 个相似案例")
        for res in results:
            print(f"- {res['title']} (相似度: {res['similarity']:.4f})")
    
    asyncio.run(test())
