import ast
import typing as T
from textwrap import dedent

from llama_index.core import VectorStoreIndex
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore, TextNode

from minimal.configuration import cfg
from minimal.flows import (
    CritiqueAgentFlow,
    Flow,
    Flows,
    LATSAgentFlow,
    RAGFlow,
    ReActAgentFlow,
    SubQuestionRAGFlow,
)
from minimal.huggingface_helper import get_embedding_model
from minimal.llm import get_llm
from minimal.logger import logger
from minimal.retrievers.build import build_rag_retriever
from minimal.searchspace import SearchSpace, get_template_name
from minimal.templates import get_template


def get_flow_name(rag_mode: str):
    match rag_mode:
        case "no_rag":
            return Flows.GENERATOR_FLOW.value.__name__
        case "rag":
            return Flows.RAG_FLOW.value.__name__
        case "sub_question_rag":
            return Flows.LLAMA_INDEX_SUB_QUESTION_FLOW.value.__name__
        case "react_rag_agent":
            return Flows.LLAMA_INDEX_REACT_AGENT_FLOW.value.__name__
        case "critique_rag_agent":
            return Flows.LLAMA_INDEX_CRITIQUE_AGENT_FLOW.value.__name__
        case "lats_rag_agent":
            return Flows.LLAMA_INDEX_LATS_RAG_AGENT.value.__name__
        case _:
            raise RuntimeError("Cannot identify flow")


def _get_examples(example_retriever: BaseRetriever, query_str: str):
    retrieved_nodes: T.List[NodeWithScore] = example_retriever.retrieve(query_str)
    result_strs = []
    for n in retrieved_nodes:
        try:
            raw_dict = ast.literal_eval(n.text)
            query = raw_dict["query"]
            response = raw_dict["response"]
            result_str = dedent(
                f"""\
                Question: {query}
                Answer: {response}"""
            )
            result_strs.append(result_str)
        except SyntaxError as exc:
            logger.warning("Converting example to dictionary failed: %s", exc)
            result_strs.append(n.text)
    return "\n\n".join(result_strs)


def _get_example_retriever(params, search_space: SearchSpace, dataset, embedding_model):
    assert embedding_model, "No embedding model for dynamic few-shot prompting"
    logger.info("Building few-shot retriever")
    dataset_iter = dataset.iter_examples(partition="train")
    logger.info("Getting few-shot examples from dataset")
    few_shot_nodes = []
    for pair in dataset_iter:
        line = f"{{'query': '''{pair.question}''', 'response': '''{pair.answer}'''}}"
        few_shot_nodes.append(TextNode(text=line))

    logger.info("Building few-shot retriever index")
    few_shot_index = VectorStoreIndex(nodes=few_shot_nodes, embed_model=embedding_model)
    logger.info("Built few-shot retriever index")
    few_shot_retriever = few_shot_index.as_retriever(
        similarity_top_k=params["few_shot_top_k"], similarity_threshold=None
    )

    def get_qa_examples(query_str, **kwargs):
        return _get_examples(few_shot_retriever, query_str)

    return get_qa_examples


def build_flow(params: T.Dict, search_space: SearchSpace, dataset) -> Flow:
    response_synthesizer_llm_name = search_space.get_response_synthesizer_llm_name(
        params
    )
    response_synthesizer_llm = get_llm(response_synthesizer_llm_name)

    get_qa_examples = None
    is_few_shot = search_space.is_few_shot(params)
    if is_few_shot:
        few_shot_embedding_model_name = params["few_shot_embedding_model"]
        few_shot_embedding_model, _ = get_embedding_model(
            few_shot_embedding_model_name,
            device=cfg.resources.embedding_device,
        )
        get_qa_examples = _get_example_retriever(
            params, search_space, dataset, few_shot_embedding_model
        )

    do_rag = params["rag_mode"] != "no_rag"
    template_name = get_template_name(params)
    template = get_template(
        template_name, with_context=do_rag, with_few_shot_prompt=is_few_shot
    )
    enforce_full_evaluation = params.get("enforce_full_evaluation", False)

    flow: T.Any

    if not do_rag:
        flow = Flow(
            response_synthesizer_llm=response_synthesizer_llm,
            template=template,
            get_examples=get_qa_examples,
            params=params,
            enforce_full_evaluation=enforce_full_evaluation,
        )
    else:
        hyde_llm = reranker_llm = reranker_top_k = None
        if params.get("hyde_enabled"):
            hyde_llm = get_llm(params["hyde_llm_name"])
        if params.get("reranker_enabled"):
            reranker_llm = get_llm(params["reranker_llm_name"])
            reranker_top_k = params["reranker_top_k"]
        if params.get("additional_context_enabled"):
            additional_context_num_nodes = params["additional_context_num_nodes"]
        else:
            additional_context_num_nodes = 0

        rag_retriever, rag_docstore = build_rag_retriever(dataset, params)

        match params["rag_mode"]:
            case "rag":
                flow = RAGFlow(
                    retriever=rag_retriever,
                    response_synthesizer_llm=response_synthesizer_llm,
                    docstore=rag_docstore,
                    template=template,
                    get_examples=get_qa_examples,
                    hyde_llm=hyde_llm,
                    reranker_llm=reranker_llm,
                    reranker_top_k=reranker_top_k,
                    additional_context_num_nodes=additional_context_num_nodes,
                    enforce_full_evaluation=enforce_full_evaluation,
                    params=params,
                )
            case "react_rag_agent":
                subquestion_engine_llm = get_llm(
                    params["react_rag_agent_subquestion_engine_llm"]
                )
                subquestion_response_synthesizer_llm = get_llm(
                    params["react_rag_agent_subquestion_response_synthesizer_llm"]
                )
                max_iterations = params.get(
                    "react_rag_agent_max_iterations",
                    search_space.defaults.max_iterations,
                )
                flow = ReActAgentFlow(
                    retriever=rag_retriever,
                    response_synthesizer_llm=response_synthesizer_llm,
                    subquestion_response_synthesizer_llm=subquestion_response_synthesizer_llm,
                    subquestion_engine_llm=subquestion_engine_llm,
                    max_iterations=max_iterations,
                    docstore=rag_docstore,
                    template=template,
                    get_examples=get_qa_examples,
                    hyde_llm=hyde_llm,
                    reranker_llm=reranker_llm,
                    reranker_top_k=reranker_top_k,
                    additional_context_num_nodes=additional_context_num_nodes,
                    dataset_name=dataset.name,
                    dataset_description=dataset.description,
                    enforce_full_evaluation=enforce_full_evaluation,
                    params=params,
                )
            case "critique_rag_agent":
                subquestion_engine_llm = get_llm(
                    params["critique_rag_agent_subquestion_engine_llm"]
                )
                subquestion_response_synthesizer_llm = get_llm(
                    params["critique_rag_agent_subquestion_response_synthesizer_llm"]
                )
                critique_agent_llm = get_llm(
                    params["critique_rag_agent_critique_agent_llm"]
                )
                reflection_agent_llm = get_llm(
                    params["critique_rag_agent_reflection_agent_llm"]
                )
                max_iterations = params.get(
                    "critique_rag_agent_max_iterations",
                    search_space.defaults.max_iterations,
                )
                flow = CritiqueAgentFlow(
                    response_synthesizer_llm=response_synthesizer_llm,
                    subquestion_engine_llm=subquestion_engine_llm,
                    subquestion_response_synthesizer_llm=subquestion_response_synthesizer_llm,
                    critique_agent_llm=critique_agent_llm,
                    reflection_agent_llm=reflection_agent_llm,
                    max_iterations=max_iterations,
                    retriever=rag_retriever,
                    docstore=rag_docstore,
                    template=template,
                    get_examples=get_qa_examples,
                    hyde_llm=hyde_llm,
                    reranker_llm=reranker_llm,
                    reranker_top_k=reranker_top_k,
                    additional_context_num_nodes=additional_context_num_nodes,
                    dataset_name=dataset.name,
                    dataset_description=dataset.description,
                    enforce_full_evaluation=enforce_full_evaluation,
                    params=params,
                )
            case "sub_question_rag":
                subquestion_engine_llm = get_llm(
                    params["sub_question_rag_subquestion_engine_llm"]
                )
                subquestion_response_synthesizer_llm = get_llm(
                    params["sub_question_rag_subquestion_response_synthesizer_llm"]
                )
                flow = SubQuestionRAGFlow(
                    response_synthesizer_llm=response_synthesizer_llm,
                    subquestion_engine_llm=subquestion_engine_llm,
                    subquestion_response_synthesizer_llm=subquestion_response_synthesizer_llm,
                    retriever=rag_retriever,
                    docstore=rag_docstore,
                    template=template,
                    get_examples=get_qa_examples,
                    hyde_llm=hyde_llm,
                    reranker_llm=reranker_llm,
                    reranker_top_k=reranker_top_k,
                    additional_context_num_nodes=additional_context_num_nodes,
                    dataset_name=dataset.name,
                    dataset_description=dataset.description,
                    enforce_full_evaluation=enforce_full_evaluation,
                    params=params,
                )
            case "lats_rag_agent":
                flow = LATSAgentFlow(
                    retriever=rag_retriever,
                    response_synthesizer_llm=response_synthesizer_llm,
                    docstore=rag_docstore,
                    template=template,
                    get_examples=get_qa_examples,
                    hyde_llm=hyde_llm,
                    reranker_llm=reranker_llm,
                    reranker_top_k=reranker_top_k,
                    additional_context_num_nodes=additional_context_num_nodes,
                    dataset_name=dataset.name,
                    dataset_description=dataset.description,
                    num_expansions=params["lats_rag_agent_num_expansions"],
                    max_rollouts=params["lats_rag_agent_max_rollouts"],
                    enforce_full_evaluation=enforce_full_evaluation,
                    params=params,
                )
            case _:
                raise ValueError(f"Invalid rag_mode: {params['rag_mode']}")

    return flow
