#!/usr/bin/env python3
"""
Discord MCP Server
ProvideDiscord频道消息发送和读取功能
"""

import asyncio
import json
import os
from typing import Any, Dict, List, Optional
import logging

# Discord.py imports
try:
    import discord
    from discord.ext import commands
except ImportError:
    print("❌ Error: discord.py not installed. Install with: pip install discord.py")
    exit(1)

# FastMCP imports
try:
    from fastmcp import FastMCP
except ImportError:
    print("❌ Error: fastmcp not installed. Install with: pip install fastmcp")
    exit(1)

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 创建FastMCP应用
mcp = FastMCP("Discord MCP Server")

# Discord客户端
discord_client = None
bot_ready = False

class DiscordBot(commands.Bot):
    def __init__(self):
        intents = discord.Intents.default()
        intents.message_content = True
        intents.guilds = True
        intents.guild_messages = True
        super().__init__(command_prefix='!', intents=intents)

    async def on_ready(self):
        global bot_ready
        bot_ready = True
        logger.info(f'Discord bot logged in as {self.user}')
        logger.info(f'Bot is in {len(self.guilds)} guilds')

async def get_discord_client():
    """Get/FetchDiscord客户端"""
    global discord_client, bot_ready
    
    if discord_client is None:
        token = os.getenv('DISCORD_TOKEN')
        if not token:
            raise ValueError("DISCORD_TOKEN environment variable is required")
        
        discord_client = DiscordBot()
        
        # 启动Discord客户端
        asyncio.create_task(discord_client.start(token))
        
        # 等待bot准备好
        while not bot_ready:
            await asyncio.sleep(0.1)
    
    return discord_client

def find_channel(bot, channel_name: str, guild_name: Optional[str] = None):
    """查找频道"""
    for guild in bot.guilds:
        if guild_name and guild.name != guild_name:
            continue
        
        # 尝试按名称查找
        channel = discord.utils.get(guild.channels, name=channel_name)
        if channel:
            return channel
        
        # 尝试按ID查找
        try:
            channel_id = int(channel_name)
            channel = bot.get_channel(channel_id)
            if channel:
                return channel
        except ValueError:
            pass
    
    return None

@mcp.tool()
async def send_discord_message(
    channel: str,
    message: str,
    guild: Optional[str] = None
) -> Dict[str, Any]:
    """
    向Discord频道发送消息
    
    Args:
        channel: 频道名称或ID
        message: 要发送的消息内容
        guild: 服务器名称（可选，如果bot在多个服务器中）
    
    Returns:
        Include/Contains发送Result的字典
    """
    try:
        bot = await get_discord_client()
        
        # 查找频道
        target_channel = find_channel(bot, channel, guild)
        if not target_channel:
            return {
                "success": False,
                "error": f"Channel '{channel}' not found"
            }
        
        # 发送消息
        sent_message = await target_channel.send(message)
        
        return {
            "success": True,
            "message_id": str(sent_message.id),
            "channel_id": str(target_channel.id),
            "channel_name": target_channel.name,
            "guild_name": target_channel.guild.name,
            "timestamp": sent_message.created_at.isoformat(),
            "content": message
        }
        
    except Exception as e:
        logger.error(f"Error sending Discord message: {e}")
        return {
            "success": False,
            "error": str(e)
        }

@mcp.tool()
async def read_discord_messages(
    channel: str,
    limit: int = 10,
    guild: Optional[str] = None
) -> Dict[str, Any]:
    """
    从Discord频道读取最近的消息
    
    Args:
        channel: 频道名称或ID
        limit: 要Get/Fetch的消息数量（默认10，Maximum100）
        guild: 服务器名称（可选，如果bot在多个服务器中）
    
    Returns:
        Include/Contains消息List的字典
    """
    try:
        bot = await get_discord_client()
        
        # Limit消息数量
        limit = min(limit, 100)
        
        # 查找频道
        target_channel = find_channel(bot, channel, guild)
        if not target_channel:
            return {
                "success": False,
                "error": f"Channel '{channel}' not found"
            }
        
        # Get/Fetch消息
        messages = []
        async for message in target_channel.history(limit=limit):
            messages.append({
                "id": str(message.id),
                "author": message.author.display_name,
                "author_id": str(message.author.id),
                "content": message.content,
                "timestamp": message.created_at.isoformat(),
                "attachments": [att.url for att in message.attachments],
                "embeds": len(message.embeds),
                "reactions": [f"{reaction.emoji}:{reaction.count}" for reaction in message.reactions]
            })
        
        return {
            "success": True,
            "channel_id": str(target_channel.id),
            "channel_name": target_channel.name,
            "guild_name": target_channel.guild.name,
            "message_count": len(messages),
            "messages": messages
        }
        
    except Exception as e:
        logger.error(f"Error reading Discord messages: {e}")
        return {
            "success": False,
            "error": str(e)
        }

@mcp.tool()
async def list_discord_guilds() -> Dict[str, Any]:
    """
    列出bot所在的所有Discord服务器
    
    Returns:
        Include/Contains服务器List的字典
    """
    try:
        bot = await get_discord_client()
        
        guilds = []
        for guild in bot.guilds:
            channels = []
            for channel in guild.channels:
                if isinstance(channel, discord.TextChannel):
                    channels.append({
                        "id": str(channel.id),
                        "name": channel.name,
                        "type": "text"
                    })
            
            guilds.append({
                "id": str(guild.id),
                "name": guild.name,
                "member_count": guild.member_count,
                "text_channels": channels
            })
        
        return {
            "success": True,
            "guild_count": len(guilds),
            "guilds": guilds
        }
        
    except Exception as e:
        logger.error(f"Error listing Discord guilds: {e}")
        return {
            "success": False,
            "error": str(e)
        }

@mcp.tool()
async def add_discord_reaction(
    channel: str,
    message_id: str,
    emoji: str,
    guild: Optional[str] = None
) -> Dict[str, Any]:
    """
    向Discord消息添加反应
    
    Args:
        channel: 频道名称或ID
        message_id: 消息ID
        emoji: 表情符号
        guild: 服务器名称（可选）
    
    Returns:
        Include/Contains操作Result的字典
    """
    try:
        bot = await get_discord_client()
        
        # 查找频道
        target_channel = find_channel(bot, channel, guild)
        if not target_channel:
            return {
                "success": False,
                "error": f"Channel '{channel}' not found"
            }
        
        # Get/Fetch消息
        try:
            message = await target_channel.fetch_message(int(message_id))
        except discord.NotFound:
            return {
                "success": False,
                "error": f"Message with ID '{message_id}' not found"
            }
        
        # 添加反应
        await message.add_reaction(emoji)
        
        return {
            "success": True,
            "message_id": message_id,
            "emoji": emoji,
            "channel_name": target_channel.name
        }
        
    except Exception as e:
        logger.error(f"Error adding Discord reaction: {e}")
        return {
            "success": False,
            "error": str(e)
        }

if __name__ == "__main__":
    # 检查environment variable
    if not os.getenv('DISCORD_TOKEN'):
        print("❌ Error: DISCORD_TOKEN environment variable is required")
        print("Please set your Discord bot token:")
        print("export DISCORD_TOKEN='your_discord_bot_token_here'")
        exit(1)
    
    # 运行MCP服务器
    mcp.run() 