import math
import typing as T

from pydantic import BaseModel

from minimal.llm import LLMs

ParamDict = T.Dict[str, str | int | float | bool]


def to_power_of_2(num: int | float) -> int:
    return int(2 ** math.ceil(math.log2(num)))


class Splitter(BaseModel):
    chunk_size_min: int = 129  # will round up to 256
    chunk_size_max: int = 4096
    chunk_overlap_frac_min: float = 0.0
    chunk_overlap_frac_max: float = 0.75
    chunk_overlap_frac_step: float = 0.25
    methods: T.List[str] = ["html", "recursive", "sentence", "token"]


DEFAULT_EMBEDDING_MODELS: T.List[str] = [
    "BAAI/bge-small-en-v1.5",
    "BAAI/bge-large-en-v1.5",
    "thenlper/gte-large",
    "mixedbread-ai/mxbai-embed-large-v1",
    "WhereIsAI/UAE-Large-V1",
    "avsolatorio/GIST-large-Embedding-v0",
    "w601sxs/b1ade-embed",
    "Labib11/MUG-B-1.6",
    "sentence-transformers/all-MiniLM-L12-v2",
    "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
    "BAAI/bge-base-en-v1.5",
    "FinLang/finance-embeddings-investopedia",
    "baconnier/Finance2_embedding_small_en-V1.5",
]


ALL_LLMS = list(LLMs.keys())

DEFAULT_LLMS: T.List[str] = [
    "gpt-4o-mini",
    "gpt-4o-std",
    "gpt-35-turbo",
    "anthropic-sonnet-35",
    "anthropic-haiku-35",
    "llama-33-70B",
    "gemini-pro",
    "gemini-flash",
    "gemini-flash2",
    "mistral-large",
]

RESPONSE_SYNTHSIZER_LLMS: T.List[str] = [
    "gpt-4o-mini",
    "gpt-4o-std",
    "gpt-35-turbo",
    "anthropic-sonnet-35",
    "anthropic-haiku-35",
    "llama-33-70B",
    "gemini-pro",
    "gemini-flash",
    "gemini-flash2",
    "mistral-large",
]

FUNCTION_CALLING_LLMS: T.List[str] = [
    "gpt-4o-mini",
    "gpt-4o-std",
    "gpt-35-turbo",
    "anthropic-sonnet-35",
    "anthropic-haiku-35",
    "llama-33-70B",
    "mistral-large",
]

CHEAP_LLMS: T.List[str] = [
    "gpt-4o-mini",
    "anthropic-haiku-35",
    "gemini-flash2",
]

NON_REASONING_LLMS: T.List[str] = [
    "gpt-4o-mini",
    "gpt-4o-std",
    "gpt-35-turbo",
    "anthropic-sonnet-35",
    "anthropic-haiku-35",
    "llama-33-70B",
    "gemini-pro",
    "gemini-flash",
    "gemini-flash2",
    "mistral-large",
]

ALL_TEMPLATE_NAMES: T.List[str] = [
    "default",
    "concise",
    "CoT",
    "finance-expert",
]


class TopK(BaseModel):
    kmin: int = 2
    kmax: int = 20
    log: bool = False
    step: int = 1


class Hybrid(BaseModel):
    bm25_weight_min: float = 0.1
    bm25_weight_max: float = 0.9
    bm25_weight_step: float = 0.1


class QueryDecomposition(BaseModel):
    enabled: T.List[bool] = [True, False]
    llm_names: T.List[str] = NON_REASONING_LLMS
    num_queries_min: int = 2
    num_queries_max: int = 20
    num_queries_step: int = 2


class FusionMode(BaseModel):
    fusion_modes: T.List[str] = [
        "simple",
        "reciprocal_rerank",
        "relative_score",
        "dist_based_score",
    ]


class Retriever(BaseModel):
    name: str
    top_k: TopK = TopK()
    methods: T.List[str] = ["dense", "sparse", "hybrid"]
    embedding_models: T.List[str] = DEFAULT_EMBEDDING_MODELS
    hybrid: Hybrid = Hybrid()
    query_decomposition: QueryDecomposition = QueryDecomposition()
    fusion: FusionMode = FusionMode()


class FewShotRetriever(BaseModel):
    enabled: T.List[bool] = [True, False]
    top_k: TopK = TopK()
    embedding_models: T.List[str] = DEFAULT_EMBEDDING_MODELS


class Reranker(BaseModel):
    """
    Params:
        reranker_enabled
        reranker_llm_name
        reranker_top_k
    """

    enable: T.List[bool] = [True, False]
    top_k: TopK = TopK(kmax=128, log=True)
    llms: T.List[str] = DEFAULT_LLMS


class Hyde(BaseModel):
    enabled: T.List[bool] = [True, False]
    llms: T.List[str] = DEFAULT_LLMS


class AdditionalContext(BaseModel):
    enabled: T.List[bool] = [False, True]
    num_nodes_min: int = 2
    num_nodes_max: int = 20


class RAGMode(BaseModel):
    prefix: str


class NoRAG(RAGMode):
    prefix: str = "no_rag_"
    response_synthesizer_llms: T.List[str] = RESPONSE_SYNTHSIZER_LLMS
    template_names: T.List[str] = ALL_TEMPLATE_NAMES


class RAG(RAGMode):
    prefix: str = "rag_"
    response_synthesizer_llms: T.List[str] = RESPONSE_SYNTHSIZER_LLMS
    template_names: T.List[str] = ALL_TEMPLATE_NAMES


class ReactRAGAgent(RAGMode):
    prefix: str = "react_rag_agent_"
    response_synthesizer_llms: T.List[str] = FUNCTION_CALLING_LLMS
    template_names: T.List[str] = ALL_TEMPLATE_NAMES
    subquestion_engine_llms: T.List[str] = FUNCTION_CALLING_LLMS
    subquestion_response_synthesizer_llms: T.List[str] = FUNCTION_CALLING_LLMS
    max_iterations_min: int = 10
    max_iterations_max: int = 11


class CritiqueRAGAgent(RAGMode):
    prefix: str = "critique_rag_agent_"
    response_synthesizer_llms: T.List[str] = FUNCTION_CALLING_LLMS
    template_names: T.List[str] = ALL_TEMPLATE_NAMES
    subquestion_engine_llms: T.List[str] = FUNCTION_CALLING_LLMS
    subquestion_response_synthesizer_llms: T.List[str] = FUNCTION_CALLING_LLMS
    critique_agent_llms: T.List[str] = FUNCTION_CALLING_LLMS
    reflection_agent_llms: T.List[str] = FUNCTION_CALLING_LLMS
    max_iterations_min: int = 10
    max_iterations_max: int = 11


class SubQuestionRAGAgent(RAGMode):
    prefix: str = "sub_question_rag_"
    response_synthesizer_llms: T.List[str] = FUNCTION_CALLING_LLMS
    template_names: T.List[str] = ALL_TEMPLATE_NAMES
    subquestion_engine_llms: T.List[str] = FUNCTION_CALLING_LLMS
    subquestion_response_synthesizer_llms: T.List[str] = FUNCTION_CALLING_LLMS


class LATSRagAgent(RAGMode):
    prefix: str = "lats_rag_agent_"

    response_synthesizer_llms: T.List[str] = FUNCTION_CALLING_LLMS
    template_names: T.List[str] = [
        "default",
        "concise",
        "finance-expert",
    ]

    num_expansions_min: int = 2
    num_expansions_max: int = 3
    num_expansions_step: int = 1
    max_rollouts_min: int = 2
    max_rollouts_max: int = 5
    max_rollouts_step: int = 1


class LATS(BaseModel):
    num_expansions_min: int = 2
    num_expansions_max: int = 3
    num_expansions_step: int = 1
    max_rollouts_min: int = 2
    max_rollouts_max: int = 5
    max_rollouts_step: int = 1


class Defaults(BaseModel):
    max_iterations: int = 5


class SearchSpace(BaseModel):
    defaults: Defaults = Defaults()
    non_search_space_params: T.List[str] = ["enforce_full_evaluation", "retrievers"]
    rag_modes: T.List[str] = [
        "no_rag",
        "rag",
        "react_rag_agent",
        "critique_rag_agent",
        "sub_question_rag",
        "lats_rag_agent",
    ]

    few_shot_retriever: FewShotRetriever = FewShotRetriever()
    rag_retriever: Retriever = Retriever(name="rag", top_k=TopK(kmax=128, log=True))
    splitter: Splitter = Splitter()
    reranker: Reranker = Reranker()
    hyde: Hyde = Hyde()
    additional_context: AdditionalContext = AdditionalContext()
    no_rag: NoRAG = NoRAG()
    rag: RAG = RAG()
    react_rag_agent: ReactRAGAgent = ReactRAGAgent()
    critique_rag_agent: CritiqueRAGAgent = CritiqueRAGAgent()
    sub_question_rag: SubQuestionRAGAgent = SubQuestionRAGAgent()
    lats_rag_agent: LATSRagAgent = LATSRagAgent()

    def param_names(
        self, params: T.Dict[str, T.Any] | T.List[str] | None = None
    ) -> T.List[str]:
        return list(self.build_distributions(params=params).keys())

    def get_response_synthesizer_llm_name(self, params: T.Dict[str, T.Any]) -> str:
        """
        When building a flow from a set of parameters, we need to know
        the parameter name that corresponds to the response synthesizer LLM.
        This function returns the correct parameter value based on the RAG mode.
        """
        rag_mode = params["rag_mode"]
        match rag_mode:
            case "no_rag":
                return params["no_rag_response_synthesizer_llm"]
            case "rag":
                return params["rag_response_synthesizer_llm"]
            case "react_rag_agent":
                return params["react_rag_agent_response_synthesizer_llm"]
            case "critique_rag_agent":
                return params["critique_rag_agent_response_synthesizer_llm"]
            case "sub_question_rag":
                return params["sub_question_rag_response_synthesizer_llm"]
            case "lats_rag_agent":
                return params["lats_rag_agent_response_synthesizer_llm"]
            case _:
                raise ValueError(f"Invalid RAG mode: {rag_mode}")

    def is_few_shot(self, params: T.Dict) -> bool:
        return params.get("few_shot_enabled", False)


def get_response_synthesizer_llm(params: T.Dict[str, T.Any], prefix: str = "") -> str:
    rag_mode = params[prefix + "rag_mode"]
    match rag_mode:
        case "no_rag":
            return params[prefix + "no_rag_response_synthesizer_llm"]
        case "rag":
            return params[prefix + "rag_response_synthesizer_llm"]
        case "react_rag_agent":
            return params[prefix + "react_rag_agent_response_synthesizer_llm"]
        case "critique_rag_agent":
            return params[prefix + "critique_rag_agent_response_synthesizer_llm"]
        case "sub_question_rag":
            return params[prefix + "sub_question_rag_response_synthesizer_llm"]
        case "lats_rag_agent":
            return params[prefix + "lats_rag_agent_response_synthesizer_llm"]
        case _:
            raise ValueError(f"Invalid RAG mode: {rag_mode}")


def get_subquestion_engine_llm(
    params: T.Dict[str, T.Any], prefix: str = ""
) -> str | None:
    rag_mode = params[prefix + "rag_mode"]
    match rag_mode:
        case "no_rag":
            return None
        case "rag":
            return None
        case "react_rag_agent":
            return params[prefix + "react_rag_agent_subquestion_engine_llm"]
        case "critique_rag_agent":
            return params[prefix + "critique_rag_agent_subquestion_engine_llm"]
        case "sub_question_rag":
            return params[prefix + "sub_question_rag_subquestion_engine_llm"]
        case "lats_rag_agent":
            return None
        case _:
            raise ValueError(f"Invalid RAG mode: {rag_mode}")


def get_subquestion_response_synthesizer_llm(
    params: T.Dict[str, T.Any], prefix: str = ""
) -> str | None:
    rag_mode = params[prefix + "rag_mode"]
    match rag_mode:
        case "no_rag":
            return None
        case "rag":
            return None
        case "react_rag_agent":
            return params[
                prefix + "react_rag_agent_subquestion_response_synthesizer_llm"
            ]
        case "critique_rag_agent":
            return params[
                prefix + "critique_rag_agent_subquestion_response_synthesizer_llm"
            ]
        case "sub_question_rag":
            return params[
                prefix + "sub_question_rag_subquestion_response_synthesizer_llm"
            ]
        case "lats_rag_agent":
            return None
        case _:
            raise ValueError(f"Invalid RAG mode: {rag_mode}")


def get_critique_agent_llm(params: T.Dict[str, T.Any], prefix: str = "") -> str | None:
    rag_mode = params[prefix + "rag_mode"]
    match rag_mode:
        case "no_rag":
            return None
        case "rag":
            return None
        case "react_rag_agent":
            return None
        case "critique_rag_agent":
            return params[prefix + "critique_rag_agent_critique_agent_llm"]
        case "sub_question_rag":
            return None
        case "lats_rag_agent":
            return None
        case _:
            raise ValueError(f"Invalid RAG mode: {rag_mode}")


def get_reflection_agent_llm(
    params: T.Dict[str, T.Any], prefix: str = ""
) -> str | None:
    rag_mode = params[prefix + "rag_mode"]
    match rag_mode:
        case "no_rag":
            return None
        case "rag":
            return None
        case "react_rag_agent":
            return None
        case "critique_rag_agent":
            return params[prefix + "critique_rag_agent_reflection_agent_llm"]
        case "sub_question_rag":
            return None
        case "lats_rag_agent":
            return None
        case _:
            raise ValueError(f"Invalid RAG mode: {rag_mode}")


def get_template_name(params: T.Dict[str, T.Any], prefix: str = "") -> str:
    rag_mode = params[prefix + "rag_mode"]
    match rag_mode:
        case "no_rag":
            return params[prefix + "no_rag_template_name"]
        case "rag":
            return params[prefix + "rag_template_name"]
        case "react_rag_agent":
            return params[prefix + "react_rag_agent_template_name"]
        case "critique_rag_agent":
            return params[prefix + "critique_rag_agent_template_name"]
        case "sub_question_rag":
            return params[prefix + "sub_question_rag_template_name"]
        case "lats_rag_agent":
            return params[prefix + "lats_rag_agent_template_name"]
        case _:
            raise ValueError(f"Invalid RAG mode: {rag_mode}")


def get_max_iterations(params: T.Dict[str, T.Any], prefix: str = "") -> int | None:
    rag_mode = params[prefix + "rag_mode"]
    match rag_mode:
        case "no_rag":
            return None
        case "rag":
            return None
        case "react_rag_agent":
            return params[prefix + "react_rag_agent_max_iterations"]
        case "critique_rag_agent":
            return params[prefix + "critique_rag_agent_max_iterations"]
        case "sub_question_rag":
            return None
        case "lats_rag_agent":
            return None
        case _:
            raise ValueError(f"Invalid RAG mode: {rag_mode}")
