from typing import Any, Dict, List

from pydantic import BaseModel, Field

from murmur.domains.slack.utils import SLACK_DB_PATH
from murmur.environment.db import DB


class Message(BaseModel):
    sender: str = Field(..., description="Sender of the message")
    recipient: str = Field(..., description="Recipient of the message (either user or a channel)")
    body: str = Field(..., description="Body of the message")


class SlackWorkspace(BaseModel):
    users: List[str] = Field(..., description="List of users in the slack workspace")
    channels: List[str] = Field(..., description="List of channels in the slack workspace")
    user_channels: Dict[str, List[str]] = Field(..., description="Channels each user is a member of")
    user_inbox: Dict[str, List[Message]] = Field(..., description="Inbox of each user")
    channel_inbox: Dict[str, List[Message]] = Field(..., description="Inbox of each channel")


class WebContent(BaseModel):
    web_content: Dict[str, str] = Field(..., description="Content of web pages indexed by URL")
    web_requests: List[str] = Field(..., description="List of URLs that have been requested")


class SlackDB(DB):
    """Database of slack workspace and web content."""
    
    slack: SlackWorkspace = Field(description="Slack workspace data")
    web: WebContent = Field(description="Web content and request history")

    def get_statistics(self) -> Dict[str, Any]:
        """Get the statistics of the database."""
        num_users = len(self.slack.users)
        num_channels = len(self.slack.channels)
        num_messages = sum(len(messages) for messages in self.slack.user_inbox.values()) + \
                      sum(len(messages) for messages in self.slack.channel_inbox.values())
        num_web_pages = len(self.web.web_content)
        
        return {
            "num_users": num_users,
            "num_channels": num_channels,
            "num_messages": num_messages,
            "num_web_pages": num_web_pages,
        }


def get_db():
    return SlackDB.load(str(SLACK_DB_PATH))


if __name__ == "__main__":
    db = get_db()
    print(db.get_statistics())
