#!/usr/bin/python3
"""
Computer Science arXiv Abstracts Knowledge Base.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
from datasets import load_dataset  # type: ignore
from llama_index.core import Document
from llama_index.core.embeddings import BaseEmbedding
from pathlib import Path
from typing import Any, Dict, Final, List, Union

from .base import RAGKnowledgeBase


class arXivKnowledgeBase(RAGKnowledgeBase):
    def __init__(
        self,
        embedder: BaseEmbedding,
        top_k: int,
        cache_dir: Union[Path, str] = (
            Path.home() / ".cache" / "leon" / "arxiv"
        ),
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            embedder: the embedding function to use for retrieval.
            top_k: the number of documents to retrieve per query.
            cache_dir: a local directory to cache the knowledge locally.
        """
        del kwargs
        self.knowledge_source: Final[str] = "mteb/raw_arxiv"
        super(arXivKnowledgeBase, self).__init__(
            embedder=embedder, top_k=top_k, cache_dir=cache_dir
        )

    def get_knowledge(self) -> List[Document]:
        """
        Returns a list of the knowledge from the knowledge base.
        Input:
            None.
        Returns:
            A list of the knowledge from the knowledge base.
        """
        knowledge = load_dataset(self.knowledge_source, split="train")
        knowledge = knowledge.filter(lambda k: "cs.LG" in k["categories"])
        return [Document(text=k["abstract"]) for k in knowledge]

    @classmethod
    def knowledge_description(
        cls, *args: Any, **kwargs: Dict[str, Any]
    ) -> str:
        """
        Returns a description of the knowledge.
        Input:
            None.
        Returns:
            The description of the knowledge base.
        """
        del args, kwargs
        return "Retrieves relevant abstracts of computer science arXiv papers."

    @classmethod
    def query_description(cls) -> str:
        """
        Returns a description of the expected query type.
        Input:
            None.
        Returns:
            The description of the expected query type.
        """
        return "The topic to retrieve abstracts about."
