from RAG import RAG, LLM, rag_chain, get_scene_from_instruction
from plan import schedule, print_arm
from DAG import parse_input, build_nodes, modify_nodes, problem1, problem2, problem3


def load_file(path):
    with open(path, "r", encoding="utf-8") as f:
        return f.read()


def load_environment(scene):
    env_paths = {
        "kitchen scene": "./memory/short_term_memory/env_kitchen.txt",
        "office scene": "./memory/short_term_memory/env_office.txt",
        "agricultural greenhouse scene": "./memory/short_term_memory/env_agricultural_greenhouse.txt",
        "factory scene": "./memory/short_term_memory/env_factory.txt"
    }
    if scene not in env_paths:
        print("Invalid scene name")
        return None
    return load_file(env_paths[scene])


def format_problems():
    prompts = []
    if problem1:
        prompts.append(f"Depends on another object's place node: {', '.join(map(str, problem1))}")
    if problem2:
        prompts.append(f"Does not depend on the tool usage node but directly depends on the pick node: {', '.join(map(str, problem2))}")
    if problem3:
        prompts.append(f"Depends on another object's tool usage node: {', '.join(map(str, problem3))}")
    return "\n".join(prompts) if prompts else None


def retry_dag_correction(llm, base_response, max_attempts=2):
    for attempt in range(max_attempts):
        problems_section = format_problems()
        if not problems_section:
            break

        print(problems_section)
        modify_prompt = load_file("./prompts/DAG_prompt_correction.txt").format(
            response=base_response,
            problems_section=problems_section
        )
        print(f"try again_{attempt + 1}")
        corrected_response = llm.call(modify_prompt)
        print(corrected_response)

        nodes_data = parse_input(corrected_response)
        nodes = build_nodes(nodes_data)
        res = modify_nodes(nodes)

        if res == 0:
            print(f"pass {'second' if attempt == 0 else 'third'} call")
            return corrected_response

    return base_response


def main():
    # === User Instruction ===
    instruction = "Complete the kitchen scene package A"
    api_path = "YOUR_API_KEY_PATH"
    
    # === Stage 1: LLM-Driven DAG Planning ===
    print("Stage 1 begin...")
    rag = RAG(document_path="./memory/long_term_memory/knowledge_set.txt", model_name="moka-ai/m3e-base")
    llm = LLM(api_path)
    rag_prompt_template = load_file("./prompts/RAG_prompt.txt")

    target_scene = get_scene_from_instruction(instruction, rag)
    rag_output = rag_chain(instruction, rag, llm, prompt_template=rag_prompt_template)

    target_env = load_environment(target_scene)
    if target_env is None:
        return

    dag_prompt = load_file("./prompts/DAG_prompt_first_call.txt").format(
        instruction=instruction,
        target_env=target_env,
        rag_output=rag_output
    )

    response = llm.call(dag_prompt)

    nodes_data = parse_input(response)
    nodes = build_nodes(nodes_data)
    res = modify_nodes(nodes)
    dag_output = response

    if res != 0:
        dag_output = retry_dag_correction(llm, response)

    print("Stage 1 done.")

    # === Stage 2: Dual-Arm Parallel Scheduling ===
    print("Stage 2 begin...")
    nodes_data = parse_input(dag_output)
    nodes = build_nodes(nodes_data)

    result = schedule(nodes)
    print(f"Total execution time: {result['total']} seconds\n")
    print_arm("Left arm schedule table:", result['left'])
    print_arm("Right arm schedule table:", result['right'])
    print("Stage 2 end. All finished.")


if __name__ == "__main__":
    main()