#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MCP服务器 - 提供数据库和嵌入向量服务
处理FAISS索引、文档检索和嵌入生成
"""

import asyncio
import json
import logging
import os
import sys
from typing import Any, Dict, List, Optional, Sequence

import faiss
import torch
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer
import uvicorn

# 添加项目根目录到Python路径
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 数据模型定义
class EmbeddingRequest(BaseModel):
    texts: List[str]
    model_path: str = "/hub/huggingface/models/bert/bert-base-uncased"

class SearchRequest(BaseModel):
    query_embedding: List[float]
    k: int = 3
    index_name: str

class DocumentRequest(BaseModel):
    file_path: str
    key_name: str = "context"

class IndexBuildRequest(BaseModel):
    file_path: str
    model_path: str = "/hub/huggingface/models/bert/bert-base-uncased"
    index_name: str
    batch_size: int = 500

class MCPDatabaseServer:
    """MCP数据库服务器类"""
    
    def __init__(self, storage_dir: str = "/tmp/mcp_storage"):
        self.storage_dir = storage_dir
        self.indexes = {}  # 存储已加载的索引
        self.documents = {}  # 存储文档
        self.models = {}  # 存储已加载的模型
        
        # 创建存储目录
        os.makedirs(storage_dir, exist_ok=True)
        
        # 初始化FastAPI应用
        self.app = FastAPI(title="MCP Database Server", version="1.0.0")
        self._setup_routes()
    
    def _setup_routes(self):
        """设置API路由"""
        
        @self.app.post("/build_index")
        async def build_index(request: IndexBuildRequest):
            """构建FAISS索引"""
            try:
                result = await self._build_index(
                    request.file_path,
                    request.model_path,
                    request.index_name,
                    request.batch_size
                )
                return {"status": "success", "message": f"Index {request.index_name} built successfully", "dimension": result}
            except Exception as e:
                logger.error(f"Error building index: {e}")
                raise HTTPException(status_code=500, detail=str(e))
        
        @self.app.post("/load_documents")
        async def load_documents(request: DocumentRequest):
            """加载文档"""
            try:
                documents = await self._load_documents(request.file_path, request.key_name)
                return {"status": "success", "count": len(documents), "documents": documents}
            except Exception as e:
                logger.error(f"Error loading documents: {e}")
                raise HTTPException(status_code=500, detail=str(e))
        
        @self.app.post("/search")
        async def search(request: SearchRequest):
            """在索引中搜索"""
            try:
                results = await self._search_index(
                    request.query_embedding,
                    request.k,
                    request.index_name
                )
                return {"status": "success", "results": results}
            except Exception as e:
                logger.error(f"Error searching: {e}")
                raise HTTPException(status_code=500, detail=str(e))
        
        @self.app.post("/generate_embeddings")
        async def generate_embeddings(request: EmbeddingRequest):
            """生成文本嵌入"""
            try:
                embeddings = await self._generate_embeddings(request.texts, request.model_path)
                return {"status": "success", "embeddings": embeddings.tolist()}
            except Exception as e:
                logger.error(f"Error generating embeddings: {e}")
                raise HTTPException(status_code=500, detail=str(e))
        
        @self.app.get("/health")
        async def health_check():
            """健康检查"""
            return {"status": "healthy", "server": "MCP Database Server"}
        
        @self.app.get("/indexes")
        async def list_indexes():
            """列出可用的索引"""
            return {"status": "success", "indexes": list(self.indexes.keys())}
    
    async def _load_model(self, model_path: str):
        """加载或获取缓存的模型"""
        if model_path not in self.models:
            logger.info(f"Loading model: {model_path}")
            tokenizer = AutoTokenizer.from_pretrained(model_path)
            model = AutoModel.from_pretrained(model_path).cuda() if torch.cuda.is_available() else AutoModel.from_pretrained(model_path)
            self.models[model_path] = (tokenizer, model)
        return self.models[model_path]
    
    async def _generate_embeddings(self, texts: List[str], model_path: str):
        """生成文本嵌入"""
        tokenizer, model = await self._load_model(model_path)
        
        encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
        if torch.cuda.is_available():
            encoded_input = {k: v.cuda() for k, v in encoded_input.items()}
        
        with torch.no_grad():
            model_output = model(**encoded_input)
        
        # Mean pooling
        embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"])
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        return embeddings.cpu().numpy()
    
    def _mean_pooling(self, model_output, attention_mask):
        """平均池化"""
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask
    
    async def _load_documents(self, file_path: str, key_name: str):
        """加载JSON文档"""
        with open(file_path, "r", encoding="utf-8") as file:
            data = json.load(file)
        documents = [item[key_name] for item in data if key_name in item]
        
        # 缓存文档
        self.documents[file_path] = documents
        return documents
    
    async def _build_index(self, file_path: str, model_path: str, index_name: str, batch_size: int = 500):
        """构建FAISS索引"""
        documents = await self._load_documents(file_path, "context")
        tokenizer, model = await self._load_model(model_path)
        
        index = None
        dimension = None
        
        for i in range(0, len(documents), batch_size):
            batch_documents = documents[i:i + batch_size]
            embeddings = await self._generate_embeddings(batch_documents, model_path)
            
            if index is None:
                dimension = embeddings.shape[1]
                index = faiss.IndexFlatL2(dimension)
            
            index.add(embeddings)
        
        # 保存索引
        index_path = os.path.join(self.storage_dir, f"{index_name}.faiss")
        faiss.write_index(index, index_path)
        
        # 缓存索引
        self.indexes[index_name] = index
        
        logger.info(f"FAISS index stored: {index_path}")
        return dimension
    
    async def _search_index(self, query_embedding: List[float], k: int, index_name: str):
        """在索引中搜索"""
        if index_name not in self.indexes:
            # 尝试从磁盘加载
            index_path = os.path.join(self.storage_dir, f"{index_name}.faiss")
            if os.path.exists(index_path):
                self.indexes[index_name] = faiss.read_index(index_path)
            else:
                raise ValueError(f"Index {index_name} not found")
        
        index = self.indexes[index_name]
        query_emb = torch.tensor([query_embedding]).numpy()
        distances, indices = index.search(query_emb, k)
        
        return {
            "distances": distances[0].tolist(),
            "indices": indices[0].tolist()
        }
    
    def run(self, host: str = "0.0.0.0", port: int = 8000):
        """运行服务器"""
        logger.info(f"Starting MCP Database Server on {host}:{port}")
        uvicorn.run(self.app, host=host, port=port)

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="MCP Database Server")
    parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
    parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
    parser.add_argument("--storage-dir", default="/tmp/mcp_storage", help="Storage directory for indexes")
    args = parser.parse_args()
    
    server = MCPDatabaseServer(storage_dir=args.storage_dir)
    server.run(host=args.host, port=args.port)
