#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MCP客户端库 - 连接到远程MCP数据库服务器
替换本地数据库操作
"""

import json
import logging
import requests
from typing import List, Dict, Any, Optional
import numpy as np
import torch
import torch.nn.functional as F

logger = logging.getLogger(__name__)

class MCPDatabaseClient:
    """MCP数据库客户端"""
    
    def __init__(self, server_url: str = "http://localhost:8000"):
        self.server_url = server_url.rstrip("/")
        self.session = requests.Session()
        
        # 验证连接
        try:
            response = self.session.get(f"{self.server_url}/health")
            if response.status_code == 200:
                logger.info(f"Connected to MCP server: {self.server_url}")
            else:
                logger.warning(f"MCP server health check failed: {response.status_code}")
        except Exception as e:
            logger.error(f"Failed to connect to MCP server: {e}")
    
    def build_index(self, file_path: str, model_path: str, index_name: str, batch_size: int = 500) -> Dict[str, Any]:
        """构建FAISS索引"""
        data = {
            "file_path": file_path,
            "model_path": model_path,
            "index_name": index_name,
            "batch_size": batch_size
        }
        
        response = self.session.post(f"{self.server_url}/build_index", json=data)
        if response.status_code == 200:
            result = response.json()
            logger.info(f"Index built successfully: {index_name}")
            return result
        else:
            raise Exception(f"Failed to build index: {response.text}")
    
    def load_documents(self, file_path: str, key_name: str = "context") -> List[str]:
        """加载文档"""
        data = {
            "file_path": file_path,
            "key_name": key_name
        }
        
        response = self.session.post(f"{self.server_url}/load_documents", json=data)
        if response.status_code == 200:
            result = response.json()
            return result["documents"]
        else:
            raise Exception(f"Failed to load documents: {response.text}")
    
    def search(self, query_embedding: List[float], k: int, index_name: str) -> Dict[str, List]:
        """在索引中搜索"""
        data = {
            "query_embedding": query_embedding,
            "k": k,
            "index_name": index_name
        }
        
        response = self.session.post(f"{self.server_url}/search", json=data)
        if response.status_code == 200:
            result = response.json()
            return result["results"]
        else:
            raise Exception(f"Failed to search: {response.text}")
    
    def generate_embeddings(self, texts: List[str], model_path: str = "/hub/huggingface/models/bert/bert-base-uncased") -> np.ndarray:
        """生成文本嵌入"""
        data = {
            "texts": texts,
            "model_path": model_path
        }
        
        response = self.session.post(f"{self.server_url}/generate_embeddings", json=data)
        if response.status_code == 200:
            result = response.json()
            return np.array(result["embeddings"])
        else:
            raise Exception(f"Failed to generate embeddings: {response.text}")
    
    def list_indexes(self) -> List[str]:
        """列出可用的索引"""
        response = self.session.get(f"{self.server_url}/indexes")
        if response.status_code == 200:
            result = response.json()
            return result["indexes"]
        else:
            raise Exception(f"Failed to list indexes: {response.text}")

# 兼容性函数 - 替换原有的函数调用
def load_json_file(file_path: str, key_name: str, client: Optional[MCPDatabaseClient] = None) -> List[str]:
    """兼容性函数：加载JSON文件"""
    if client is None:
        # 如果没有客户端，使用本地加载（向后兼容）
        with open(file_path, "r", encoding="utf-8") as file:
            data = json.load(file)
        return [item[key_name] for item in data if key_name in item]
    else:
        # 使用MCP客户端
        return client.load_documents(file_path, key_name)

def mean_pooling(model_output, attention_mask):
    """兼容性函数：平均池化"""
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

def build_index_remote(file_path: str, model_path: str, index_name: str, 
                      client: MCPDatabaseClient, batch_size: int = 500) -> Dict[str, Any]:
    """远程构建索引"""
    return client.build_index(file_path, model_path, index_name, batch_size)

def search_remote(query_text: str, model_path: str, index_name: str, 
                 client: MCPDatabaseClient, k: int = 3) -> Dict[str, List]:
    """远程搜索"""
    # 生成查询嵌入
    query_embeddings = client.generate_embeddings([query_text], model_path)
    query_embedding = query_embeddings[0].tolist()
    
    # 搜索
    return client.search(query_embedding, k, index_name)

# 全局客户端实例（可配置）
_global_client = None

def configure_mcp_client(server_url: str):
    """配置全局MCP客户端"""
    global _global_client
    _global_client = MCPDatabaseClient(server_url)
    return _global_client

def get_mcp_client() -> Optional[MCPDatabaseClient]:
    """获取全局MCP客户端"""
    return _global_client
