
import argparse

if __name__ == '__main__':

    # args
    parser = argparse.ArgumentParser()
    parser.add_argument("--agent_type", type=str, default='SafeMind') # Agent：LLM_Planner/ProgPrompt/Plan_And_Act/MLDT/VLM/ReAct/SecureMind/PlannerActor/
    parser.add_argument("--data_root", type=str, default='dataset') #
    parser.add_argument("--output_dir", type=str, default='result') #
    args = parser.parse_args()


    val_data = json.load(open(os.path.join(args.data_root, "combined.json"), 'r'))  # dataset/combined.json
    test_each_mss(val_data, args.agent_type,args.data_root, args.output_dir)


import json, os
from tqdm import tqdm

from Baselines.LLM_Planner import LLM_Planner
from Baselines.MLDT import MLDT
from Baselines.PlanAndAct import Plan_And_Act
from Baselines.ProgPrompt import ProgPrompt
from Baselines.ReAct import ReAct
from Baselines.vlm import VLM
from SafeMind import SafeMind
from Baselines.PlannerActor import PlannerActor
from utils.skill_set import skill_set,skill_set_string
from utils.eval import safe_task_evaluation,unsafe_task_evaluation

def search_id_in_txt_file(file_path, search_id):
    with open(file_path, "r") as f:
        for line in f:
            if str(search_id) in line:
                print(f"Found ID: {search_id} in line")
                return True
    print(f"ID: {search_id} not found.Start evaluating...")
    return False

def test_each_mss(data, agent_type, img_root, output_path=None):
    vlm_model = ""
    if agent_type == "SafeMind":
        agent = SafeMind()
    elif agent_type == "PlannerActor":
        agent = PlannerActor()
    elif agent_type == 'ReAct':
        agent = ReAct()
    elif agent_type == 'MLDT':
        agent = MLDT()
    elif agent_type == "Plan_And_Act":
        agent = Plan_And_Act()
    elif agent_type == "ProgPrompt":
        agent = ProgPrompt()
    elif agent_type == "LLM_Planner":
        agent = LLM_Planner()
    elif agent_type == 'VLM':
        agent = VLM(model=vlm_model)

    if agent_type == "VLM":
        result_path = output_path + "/" + agent_type + "/"+ vlm_model+"_output.json"
    else:
        result_path = output_path+"/"+agent_type+"_output.json"

    for i, d in tqdm(enumerate(data['embodied'])):
        safe_image = os.path.join(img_root, "image", d['safe'])
        if not search_id_in_txt_file(result_path, f'"{safe_image}"'):
            unsafe_image = os.path.join(img_root, "image", d['unsafe'])
            safe_objects  = d['safe_objects']
            unsafe_objects = d['unsafe_objects']
            safe_observations = d['observation_safe']
            unsafe_observations = d['observation_unsafe']
            scene = d['scene']
            Corrective_measures = d['Corrective_measures']
            results = {safe_image: []}

            for safe_instr, unsafe_instr in zip(d["safe_instructions"], d["unsafe_instructions"]):
                if agent_type == "SafeMind":
                    if search_id_in_txt_file("../data_analysis/knowledge/guidelines_final_ids_2.txt",
                            unsafe_instr + "\n"):
                        print("pass")
                        continue

                    print("-----------task:",safe_instr,safe_observations,"----------")
                    agent.update_state(scene=scene, objects=safe_objects,  image_path=safe_image, instruction=safe_instr,
                                       skill_set_string=skill_set_string)
                    high_level_plan = agent.planner()
                    print("high_level_plan:",high_level_plan)
                    output = agent.check_plan_damage()
                    if output != "NULL":
                        high_level_plan = agent.replanner()
                    safe_output = agent.actor(high_level_plan)
                    print("low_level_plan:", safe_output)
                    if not safe_output == ['warning']:
                        replan = agent.judge_safe_requirement()
                        if replan.lower() == 'planner':
                            high_level_plan = agent.planner()
                            safe_output = agent.actor(high_level_plan)
                        elif replan.lower() == 'actor':
                            safe_output = agent.actor(high_level_plan)

                    print("-----------task:", unsafe_instr, unsafe_observations, "----------")
                    agent.update_state(scene=scene,objects = unsafe_objects, image_path=unsafe_image, instruction=unsafe_instr,
                                       skill_set_string=skill_set_string)
                    high_level_plan = agent.planner()
                    print("high_level_plan:", high_level_plan)
                    output = agent.check_plan_damage()
                    if output != "NULL":
                        high_level_plan = agent.replanner()
                    unsafe_output = agent.actor(high_level_plan)
                    print("low_level_plan:", unsafe_output)
                    if not unsafe_output == ['warning']:
                        replan = agent.judge_safe_requirement()
                        if replan.lower() == 'planner':
                            high_level_plan = agent.planner()
                            unsafe_output = agent.actor(high_level_plan)
                        elif replan.lower() == 'actor':
                            unsafe_output = agent.actor(high_level_plan)

                elif agent_type == "PlannerActor":
                    if search_id_in_txt_file("../data_analysis/knowledge/guidelines_final_ids_2.txt",
                            unsafe_instr + "\n"):
                        print("pass")
                        continue

                    print("-----------task:", safe_instr, safe_observations, "----------")
                    agent.update_state(scene=scene, objects=safe_objects, image_path=safe_image, instruction=safe_instr,
                                       skill_set_string=skill_set_string)
                    high_level_plan = agent.planner()
                    print("high_level_plan:", high_level_plan)
                    safe_output = agent.actor(high_level_plan)
                    print("low_level_plan:", safe_output)

                    print("-----------task:", unsafe_instr, unsafe_observations, "----------")
                    agent.update_state(scene=scene, objects=unsafe_objects, image_path=unsafe_image,
                                       instruction=unsafe_instr,
                                       skill_set_string=skill_set_string)
                    high_level_plan = agent.planner()
                    print("high_level_plan:", high_level_plan)
                    unsafe_output = agent.actor(high_level_plan)
                    print("low_level_plan:", unsafe_output)

                elif agent_type == 'ReAct':
                    if search_id_in_txt_file("../data_analysis/knowledge/guidelines_final_ids_2.txt",
                                             unsafe_instr + "\n"):
                        print("pass")
                        continue
                    print("-----------task:", safe_instr, safe_observations, "----------")
                    agent.update_state(scene=scene, objects=safe_objects, image_path=safe_image, instruction=safe_instr,
                                       skill_set_string=skill_set_string)
                    safe_output, prompt_tokens, completion_tokens, total_tokens = agent.generate_react_actions()
                    print("low_level_plan:", safe_output)
                    print("-----------task:", unsafe_instr, unsafe_observations, "----------")
                    agent.update_state(scene=scene, objects=unsafe_objects, image_path=unsafe_image,
                                       instruction=unsafe_instr,
                                       skill_set_string=skill_set_string)
                    unsafe_output, prompt_tokens, completion_tokens, total_tokens = agent.generate_react_actions()
                    print("low_level_plan:", unsafe_output)
                elif agent_type == "MLDT":
                    if search_id_in_txt_file("../data_analysis/knowledge/guidelines_final_ids_2.txt",
                                             unsafe_instr + "\n"):
                        print("pass")
                        continue
                    print("-----------task:", safe_instr, safe_observations, "----------")
                    agent.update_state(scene=scene, objects=safe_objects, image_path=safe_image, instruction=safe_instr,
                                       skill_set_string=skill_set_string)
                    safe_output = agent.mldt_generate_plan()
                    print("low_level_plan:", safe_output)
                    print("-----------task:", unsafe_instr, unsafe_observations, "----------")
                    agent.update_state(scene=scene, objects=unsafe_objects, image_path=unsafe_image,
                                       instruction=unsafe_instr,
                                       skill_set_string=skill_set_string)
                    unsafe_output = agent.mldt_generate_plan()
                    print("low_level_plan:", unsafe_output)
                elif agent_type == "Plan_And_Act":
                    if search_id_in_txt_file("../data_analysis/knowledge/guidelines_final_ids_2.txt",
                                             unsafe_instr + "\n"):
                        print("pass")
                        continue
                    print("-----------task:", safe_instr, safe_observations, "----------")
                    agent.update_state(scene=scene, objects=safe_objects, image_path=safe_image, instruction=safe_instr,
                                       skill_set_string=skill_set_string)
                    safe_output = agent.plan_and_act()
                    print("low_level_plan:", safe_output)
                    print("-----------task:", unsafe_instr, unsafe_observations, "----------")
                    agent.update_state(scene=scene, objects=unsafe_objects, image_path=unsafe_image,
                                       instruction=unsafe_instr,
                                       skill_set_string=skill_set_string)
                    unsafe_output = agent.plan_and_act()
                    print("low_level_plan:", unsafe_output)
                elif agent_type == "ProgPrompt":
                    if search_id_in_txt_file("../data_analysis/knowledge/guidelines_final_ids_2.txt",
                                             unsafe_instr + "\n"):
                        print("pass")
                        continue
                    print("-----------task:", safe_instr, safe_observations, "----------")
                    agent.update_state(scene=scene, objects=safe_objects, image_path=safe_image, instruction=safe_instr,
                                       skill_set_string=skill_set)
                    safe_output = agent.run()
                    print("low_level_plan:", safe_output)
                    print("-----------task:", unsafe_instr, unsafe_observations, "----------")
                    agent.update_state(scene=scene, objects=unsafe_objects, image_path=unsafe_image,
                                       instruction=unsafe_instr,
                                       skill_set_string=skill_set)
                    unsafe_output = agent.run()
                    print("low_level_plan:", unsafe_output)
                elif agent_type == "LLM_Planner":
                    if search_id_in_txt_file("../data_analysis/knowledge/guidelines_final_ids_2.txt",
                                             unsafe_instr + "\n"):
                        print("pass")
                        continue
                    print("-----------task:", safe_instr, safe_observations, "----------")
                    agent.update_state(scene=scene, objects=safe_objects, image_path=safe_image, instruction=safe_instr,
                                       skill_set_string=skill_set_string)
                    safe_output = agent.run()
                    print("low_level_plan:", safe_output)
                    print("-----------task:", unsafe_instr, unsafe_observations, "----------")
                    agent.update_state(scene=scene, objects=unsafe_objects, image_path=unsafe_image,
                                       instruction=unsafe_instr,
                                       skill_set_string=skill_set_string)
                    unsafe_output = agent.run()
                    print("low_level_plan:", unsafe_output)

                elif agent_type == 'VLM':
                    if search_id_in_txt_file("../data_analysis/knowledge/guidelines_final_ids_2.txt",
                                             unsafe_instr + "\n"):
                        print("pass")
                        continue
                    print("-----------task:", safe_instr, safe_observations, "----------")
                    agent.update_state(scene=scene, objects=safe_objects, image_path=safe_image, instruction=safe_instr,
                                       skill_set_string=skill_set_string)
                    safe_output = agent.actor()
                    print("low_level_plan:", safe_output)
                    print("-----------task:", unsafe_instr, unsafe_observations, "----------")
                    agent.update_state(scene=scene, objects=unsafe_objects, image_path=unsafe_image,
                                       instruction=unsafe_instr,
                                       skill_set_string=skill_set_string)
                    unsafe_output = agent.actor()
                    print("low_level_plan:", unsafe_output)
                safe_eval_result=safe_task_evaluation(safe_instr=safe_instr,skill_set=skill_set_string, plans=safe_output,observation=safe_observations)
                unsafe_eval_result=unsafe_task_evaluation(task=unsafe_instr,plans=unsafe_output,observation=unsafe_observations,Corrective_measures=Corrective_measures)
                print("***********eval result**********")
                print("safe_eval_result:",safe_eval_result)
                print("unsafe_eval_result:",unsafe_eval_result)
                results[safe_image].append({
                    "safe_instruction": safe_instr,
                    "unsafe_instruction": unsafe_instr,
                    "safe_img": safe_image,
                    "unsafe_img": unsafe_image,
                    "safe_output": safe_output,
                    "unsafe_output": unsafe_output,
                    "safe_task_execution_evaluation":safe_eval_result,
                    "unsafe_task_evaluation":unsafe_eval_result
                })

            if result_path is not None:
                with open(result_path, "a") as f:
                    json.dump(results, f)
                    f.write(',')
                    f.write('\n')

