import typing as T

import Stemmer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.node_parser import (
    CodeSplitter,
    LangchainNodeParser,
    SentenceSplitter,
    TokenTextSplitter,
)
from llama_index.core.node_parser.interface import NodeParser
from llama_index.core.retrievers import BaseRetriever, QueryFusionRetriever
from llama_index.core.retrievers.fusion_retriever import FUSION_MODES
from llama_index.core.schema import TransformComponent
from llama_index.core.storage.docstore.types import BaseDocumentStore
from llama_index.core.storage.storage_context import StorageContext
from llama_index.retrievers.bm25 import BM25Retriever
from transformers import AutoTokenizer

from minimal.configuration import cfg
from minimal.huggingface_helper import get_embedding_model
from minimal.llm import get_llm, get_tokenizer
from minimal.logger import logger
from minimal.searchspace import ParamDict, SearchSpace


def build_splitter(params: T.Dict[str, T.Any]) -> NodeParser:
    overlap = int(params["splitter_chunk_size"] * params["splitter_chunk_overlap_frac"])
    llm_name = SearchSpace().get_response_synthesizer_llm_name(params)
    match params["splitter_method"]:
        case "html":
            return CodeSplitter(
                language="html",
                max_chars=params["splitter_chunk_size"] * 4,
            )
        case "sentence":
            return SentenceSplitter(
                chunk_size=params["splitter_chunk_size"],
                chunk_overlap=overlap,
                tokenizer=get_tokenizer(llm_name),
            )
        case "token":
            return TokenTextSplitter(
                chunk_size=params["splitter_chunk_size"],
                chunk_overlap=overlap,
                tokenizer=get_tokenizer(llm_name),
            )
        case "recursive":
            return LangchainNodeParser(
                RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
                    tokenizer=AutoTokenizer.from_pretrained(
                        params.get("rag_embedding_model", "BAAI/bge-small-en-v1.5")
                    ),
                    chunk_size=params["splitter_chunk_size"],
                    chunk_overlap=overlap,
                )
            )
        case _:
            raise ValueError("Invalid splitter")


def get_or_build_dense_index(
    params: ParamDict,
    documents: T.List[Document],
    transforms: T.List[TransformComponent],
    embedding_model: BaseEmbedding,
    max_chunks: int = 0,
    use_cache: bool = True,
) -> T.Tuple[VectorStoreIndex, BaseDocumentStore]:
    logger.info("Building dense index")
    index, docstore = _build_dense_index(
        documents, transforms, embedding_model, max_chunks=max_chunks
    )
    return index, docstore


def _build_dense_index(
    documents: T.List[Document],
    transforms: T.List[TransformComponent],
    embedding_model: BaseEmbedding,
    max_chunks: int = 0,
) -> T.Tuple[VectorStoreIndex, BaseDocumentStore]:
    logger.info("Building dense index")
    logger.debug(f"Embedding model type is {type(embedding_model)})")

    pipeline = IngestionPipeline(transformations=transforms)
    nodes = pipeline.run(
        documents=documents,
        show_progress=cfg.logging.show_progress,
    )
    if max_chunks:
        nodes = nodes[:max_chunks]
    index = VectorStoreIndex(
        nodes=nodes,
        embed_model=embedding_model,
        insert_batch_size=2048,
        show_progress=cfg.logging.show_progress,
    )
    return index, index.docstore


def _build_sparse_index(
    documents: T.List[Document], transforms: T.List[TransformComponent], top_k: int
) -> T.Tuple[BaseRetriever, BaseDocumentStore]:
    logger.info("Building sparse index")

    pipeline = IngestionPipeline(
        transformations=transforms,
    )
    nodes = pipeline.run(
        documents=documents,
        show_progress=cfg.logging.show_progress,
    )
    docstore = StorageContext.from_defaults().docstore
    docstore.add_documents(nodes)
    return (
        BM25Retriever.from_defaults(
            nodes=list(docstore.docs.values()),
            similarity_top_k=top_k,
            stemmer=Stemmer.Stemmer("english"),
            language="english",
        ),
        docstore,
    )


def build_rag_retriever(
    dataset, params: ParamDict
) -> T.Tuple[BaseRetriever, BaseDocumentStore]:
    logger.info(f"Building retreiver for {params=}")
    rag_method = params["rag_method"]
    top_k = int(params["rag_top_k"])
    query_decomp_enabled = params["rag_query_decomposition_enabled"]

    assert rag_method in [
        "dense",
        "sparse",
        "hybrid",
    ], f"RAG method `{rag_method}` not supported"

    logger.info("Loading grounding data documents")
    documents = list(dataset.iter_grounding_data())
    splitter = build_splitter(params)
    transforms = [splitter]

    # Build indexes
    sparse_retriever = sparse_docstore = None
    if rag_method in ["sparse", "hybrid"]:
        sparse_retriever, sparse_docstore = _build_sparse_index(
            documents, transforms, top_k
        )
        if not query_decomp_enabled and rag_method == "sparse":
            return sparse_retriever, sparse_docstore

    dense_retriever = dense_docstore = None
    if rag_method in ["dense", "hybrid"]:
        embedding_model_name = str(params["rag_embedding_model"])
        embedding_model, _ = get_embedding_model(
            embedding_model_name,
            total_chunks=0,
            device=cfg.resources.embedding_device,
        )
        assert embedding_model is not None

        max_chunks = 0

        dense_index, dense_docstore = get_or_build_dense_index(
            params,
            documents,
            transforms,
            embedding_model,
            max_chunks=max_chunks,
        )
        dense_retriever = dense_index.as_retriever(
            embed_model=embedding_model, similarity_top_k=top_k
        )
        if not query_decomp_enabled and rag_method == "dense":
            return dense_retriever, dense_docstore

    # Not dense or sparse - build fusion retriever
    if rag_method == "hybrid":
        # Hybrid mode, use both retrievers with weights
        retrievers = [sparse_retriever, dense_retriever]
        hybrid_bm25_weight = float(params["rag_hybrid_bm25_weight"])
        retriever_weights = [hybrid_bm25_weight, 1 - hybrid_bm25_weight]
    else:
        # Otherwise, pick the active retriever
        retriever = dense_retriever or sparse_retriever
        retrievers = [retriever]
        retriever_weights = [1]

    fusion_retriever_params = {
        "llm": get_llm("gpt-4o-mini"),  # Not used without query decomposition enabled
        "mode": FUSION_MODES(params["rag_fusion_mode"]),
        "use_async": False,
        "verbose": True,
        "similarity_top_k": top_k,
        "num_queries": 1,
        "retriever_weights": retriever_weights,
        "retrievers": retrievers,
    }

    docstore = dense_docstore or sparse_docstore

    if rag_method == "hybrid":
        assert dense_docstore is not None and sparse_docstore is not None
        docstore = sparse_docstore
        docstore.add_documents(list(dense_docstore.docs.values()))

    if query_decomp_enabled:
        # Get query decomp params
        query_decomposition_num_queries = params["rag_query_decomposition_num_queries"]
        query_decomposition_llm_name = str(params["rag_query_decomposition_llm_name"])
        query_decomposition_llm = get_llm(query_decomposition_llm_name)
        # Add query decomp params and retriever
        fusion_retriever_params.update(
            **{
                "llm": query_decomposition_llm,
                "num_queries": query_decomposition_num_queries,
            }
        )

    fusion_retriever = QueryFusionRetriever(**fusion_retriever_params)
    return fusion_retriever, docstore
