import json
import os
import copy

from mctextworld.action import ActionLibrary
from mctextworld.memories import DecomposedMemory, HypothesizedRecipeGraph
# from mctextworld.models.deepseek_vl_planning import PlanningModel as DeepSeekPlanningModel
from mctextworld.models.qwen_2_5_planning import PlanningModel as QwenPlanningModel

class Agent:
    def __init__(self, cfg, logger, wp_to_sg_memory: DecomposedMemory):
        self.wp_to_sg_memory = wp_to_sg_memory
        self.logger = logger
        self.topK = cfg["memory"]["topK"]
        self.prefix = cfg["prefix"]
        # self.plan_model = DeepSeekPlanningModel("deepseek-ai/deepseek-vl-7b-chat")
        self.plan_model = QwenPlanningModel("Qwen/Qwen2.5-VL-7B-Instruct", device_id=cfg["device_id"])

        base_dir = os.path.dirname(__file__)
        # self.img_path = os.path.join(base_dir, "models", "img_for_deepseek.jpg")


    def reflect_on_failure(self, item_name, language_action_str, inventory_before_action):
        reflection = self.plan_model.reflect_on_failure(
            item_name, language_action_str, inventory_before_action
        )
        reflection, error_message = render_reflection(reflection)
        if error_message is not None:
            self.logger.warning(f"render_reflection failed. Error message: {error_message}")
            return ""

        return reflection["analysis"]
    
    def make_plan_with_reflection(self, wp):
        is_succeeded, sg_str = self.wp_to_sg_memory.is_succeeded_waypoint(wp)

        # sg_str == "{"task": "~~", "goal": ["~~", 1]}"
        if is_succeeded:
            sg, language_action_str, _ = render_subgoal(sg_str, 1)
            sg = convert_subgoal_to_textworld(sg)
            return wp, sg, language_action_str, None
        else:
            self.logger.info(f"No success experience for waypoint: {wp}, so, call planner to generate a plan.")

            similar_wp_sg_dict = self.wp_to_sg_memory.retrieve_similar_succeeded_waypoints(wp, self.topK)
            reflection_list = self.wp_to_sg_memory.retrieve_all_reflections(wp)

            sg, language_action_str, error_message = self.call_planner_with_reflection_with_retry(
                wp, 1, similar_wp_sg_dict, reflection_list
            )
            if error_message is not None:
                return wp, "", "", error_message

            sg = convert_subgoal_to_textworld(sg)
            return wp, sg, language_action_str, None


    def make_plan(self, wp):
        self.logger.info(f"In make_plan. prefix: {self.prefix}")
        is_succeeded, sg_str = self.wp_to_sg_memory.is_succeeded_waypoint(wp)

        # sg_str == "{"task": "~~", "goal": ["~~", 1]}"
        if is_succeeded:
            sg, language_action_str, _ = render_subgoal(sg_str, 1)
            sg = convert_subgoal_to_textworld(sg)
            return wp, sg, language_action_str, None
        else:
            self.logger.info(f"No success experience for waypoint: {wp}, so, call planner to generate a plan.")

            similar_wp_sg_dict = self.wp_to_sg_memory.retrieve_similar_succeeded_waypoints(wp, self.topK)
            failed_sg_list = self.wp_to_sg_memory.retrieve_failed_subgoals(wp) # could be empty list, i.e., []

            sg, language_action_str, error_message = self.call_planner_with_retry(
                wp, 1, similar_wp_sg_dict, failed_sg_list
            )
            if error_message is not None:
                return wp, "", "", error_message

            sg = convert_subgoal_to_textworld(sg)
            return wp, sg, language_action_str, None


    def make_plan_wo_succ(self, wp):
        self.logger.info(f"In make_plan_wo_succ. prefix: {self.prefix}")
        self.logger.info(f"make_plan_wo_succ. so, call planner to generate a plan.")
        similar_wp_sg_dict = dict()
        failed_sg_list = self.wp_to_sg_memory.retrieve_failed_subgoals(wp) # could be empty list, i.e., []
        sg, language_action_str, error_message = self.call_planner_with_retry(
            wp, 1, similar_wp_sg_dict, failed_sg_list
        )
        if error_message is not None:
            return wp, "", "", error_message

        sg = convert_subgoal_to_textworld(sg)
        return wp, sg, language_action_str, None


    def make_plan_wo_fail(self, wp):
        self.logger.info(f"In make_plan_wo_fail. prefix: {self.prefix}")
        is_succeeded, sg_str = self.wp_to_sg_memory.is_succeeded_waypoint(wp)

        if is_succeeded:
            sg, language_action_str, _ = render_subgoal(sg_str, 1)
            sg = convert_subgoal_to_textworld(sg)
            return wp, sg, language_action_str, None
        else:
            self.logger.info(f"No success experience for waypoint: {wp}, so, call planner to generate a plan.")

            similar_wp_sg_dict = self.wp_to_sg_memory.retrieve_similar_succeeded_waypoints(wp, self.topK)
            failed_sg_list = []

            sg, language_action_str, error_message = self.call_planner_with_retry(
                wp, 1, similar_wp_sg_dict, failed_sg_list
            )
            if error_message is not None:
                return wp, "", "", error_message

            sg = convert_subgoal_to_textworld(sg)
            return wp, sg, language_action_str, None


    def make_plan_wo_succ_fail_memory(self, wp):
        self.logger.info(f"In make_plan_wo_succ_fail_memory. prefix: {self.prefix}")
        self.logger.info(f"make_plan_wo_succ. so, call planner to generate a plan.")
        similar_wp_sg_dict = dict()
        failed_sg_list = []
        sg, language_action_str, error_message = self.call_planner_with_retry(
            wp, 1, similar_wp_sg_dict, failed_sg_list
        )
        if error_message is not None:
            return wp, "", "", error_message

        sg = convert_subgoal_to_textworld(sg)
        return wp, sg, language_action_str, None


    def call_planner_with_retry(self, wp, wp_num, similar_wp_sg_dict, failed_sg_list):
        attempts = 0
        max_retries = 3
        plan, sg_str = [], ""

        while attempts < max_retries:
            attempts += 1

            # self.logger.info(f"Attempt: {attempts}, Just before get_decomposed_plan: ")
            # self.logger.info(f"waypoint: {wp}")
            # self.logger.info(f"similar_wp_sg_dict: {json.dumps(similar_wp_sg_dict)}")
            # self.logger.info(f"failed_sg_list: {str(failed_sg_list)}")
            # self.logger.info(f"Starting get_decomposed_plan ...\n")

            try:
                sg_str, prompt = self.plan_model.decomposed_plan(
                    waypoint=wp,
                    images=None,
                    similar_wp_sg_dict=similar_wp_sg_dict,
                    failed_sg_list_for_wp=failed_sg_list,
                )

                # self.logger.info(f'prompt before render_subgoal at attempt {attempts}')
                # self.logger.info(f"{prompt}\n")
                # self.logger.info(f'sg_str before render_subgoal at attempt {attempts}')
                # self.logger.info(f"{sg_str}\n")

                tmp_subgoal, _, render_error = render_subgoal(copy.deepcopy(sg_str), wp_num)
                if render_error is None:
                    break

                self.logger.warning(f"get_decomposed_plan at attempt {attempts} failed. Error message: {render_error}")
                if attempts >= max_retries:
                    self.logger.error("Max retries reached. Could not fetch get_decomposed_plan.")
                    return "", "", "max_tries_get_decomposed_plan"

            except Exception as e:
                self.logger.info(f"Error in get_decomposed_plan: {e}")
                if attempts >= max_retries:
                    self.logger.error("Max retries reached. Could not fetch get_decomposed_plan.")
                    return "", "", "max_tries_get_decomposed_plan"
                continue

        subgoal, language_action_str, _ = render_subgoal(sg_str, wp_num)

        return subgoal, language_action_str, None


    def call_planner_with_reflection_with_retry(self, wp, wp_num, similar_wp_sg_dict, reflection_list):
        attempts = 0
        max_retries = 3
        plan, sg_str = [], ""

        while attempts < max_retries:
            attempts += 1

            try:
                sg_str, prompt = self.plan_model.plan_with_reflection(
                    waypoint=wp,
                    images=None,
                    similar_wp_sg_dict=similar_wp_sg_dict,
                    reflection_list=reflection_list,
                )

                tmp_subgoal, _, render_error = render_subgoal(copy.deepcopy(sg_str), wp_num)
                if render_error is None:
                    break

                if attempts >= max_retries:
                    self.logger.error("Max retries reached. Could not fetch plan_with_reflection.")
                    return "", "", "max_tries_plan_with_reflection"

            except Exception as e:
                self.logger.info(f"Error in plan_with_reflection: {e}")
                if attempts >= max_retries:
                    self.logger.error("Max retries reached. Could not fetch plan_with_reflection.")
                    return "", "", "max_tries_plan_with_reflection"
                continue

        subgoal, language_action_str, _ = render_subgoal(sg_str, wp_num)

        return subgoal, language_action_str, None


def convert_subgoal_to_textworld(subgoal):
    sg = subgoal

    task = sg['task']
    goal = sg['goal'] # 'goal': ['wooden_pickaxe', 1]}

    # Optimus-1's subgoal example:
    # {"task": "craft wooden pickaxe", "goal": ["wooden_pickaxe", 1]}
    # {"task": "mine cobblestone", "goal": ["cobblestone", 1]}

    # In optimus-1 code, there is no distinction between oak_log, birch_log, etc. They are all simply referred to as "log".
    # So, we need to convert them to the corresponding actions.
    if "log" in goal[0] and ("mine" in task or "chop" in task):
        return "mine_oak_log"
    if "planks" in goal[0] and "craft" in task:
        return "craft_oak_planks"

    if 'mine' in task and (goal[0] == 'coal' or goal[0] == 'coals'):
        return 'mine_coal'

    if 'craft' in task:
        return f"craft_{goal[0]}".replace(" ", "_")
    elif 'smelt' in task:
        return f"smelt_{goal[0]}".replace(" ", "_")
    elif 'mine' in task or 'break' in task or 'gather' in task or 'collect' in task or 'dig' in task:
        return f"mine_{goal[0]}".replace(" ", "_")
    else:
        return task.replace(" ", "_")


def render_subgoal(subgoal: str, wp_num: int = 1) -> str:
    subgoal = subgoal.replace("<task planning>:", "<task planning>").replace("**", "")
    sep_str = "<task planning>"

    temp = subgoal.split(sep_str)[-1].strip()
    if "```json" in temp:
        temp = temp.split("```json")[1].strip().split("```")[0].strip()

    if "{{" in temp:
        temp = temp.replace("{{", "{").replace("}}", "}")

    l = temp.find("{")
    r = temp.rfind("}")
    temp = temp[l: r + 1]

    try:
        temp = json.loads(temp)
    except json.JSONDecodeError as e:
        return None, None, str(e)

    temp["goal"][1] = wp_num

    return temp, temp["task"], None

def render_reflection(reflection: str) -> str:
    reflection = reflection.replace("<failure_analysis>:", "<failure_analysis>").replace("**", "")
    sep_str = "<failure_analysis>"

    temp = reflection.split(sep_str)[-1].strip()
    if "```json" in temp:
        temp = temp.split("```json")[1].strip().split("```")[0].strip()

    if "{{" in temp:
        temp = temp.replace("{{", "{").replace("}}", "}")

    l = temp.find("{")
    r = temp.rfind("}")
    temp = temp[l: r + 1]

    try:
        temp = json.loads(temp)
    except json.JSONDecodeError as e:
        return None, str(e)

    return temp, None


# def render_plan(plan: str, wp_num: int = 1):
#     plan = plan.replace("<task planning>:", "<task planning>").replace("**", "")
#     plan = plan.replace("<replan>:", "<replan>").replace("<replan>", "<task planning>")
#     sep_str = "<task planning>"

#     temp = plan.split(sep_str)[-1].strip()
#     if "```json" in temp:
#         temp = temp.split("```json")[1].strip().split("```")[0].strip()

#     if "{{" in temp:
#         temp = temp.replace("{{", "{").replace("}}", "}")

#     r = temp.rfind("}")
#     temp = temp[: r + 1]

#     try:
#         temp = json.loads(temp)
#     except json.JSONDecodeError as e:
#         return [], str(e)

#     sub_plans = [
#         temp[step]
#         for step in temp.keys()
#         if "open" not in temp[step]["task"]
#         and "place" not in temp[step]["task"]
#         and "access" not in temp[step]["task"]
#     ]

#     for p in sub_plans:
#         p["task"] = p["task"].replace("punch", "chop").replace("collect", "chop").replace("gather", "chop")
#         p['goal'][1] = wp_num

#     return sub_plans, None
