import os, sys

# load data
import json
import argparse
import json, os
from tqdm import tqdm

from Baselines.LLM_Planner import LLM_Planner
from Baselines.MLDT import MLDT
from Baselines.PlanAndAct import Plan_And_Act
from Baselines.PlannerActor import PlannerActor
from Baselines.ProgPrompt import ProgPrompt
from Baselines.ReAct import ReAct
from Baselines.vlm import VLM
from SafeMind import SafeMind

from utils.skill_set import skill_set, skill_set_string
from eval_method import multitask_evaluation

def search_id_in_txt_file(file_path, search_id):
    with open(file_path, "r") as f:
        for line in f:
            if str(search_id) in line:
                print(f"Found ID: {search_id} in line")
                return True
    print(f"ID: {search_id} not found. Start evaluating...")
    return False

def Order_Fix_test(data, agent_type, img_root, output_path=None):
    vlm_model = ""
    if agent_type == "SafeMind":
        agent = SafeMind()
    elif agent_type == "PlannerActor":
        agent = PlannerActor()
    elif agent_type == "ReAct":
        agent = ReAct()
    elif agent_type == 'VLM':
        agent = VLM(model=vlm_model)
    elif agent_type == "MLDT":
        agent = MLDT()
    elif agent_type == "ProgPrompt":
        agent = ProgPrompt()
    elif agent_type == "Plan_And_Act":
        agent = Plan_And_Act()
    elif agent_type == "LLM_Planner":
        agent = LLM_Planner()

    if agent_type == "VLM":
        result_path = output_path + "/" + agent_type + "/"+ vlm_model+"_multitask.json"
    else:
        result_path = output_path+"/"+agent_type+"/multitask.json"
    for i, d in tqdm(enumerate(data['multitask'])):
        image = os.path.join(img_root, "image/multitask", d['image'])
        if not search_id_in_txt_file(result_path, f'"{image}"'):
            scene = d['scene']
            objects = d['objects']
            observation = d['observation']
            hazard = d['inherent safety hazards']
            results = {image: []}
            for instruction in d["instructions"]:
                print("-----------task:", instruction, "----------")
                if agent_type == "SafeMind":
                    if search_id_in_txt_file(
                            "../data_analysis/knowledge/guidelines_final_ids_3.txt",
                            instruction + "\n"):
                        print("pass")
                        continue
                    agent.update_state(scene=scene,objects = objects, image_path=image, instruction=instruction,
                                       skill_set_string=skill_set_string)
                    high_level_plan = agent.planner()
                    print("high_level_plan:", high_level_plan)
                    output = agent.check_plan_damage()
                    if output != "NULL":
                        high_level_plan = agent.replanner()
                    actions = agent.actor(high_level_plan)
                    print("low_level_plan:", actions)
                    replan = agent.judge_safe_requirement()
                    if replan.lower() == 'planner':
                        high_level_plan = agent.planner()
                        actions = agent.actor(high_level_plan)
                    elif replan.lower() == 'actor':
                        actions = agent.actor(high_level_plan)
                elif agent_type == "PlannerActor":
                    if search_id_in_txt_file(
                            "../data_analysis/knowledge/guidelines_final_ids_3.txt",
                            instruction + "\n"):
                        print("pass")
                        continue
                    agent.update_state(scene=scene, objects=objects, image_path=image, instruction=instruction,
                                       skill_set_string=skill_set_string)
                    high_level_plan = agent.planner()
                    print("high_level_plan:", high_level_plan)
                    actions = agent.actor(high_level_plan)
                    print("low_level_plan:", actions)
                elif agent_type == "ReAct":
                    if search_id_in_txt_file(
                            "../data_analysis/knowledge/guidelines_final_ids_3.txt",
                            instruction + "\n"):
                        print("pass")
                        continue
                    agent.update_state(scene=scene, objects=objects, image_path=image, instruction=instruction,
                                       skill_set_string=skill_set_string)
                    actions,prompt_tokens, completion_tokens, total_tokens = agent.generate_react_actions()
                    print("low_level_plan:", actions)
                elif agent_type == "MLDT":
                    if search_id_in_txt_file(
                            "../data_analysis/knowledge/guidelines_final_ids_3.txt",
                            instruction + "\n"):
                        print("pass")
                        continue
                    agent.update_state(scene=scene, objects=objects, image_path=image, instruction=instruction,
                                       skill_set_string=skill_set_string)
                    actions = agent.mldt_generate_plan()
                    print("low_level_plan:", actions)
                elif agent_type == "ProgPrompt":
                    if search_id_in_txt_file(
                            "../data_analysis/knowledge/guidelines_final_ids_3.txt",
                            instruction + "\n"):
                        print("pass")
                        continue
                    agent.update_state(scene=scene, objects=objects, image_path=image, instruction=instruction,
                                       skill_set_string=skill_set)
                    actions = agent.run()
                    print("low_level_plan:", actions)
                elif agent_type == "Plan_And_Act":
                    if search_id_in_txt_file(
                            "../data_analysis/knowledge/guidelines_final_ids_3.txt",
                            instruction + "\n"):
                        print("pass")
                        continue
                    agent.update_state(scene=scene, objects=objects, image_path=image, instruction=instruction,
                                       skill_set_string=skill_set_string)
                    actions = agent.plan_and_act()
                    print("low_level_plan:", actions)
                elif agent_type == "LLM_Planner":
                    if search_id_in_txt_file(
                            "../data_analysis/knowledge/guidelines_final_ids_3.txt",
                            instruction + "\n"):
                        print("pass")
                        continue
                    agent.update_state(scene=scene, objects=objects, image_path=image, instruction=instruction,
                                       skill_set_string=skill_set_string)
                    actions = agent.run()
                    print("low_level_plan:", actions)
                elif agent_type == "VLM":
                    if search_id_in_txt_file(
                            "../data_analysis/knowledge/guidelines_final_ids_3.txt",
                            instruction + "\n"):
                        print("pass")
                        continue
                    agent.update_state(scene=scene, objects=objects, image_path=image, instruction=instruction,
                                       skill_set_string=skill_set_string)
                    actions = agent.actor()
                    print("low_level_plan:", actions)
                eval_result= multitask_evaluation(task=instruction,observation= observation,plans=actions,hazard=hazard,skill_set_string=skill_set_string)
                print("***********eval result**********")
                print("eval_result:",eval_result)
                results[image].append({
                    "instruction":instruction,
                    "output": actions,
                    "evaluation":eval_result,
                })
            if result_path is not None:
                with open(result_path, "a") as f:
                    json.dump(results, f)
                    f.write(',')
                    f.write('\n')

if __name__ == '__main__':
    # args
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_root", type=str, default='Dataset') # data
    parser.add_argument("--output_dir", type=str, default='Result') #
    parser.add_argument("--type", type=str, default='multitask')  #
    parser.add_argument("--agent_type", type=str, default='SafeMind')# LLM_Planner/Plan_And_Act/ProgPrompt/MLDT/VLM/ReAct/SafeMind/PlannerActor
    args = parser.parse_args()

    val_data = json.load(open(os.path.join(args.data_root, "data.json"), 'r'))  # dataset/combined.json

    Order_Fix_test(val_data, args.agent_type, args.data_root, args.output_dir)




