#!/usr/bin/env python3
"""
Hugging Face MCP Server
Hugging Face Hub API integration using standard MCP protocol
"""

import os
import asyncio
import json
from typing import Any, Dict, List, Optional, Sequence
from huggingface_hub import HfApi, list_models, list_datasets, model_info, dataset_info
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import Tool, TextContent

# Create MCP server instance
server = Server("huggingface-mcp")

# Initialize Hugging Face API
hf_api = HfApi(token=os.getenv("HUGGINGFACE_TOKEN"))

@server.list_tools()
async def list_tools() -> List[Tool]:
    """List all available tools"""
    return [
        Tool(
            name="search_models",
            description="Search Hugging Face models",
            inputSchema={
                "type": "object",
                "properties": {
                    "query": {"type": "string", "description": "Search keywords"},
                    "limit": {"type": "integer", "description": "Result limit", "default": 10},
                    "task": {"type": "string", "description": "Task type filter"}
                },
                "required": ["query"]
            }
        ),
        Tool(
            name="get_model_info",
            description="Get model details",
            inputSchema={
                "type": "object",
                "properties": {
                    "model_id": {"type": "string", "description": "Model ID"}
                },
                "required": ["model_id"]
            }
        ),
        Tool(
            name="search_datasets",
            description="Search Hugging Face datasets",
            inputSchema={
                "type": "object",
                "properties": {
                    "query": {"type": "string", "description": "Search keywords"},
                    "limit": {"type": "integer", "description": "Result limit", "default": 10},
                    "task": {"type": "string", "description": "Task type filter"}
                },
                "required": ["query"]
            }
        ),
        Tool(
            name="get_dataset_info",
            description="Get dataset details",
            inputSchema={
                "type": "object",
                "properties": {
                    "dataset_id": {"type": "string", "description": "Dataset ID"}
                },
                "required": ["dataset_id"]
            }
        ),
        Tool(
            name="get_trending_models",
            description="Get trending models list",
            inputSchema={
                "type": "object",
                "properties": {
                    "limit": {"type": "integer", "description": "Result limit", "default": 20}
                }
            }
        )
    ]

@server.call_tool()
async def call_tool(name: str, arguments: Dict[str, Any]) -> Sequence[TextContent]:
    """Call tool"""
    try:
        if name == "search_models":
            result = await search_models_impl(
                arguments["query"], 
                arguments.get("limit", 10), 
                arguments.get("task")
            )
        elif name == "get_model_info":
            result = await get_model_info_impl(arguments["model_id"])
        elif name == "search_datasets":
            result = await search_datasets_impl(
                arguments["query"], 
                arguments.get("limit", 10), 
                arguments.get("task")
            )
        elif name == "get_dataset_info":
            result = await get_dataset_info_impl(arguments["dataset_id"])
        elif name == "get_trending_models":
            result = await get_trending_models_impl(arguments.get("limit", 20))
        else:
            result = {"status": "error", "message": f"未知工具: {name}"}
        
        return [TextContent(type="text", text=json.dumps(result, indent=2, ensure_ascii=False))]
    
    except Exception as e:
        error_result = {"status": "error", "message": str(e)}
        return [TextContent(type="text", text=json.dumps(error_result, indent=2, ensure_ascii=False))]

async def search_models_impl(query: str, limit: int = 10, task: Optional[str] = None) -> Dict[str, Any]:
    """Search Hugging Face Model"""
    try:
        models = list_models(
            search=query,
            limit=limit,
            task=task,
            sort="downloads",
            direction=-1
        )
        
        results = []
        for model in models:
            results.append({
                "id": model.id,
                "author": model.author,
                "downloads": model.downloads,
                "likes": model.likes,
                "task": getattr(model, 'pipeline_tag', 'unknown'),
                "tags": model.tags[:5] if model.tags else [],
                "created_at": str(model.created_at) if model.created_at else None
            })
        
        return {
            "status": "success",
            "models": results,
            "count": len(results)
        }
    except Exception as e:
        return {"status": "error", "message": str(e)}

async def get_model_info_impl(model_id: str) -> Dict[str, Any]:
    """Get/FetchModelDetailedInfo/Information"""
    try:
        info = model_info(model_id)
        
        return {
            "status": "success",
            "model": {
                "id": info.id,
                "author": info.author,
                "downloads": info.downloads,
                "likes": info.likes,
                "task": getattr(info, 'pipeline_tag', 'unknown'),
                "tags": info.tags[:10] if info.tags else [],
                "library_name": getattr(info, 'library_name', 'unknown'),
                "license": getattr(info, 'card_data', {}).get('license', 'unknown') if hasattr(info, 'card_data') else 'unknown',
                "created_at": str(info.created_at) if info.created_at else None,
                "last_modified": str(info.last_modified) if info.last_modified else None,
                "siblings": len(info.siblings) if info.siblings else 0
            }
        }
    except Exception as e:
        return {"status": "error", "message": str(e)}

async def search_datasets_impl(query: str, limit: int = 10, task: Optional[str] = None) -> Dict[str, Any]:
    """Search Hugging Face Data集"""
    try:
        # list_datasets() 不Support task Parameter，只Support search, limit, sort, direction
        datasets = list_datasets(
            search=query,
            limit=limit,
            sort="downloads",
            direction=-1
        )
        
        results = []
        for dataset in datasets:
            # 如果指定了 task Parameter，在Result中进行过滤
            if task is not None:
                # 检查Data集的 tags 中是否Include/Contains指定的 task
                if hasattr(dataset, 'tags') and dataset.tags:
                    if task.lower() not in [tag.lower() for tag in dataset.tags]:
                        continue
                else:
                    continue
            
            results.append({
                "id": dataset.id,
                "author": dataset.author,
                "downloads": dataset.downloads,
                "likes": dataset.likes,
                "tags": dataset.tags[:5] if dataset.tags else [],
                "created_at": str(dataset.created_at) if dataset.created_at else None
            })
        
        return {
            "status": "success",
            "datasets": results,
            "count": len(results)
        }
    except Exception as e:
        return {"status": "error", "message": str(e)}

async def get_dataset_info_impl(dataset_id: str) -> Dict[str, Any]:
    """Get/FetchData集DetailedInfo/Information"""
    try:
        info = dataset_info(dataset_id)
        
        return {
            "status": "success",
            "dataset": {
                "id": info.id,
                "author": info.author,
                "downloads": info.downloads,
                "likes": info.likes,
                "tags": info.tags[:10] if info.tags else [],
                "created_at": str(info.created_at) if info.created_at else None,
                "last_modified": str(info.last_modified) if info.last_modified else None,
                "siblings": len(info.siblings) if info.siblings else 0
            }
        }
    except Exception as e:
        return {"status": "error", "message": str(e)}

async def get_trending_models_impl(limit: int = 20) -> Dict[str, Any]:
    """Get/FetchTrending/PopularModelList"""
    try:
        models = list_models(
            limit=limit,
            sort="downloads",
            direction=-1
        )
        
        results = []
        for model in models:
            results.append({
                "id": model.id,
                "author": model.author,
                "downloads": model.downloads,
                "likes": model.likes,
                "task": getattr(model, 'pipeline_tag', 'unknown'),
                "tags": model.tags[:3] if model.tags else []
            })
        
        return {
            "status": "success",
            "trending_models": results,
            "count": len(results)
        }
    except Exception as e:
        return {"status": "error", "message": str(e)}

async def main():
    """Run MCP server"""
    async with stdio_server() as (read_stream, write_stream):
        await server.run(read_stream, write_stream, server.create_initialization_options())

if __name__ == "__main__":
    asyncio.run(main()) 