import os

from hydra.utils import instantiate
from omegaconf import OmegaConf

import partnr
from partnr.llm.gemini import Gemini  # noqa: F401
from partnr.llm.hf_model import HFModel  # noqa: F401
from partnr.llm.llama import Llama  # noqa: F401
from partnr.llm.llama2 import Llama2  # noqa: F401
from partnr.llm.llama3 import Llama3  # noqa: F401
from partnr.llm.mixtral import Mixtral  # noqa: F401
from partnr.llm.openai import OpenAI  # noqa: F401
from partnr.llm.openai_chat import OpenAIChat  # noqa: F401


def instantiate_llm(llm_name, generation_params=None, **kwargs):
    if generation_params is None:
        generation_params = {}

    if (llm_name == "llama") and ("host" not in kwargs):
        raise ValueError(
            "You must provide a host to instantiate LLaMa."
        )

    # Get the path to the LLM config file
    partnr_dir_path = os.path.dirname(partnr.__file__)
    llm_config_path = f"{partnr_dir_path}/conf/llm/{llm_name}.yaml"
    assert os.path.exists(
        llm_config_path
    ), f"LLM config file not found at {llm_config_path}"

    # Load the LLM config file
    llm_config = OmegaConf.load(llm_config_path)

    # Update the config with the kwargs
    if generation_params:
        llm_config.generation_params = OmegaConf.merge(
            llm_config.generation_params, OmegaConf.create(generation_params)
        )

    if kwargs:
        llm_config = OmegaConf.merge(llm_config, OmegaConf.create(kwargs))

    llm = instantiate(llm_config.llm)(llm_config)

    return llm
