# fmt: off

# ===============================================
# LINK: System Prompt parts
# ===============================================
from .prompts import prompts_per_env


def safe_format(string_template: str, fill_with: str = "", **kwargs) -> str:
    """
    Formats a given template using the provided keyword arguments.
    Missing keys in the template are replaced with an empty string.

    Args:
        template (str): The string template with placeholders.
        **kwargs: Key-value pairs for formatting.

    Returns:
        str: The formatted string with missing keys as empty strings.
    """

    class DefaultDict(dict):
        def __missing__(self, key):
            return fill_with

    return string_template.format_map(DefaultDict(**kwargs))


def get_query_prompt_eval(
    env: str,
    eval_criteria: str = "", 
    cot_part: str = "", 
    add_retrieval_step: bool = False,
) -> str:
    prompts = prompts_per_env[env]
    cot_parts = prompts["cot_parts"]
    eval_criterias = prompts["eval_criterias"]
    k_retrieval_1p = prompts["k_retrieval_queries"]["k_1p"]
    response_format = prompts["response_format"]
    eval_prompt_template = prompts["eval_prompt_template"]
    
    # Get CoT parts
    if cot_part:
        _cot_parts = cot_parts[cot_part].strip()
    else:
        _cot_parts = ""

    # Get eval criteria
    if eval_criteria:
        _eval_criteria = eval_criterias[eval_criteria].strip()
    else:
        _eval_criteria = ""

    # If one-pass, but retrieving K, add K retrieval step before CoT
    k_step = k_retrieval_1p.strip() if add_retrieval_step else ""

    _response_format = response_format.format(k_retrieval_step=k_step, cot_parts=_cot_parts).strip()

    # Get eval prompt
    eval_prompt = safe_format(
        eval_prompt_template,
        eval_criteria=_eval_criteria,
        response_format=_response_format,
    )
    return eval_prompt.strip()


def get_sys_prompt_verifier(
    env: str,
    trace_info: str = "", 
    add_summ_info: bool = False, 
    add_k_info: bool = False,
    two_pass: bool = False,
    prompt_config: dict = None,
    config: dict = None,
) -> str:
    prompts = prompts_per_env[env]
    trace_infos = prompts["trace_infos"]
    summ_info = prompts["summ_info"]
    k_info = prompts["k_info"]

    if prompt_config.get("rule", ""):
        # TODO: Add rules dynamically
        rule_key = prompt_config["rule"]
        rule = prompts[rule_key].strip()
    else:
        rule = ""

    key_sys_prompt = config.get("meta_data", {}).get("sys_prompt", "base") # TODO: clean this up
    sys_prompt_critic_base = prompts["sys_prompts_critic"][key_sys_prompt]    
    if trace_info:
        _trace_info = trace_infos[trace_info].strip()
    else:
        _trace_info = ""

    _summ_info = summ_info if add_summ_info else ""

    if add_k_info:
        if two_pass:
            _k_info = k_info
            _rule = rule
        else:
            _k_info = ""
            _rule = rule
    else:
        _k_info = ""
        _rule = rule

    if config.get("additional_config", {}).get("annotate_actions_on_image", False):
        annotation_info = prompts["annotation_info"]
    else:
        annotation_info = ""
    
    sys_prompt = safe_format(
        sys_prompt_critic_base,
        trace_info=_trace_info,
        summ_info=_summ_info,
        k_info=_k_info,
        rule=_rule,
        annotation_info=annotation_info,
    )
    return sys_prompt.strip()

def get_prompts_first_pass(
    env: str,
    config,
) -> tuple[str, str]:
    prompt_config = config["prompt_args"]
    k_prompt_config = prompt_config["k_config"]
    query: str = k_prompt_config["query"]

    k_retrieval_queries = prompts_per_env[env]["k_retrieval_queries"]
    sys_prompt_k_expert = prompts_per_env[env]["sys_prompts_expert"][query]

    if "expert" in query:
        k_retrieval_query = k_retrieval_queries[query]
        _sys_prompt = sys_prompt_k_expert
    else:
        k_retrieval_query = k_retrieval_queries[query]
        _sys_prompt = get_sys_prompt_verifier(
            env=env,
            trace_info=config["prompt_args"]["trace_info"],
            add_summ_info=config["prompt_args"]["add_summ_info"],
            add_k_info=config["prompt_args"]["add_expectation_info"],
            two_pass=True,
        )
    return k_retrieval_query, _sys_prompt

def get_verifier_prompts(
    config: dict,
) -> dict[str, str]:
    env = config["env"]
    k_config = config["prompt_args"]["k_config"]
    prompt_args = config["prompt_args"]
    add_expectation_info = prompt_args.get("add_expectation_info", False)
    eval_criteria = prompt_args.get("eval_criteria", "")
    cot_part = prompt_args.get("cot_part", "")
    trace_info = prompt_args.get("trace_info", "")
    add_summ_info = prompt_args.get("add_summ_info", False)
    if k_config:
        two_pass = True
    else:
        two_pass = False

    add_retrieval_step: bool = True if (add_expectation_info and not two_pass) else False

    eval_prompt = get_query_prompt_eval(env, eval_criteria, cot_part, add_retrieval_step)
    sys_prompt = get_sys_prompt_verifier(env, trace_info, add_summ_info, add_expectation_info, two_pass, prompt_args, config)
    return {"eval_prompt": eval_prompt, "sys_prompt": sys_prompt}


def get_k_injection(config: dict) -> str:
    env = config["env"]
    prompts = prompts_per_env[env]
    k_injections = prompts["k_injections"]
    k_query = config["prompt_args"]["k_config"]["query"]
    k_injection = k_injections[k_query]
    return k_injection
