from procoder.functional import *
from procoder.functional import indent4 as indent
from procoder.prompt import *
from toolemu.prompts.generator.tool_gen.shared import *
from toolemu.utils import get_num_tokens, make_colorful

STD_TOOL_GEN_GENERAL_TASK_DESC = NamedBlock(
    "Task Description",
    TOOL_GEN_GENERAL_TASK_DESC,
)

ORIG_TOOL_GEN_PROMPT = (
    Collection(
        STD_TOOL_GEN_GENERAL_TASK_DESC,
        TOOL_GEN_REQUIREMENTS,
        TOOL_GEN_BLACKLIST,
        TOOL_FORMAT,
        TOOL_GEN_TASK_INSTRUCTION,
    )
    .set_sep("\n\n")
    .set_indexing_method(sharp2_indexing)
    # .enable_name()
)


def remove_risky_req(prompt):
    brainstorm_toolkit_step = find_submodule(prompt, "brainstorm_toolkit_step")
    prompt = removed_submodules(
        prompt,
        [
            "risky_requirement",
            "toolkit_risks",
            "potential_risks",
            "brainstorm_tool_risks",
            "assess_risks",
            "no_unauthorized_access_risk",
        ],
    )

    replace_func = lambda x: x.replace(", {risky_requirement}", "")
    brainstorm_toolkit_step = replace_prompt(brainstorm_toolkit_step, replace_func)
    prompt = replaced_submodule(
        prompt, "brainstorm_toolkit_step", brainstorm_toolkit_step
    )

    return prompt


STD_TOOL_GEN_PROMPT = remove_risky_req(ORIG_TOOL_GEN_PROMPT)


STD_TOOL_GEN_FIXED_PROMPT = apply_gen_fixed_toolkit_prompt(STD_TOOL_GEN_PROMPT)
STD_TOOL_GEN_BRAINSTORMED_PROMPT = apply_gen_brainstormed_toolkit_prompt(
    STD_TOOL_GEN_PROMPT
)

STD_TOOL_GEN_FIXED_THOUGHTS_PROMPT = remove_spec_format_instruction(
    STD_TOOL_GEN_FIXED_PROMPT
)
STD_TOOL_GEN_BRAINSTORMED_THOUGHTS_PROMPT = remove_spec_format_instruction(
    STD_TOOL_GEN_BRAINSTORMED_PROMPT
)


if __name__ == "__main__":
    from toolemu.utils import JSON_TYPES_STR

    def print_prompt(prompt):
        inputs = {
            "existing_tools": None,
            "existing_domains": None,
            "json_types": JSON_TYPES_STR,
            "examples": None,
            "toolkit_name": None,
            "toolkit_desc": None,
        }
        system_prompt, example_prompt = format_multiple_prompts(
            [TOOL_GEN_SYSTEM_MESSAGE, prompt],
            inputs,
            include_brackets=[False, True],
        )
        print(make_colorful("system", system_prompt))
        print(make_colorful("human", example_prompt))
        print("\n\n>>>>Token lengths:", get_num_tokens(example_prompt))

    gen_fixed_toolkit = True
    if gen_fixed_toolkit:
        gen_prompt = STD_TOOL_GEN_FIXED_PROMPT
        gen_thought_prompt = STD_TOOL_GEN_FIXED_THOUGHTS_PROMPT
    else:
        gen_prompt = STD_TOOL_GEN_BRAINSTORMED_PROMPT
        gen_thought_prompt = STD_TOOL_GEN_BRAINSTORMED_THOUGHTS_PROMPT

    print("\n\nSTD Tool Gen Prompt:")
    print_prompt(gen_prompt)

    print("\n\nSTD Tool Gen Thoughts Prompt:")
    print_prompt(gen_thought_prompt)
