import copy
import json
import logging
import os

from omegaconf import DictConfig, OmegaConf

from ours.util import (
    render_subgoal,
)

from ours.memories import DecomposedMemory
from ours.memories import KnowledgeGraph as OracleGraph
from ours.qwen_vl_planning import PlanningModel as QwenVLPlanningModel


def call_planner_with_retry(
    plan_model: QwenVLPlanningModel,
    wp: str,
    wp_num: int,
    similar_wp_sg_dict: dict,
    failed_sg_list: list,
    logger: logging.Logger,
):
    attempts = 0
    max_retries = 3
    subgoal, sg_str = [], ""
    while attempts < max_retries:
        attempts += 1

        logger.info(f"Attempt: {attempts}, Just before get_decomposed_plan: ")
        logger.info(f"waypoint: {wp}")
        logger.info(f"similar_wp_sg_dict: {json.dumps(similar_wp_sg_dict)}")
        logger.info(f"failed_sg_list: {str(failed_sg_list)}")
        logger.info(f"Starting get_decomposed_plan ...\n")

        try:
            sg_str, prompt = plan_model.decomposed_plan(
                waypoint=wp,
                image_path=None,
                similar_wp_sg_dict=similar_wp_sg_dict,
                failed_sg_list_for_wp=failed_sg_list,
            )

            logger.info(f'prompt before render_subgoal at attempt {attempts}')
            logger.info(f"{prompt}\n")
            logger.info(f'sg_str before render_subgoal at attempt {attempts}')
            logger.info(f"{sg_str}\n")

            tmp_subgoal, _, render_error = render_subgoal(copy.deepcopy(sg_str), wp_num)
            if render_error is None:
                break

            logger.warning(f"get_decomposed_plan at attempt {attempts} failed. Error message: {render_error}")
            if attempts >= max_retries:
                logger.error("Max retries reached. Could not fetch get_decomposed_plan.")
                return [], "", "max_tries_get_decomposed_plan"

        except Exception as e:
            logger.info(f"Error in get_decomposed_plan: {e}")
            if attempts >= max_retries:
                logger.error("Max retries reached. Could not fetch get_decomposed_plan.")
                return [], "", "max_tries_get_decomposed_plan"
            continue

    subgoal, language_action_str, _ = render_subgoal(sg_str, wp_num)

    return subgoal, language_action_str, None



def call_planner_with_reflection(
    plan_model: QwenVLPlanningModel,
    wp: str,
    wp_num: int,
    similar_wp_sg_dict: dict,
    reflection_list: list,
    logger: logging.Logger,
):
    attempts = 0
    max_retries = 3
    subgoal, sg_str = [], ""
    while attempts < max_retries:
        attempts += 1

        logger.info(f"Attempt: {attempts}, Just before get_decomposed_plan: ")
        logger.info(f"waypoint: {wp}")
        logger.info(f"similar_wp_sg_dict: {json.dumps(similar_wp_sg_dict)}")
        logger.info(f"reflection_list: {str(reflection_list)}")
        logger.info(f"Starting get_decomposed_plan ...\n")

        try:
            sg_str, prompt = plan_model.plan_with_reflection(
                waypoint=wp,
                image_path=None,
                similar_wp_sg_dict=similar_wp_sg_dict,
                reflection_list=reflection_list,
            )

            logger.info(f'prompt before render_subgoal at attempt {attempts}')
            logger.info(f"{prompt}\n")
            logger.info(f'sg_str before render_subgoal at attempt {attempts}')
            logger.info(f"{sg_str}\n")

            tmp_subgoal, _, render_error = render_subgoal(copy.deepcopy(sg_str), wp_num)
            if render_error is None:
                break

            logger.warning(f"get_decomposed_plan at attempt {attempts} failed. Error message: {render_error}")
            if attempts >= max_retries:
                logger.error("Max retries reached. Could not fetch get_decomposed_plan.")
                return [], "", "max_tries_get_decomposed_plan"

        except Exception as e:
            logger.info(f"Error in get_decomposed_plan: {e}")
            if attempts >= max_retries:
                logger.error("Max retries reached. Could not fetch get_decomposed_plan.")
                return [], "", "max_tries_get_decomposed_plan"
            continue

    subgoal, language_action_str, _ = render_subgoal(sg_str, wp_num)

    return subgoal, language_action_str, None



def retrieve_waypoints(
    waypoint_generator: OracleGraph,
    item: str,
    number: int = 1,
    cur_inventory: dict = dict()
) -> str:
    item = item.lower().replace(" ", "_")
    item = item.replace("logs", "log")

    _cur_inventory = copy.deepcopy(cur_inventory)
    if item in _cur_inventory:
        del _cur_inventory[item]

    pretty_result, ordered_text, ordered_item, ordered_item_quantity = \
        waypoint_generator.compile(item.replace(" ", "_"), number, _cur_inventory)
    return pretty_result


def subgoal_ours_with_full_memory(
    plan_model: QwenVLPlanningModel,
    original_final_goal: str,
    inventory: dict,
    action_memory: DecomposedMemory,
    waypoint_generator: OracleGraph,
    topK: int,
    logger: logging.Logger,
):
    wp_list_str = retrieve_waypoints(waypoint_generator, original_final_goal, 1, inventory)
    logger.info(f"Retrieved waypoint list: {wp_list_str}")
    first_wp_str = wp_list_str.splitlines()[1] # 0th line is 'craft 1 <goal> summary:'

    wp = first_wp_str.split('.')[1].split(':')[0].strip()
    if 'log' in wp:
        wp = 'logs'

    wp_num = int(first_wp_str.split('.')[1].split('need')[1].strip())

    is_succeeded, sg_str = action_memory.is_succeeded_waypoint(wp)

    logger.info(f"In subgoal_ours_with_full_memory")
    logger.info(f"waypoint: {wp}, waypoint_num: {wp_num}")
    logger.info(f"is_succeeded: {str(is_succeeded)}")

    if is_succeeded:
        subgoal, language_action_str, _ = render_subgoal(sg_str, wp_num)
        return wp, subgoal, language_action_str, None

    else:
        logger.info(f"No success experience for waypoint: {wp}, so, call planner to generate a plan.")

        similar_wp_sg_dict = action_memory.retrieve_similar_succeeded_waypoints(wp, topK)
        failed_sg_list = action_memory.retrieve_failed_subgoals(wp) # could be empty list, i.e., []

        subgoal, language_action_str, error_message = call_planner_with_retry(
            plan_model, wp, wp_num, similar_wp_sg_dict, failed_sg_list, logger
        )

        return wp, subgoal, language_action_str, error_message


def subgoal_only_with_reuse_no_positive_reference_with_reflection(
    plan_model: QwenVLPlanningModel,
    original_final_goal: str,
    inventory: dict,
    action_memory: DecomposedMemory,
    waypoint_generator: OracleGraph,
    topK: int,
    logger: logging.Logger,
):
    wp_list_str = retrieve_waypoints(waypoint_generator, original_final_goal, 1, inventory)
    logger.info(f"Retrieved waypoint list: {wp_list_str}")
    first_wp_str = wp_list_str.splitlines()[1] # 0th line is 'craft 1 <goal> summary:'

    wp = first_wp_str.split('.')[1].split(':')[0].strip()
    if 'log' in wp:
        wp = 'logs'

    wp_num = int(first_wp_str.split('.')[1].split('need')[1].strip())

    is_succeeded, sg_str = action_memory.is_succeeded_waypoint(wp)

    logger.info(f"In subgoal_only_with_reuse_no_positive_reference")
    logger.info(f"waypoint: {wp}, waypoint_num: {wp_num}")
    logger.info(f"is_succeeded: {str(is_succeeded)}")

    # If the waypoint has a success experience, this function directly uses it
    if is_succeeded:
        subgoal, language_action_str, _ = render_subgoal(sg_str, wp_num)
        return wp, subgoal, language_action_str, None
    
    # This function does not use the failed experience
    # Also, this function does not use the similar succeeded experience
    # It only calls the planner LLM to generate a plan
    else:
        logger.info(f"No success experience for waypoint: {wp}, so, call planner to generate a plan.")

        similar_wp_sg_dict = dict()
        failed_sg_list = []

        reflection_list = action_memory.retrieve_all_reflections(wp)

        subgoal, language_action_str, error_message = call_planner_with_reflection(
            plan_model, wp, wp_num, similar_wp_sg_dict, reflection_list, logger
        )

        return wp, subgoal, language_action_str, error_message


# DECKARD, ADAM will use this function
def subgoal_only_with_reuse_no_positive_reference(
    plan_model: QwenVLPlanningModel,
    original_final_goal: str,
    inventory: dict,
    action_memory: DecomposedMemory,
    waypoint_generator: OracleGraph,
    topK: int,
    logger: logging.Logger,
):
    wp_list_str = retrieve_waypoints(waypoint_generator, original_final_goal, 1, inventory)
    logger.info(f"Retrieved waypoint list: {wp_list_str}")
    first_wp_str = wp_list_str.splitlines()[1] # 0th line is 'craft 1 <goal> summary:'

    wp = first_wp_str.split('.')[1].split(':')[0].strip()
    if 'log' in wp:
        wp = 'logs'

    wp_num = int(first_wp_str.split('.')[1].split('need')[1].strip())

    is_succeeded, sg_str = action_memory.is_succeeded_waypoint(wp)

    logger.info(f"In subgoal_only_with_reuse_no_positive_reference")
    logger.info(f"waypoint: {wp}, waypoint_num: {wp_num}")
    logger.info(f"is_succeeded: {str(is_succeeded)}")

    # If the waypoint has a success experience, this function directly uses it
    if is_succeeded:
        subgoal, language_action_str, _ = render_subgoal(sg_str, wp_num)
        return wp, subgoal, language_action_str, None
    
    # This function does not use the failed experience
    # Also, this function does not use the similar succeeded experience
    # It only calls the planner LLM to generate a plan
    else:
        logger.info(f"No success experience for waypoint: {wp}, so, call planner to generate a plan.")

        similar_wp_sg_dict = dict()
        failed_sg_list = []

        subgoal, language_action_str, error_message = call_planner_with_retry(
            plan_model, wp, wp_num, similar_wp_sg_dict, failed_sg_list, logger
        )

        return wp, subgoal, language_action_str, error_message


def subgoal_wo_fail(
    plan_model: QwenVLPlanningModel,
    original_final_goal: str,
    inventory: dict,
    action_memory: DecomposedMemory,
    waypoint_generator: OracleGraph,
    topK: int,
    logger: logging.Logger,
):
    wp_list_str = retrieve_waypoints(waypoint_generator, original_final_goal, 1, inventory)
    logger.info(f"Retrieved waypoint list: {wp_list_str}")
    first_wp_str = wp_list_str.splitlines()[1] # 0th line is 'craft 1 <goal> summary:'

    wp = first_wp_str.split('.')[1].split(':')[0].strip()
    if 'log' in wp:
        wp = 'logs'

    wp_num = int(first_wp_str.split('.')[1].split('need')[1].strip())

    is_succeeded, sg_str = action_memory.is_succeeded_waypoint(wp)

    logger.info(f"In subgoal_wo_fail")
    logger.info(f"waypoint: {wp}, waypoint_num: {wp_num}")
    logger.info(f"is_succeeded: {str(is_succeeded)}")

    if is_succeeded:
        subgoal, language_action_str, _ = render_subgoal(sg_str, wp_num)
        return wp, subgoal, language_action_str, None

    else:
        logger.info(f"No success experience for waypoint: {wp}, so, call planner to generate a plan.")

        similar_wp_sg_dict = action_memory.retrieve_similar_succeeded_waypoints(wp, topK)
        failed_sg_list = []

        subgoal, language_action_str, error_message = call_planner_with_retry(
            plan_model, wp, wp_num, similar_wp_sg_dict, failed_sg_list, logger
        )

        return wp, subgoal, language_action_str, error_message


def subgoal_wo_succ(
    plan_model: QwenVLPlanningModel,
    original_final_goal: str,
    inventory: dict,
    action_memory: DecomposedMemory,
    waypoint_generator: OracleGraph,
    topK: int,
    logger: logging.Logger,
):
    wp_list_str = retrieve_waypoints(waypoint_generator, original_final_goal, 1, inventory)
    logger.info(f"Retrieved waypoint list: {wp_list_str}")
    first_wp_str = wp_list_str.splitlines()[1] # 0th line is 'craft 1 <goal> summary:'

    wp = first_wp_str.split('.')[1].split(':')[0].strip()
    if 'log' in wp:
        wp = 'logs'

    wp_num = int(first_wp_str.split('.')[1].split('need')[1].strip())

    logger.info(f"In subgoal_wo_succ")
    logger.info(f"waypoint: {wp}, waypoint_num: {wp_num}")

    similar_wp_sg_dict = dict()
    failed_sg_list = action_memory.retrieve_failed_subgoals(wp) # could be empty list, i.e., []

    subgoal, language_action_str, error_message = call_planner_with_retry(
        plan_model, wp, wp_num, similar_wp_sg_dict, failed_sg_list, logger
    )

    return wp, subgoal, language_action_str, error_message


def subgoal_wo_succ_fail_memory(
    plan_model: QwenVLPlanningModel,
    original_final_goal: str,
    inventory: dict,
    action_memory: DecomposedMemory,
    waypoint_generator: OracleGraph,
    topK: int,
    logger: logging.Logger,
):
    wp_list_str = retrieve_waypoints(waypoint_generator, original_final_goal, 1, inventory)
    logger.info(f"Retrieved waypoint list: {wp_list_str}")
    first_wp_str = wp_list_str.splitlines()[1] # 0th line is 'craft 1 <goal> summary:'

    wp = first_wp_str.split('.')[1].split(':')[0].strip()
    if 'log' in wp:
        wp = 'logs'

    wp_num = int(first_wp_str.split('.')[1].split('need')[1].strip())

    logger.info(f"In subgoal_wo_succ_fail_memory")
    logger.info(f"waypoint: {wp}, waypoint_num: {wp_num}")

    similar_wp_sg_dict = dict()
    failed_sg_list = []

    subgoal, language_action_str, error_message = call_planner_with_retry(
        plan_model, wp, wp_num, similar_wp_sg_dict, failed_sg_list, logger
    )

    return wp, subgoal, language_action_str, error_message


def make_one_subgoal(
    plan_model: QwenVLPlanningModel,
    original_final_goal: str,
    inventory: dict,
    action_memory: DecomposedMemory,
    waypoint_generator: OracleGraph,
    topK: int,
    logger: logging.Logger,
    cfg: DictConfig,
):
    prefix = cfg["prefix"]
    logger.info(f"[yellow]In make_one_subgoal(), prefix: {prefix}[/yellow]")

    if "wo_succ_fail" in prefix or "pure_llm" in prefix:
        wp, subgoal, language_action_str, error_message = subgoal_wo_succ_fail_memory(
            plan_model,
            original_final_goal,
            inventory,
            action_memory,
            waypoint_generator,
            topK,
            logger,
        )
    elif "wo_succ" in prefix:
        wp, subgoal, language_action_str, error_message = subgoal_wo_succ(
            plan_model,
            original_final_goal,
            inventory,
            action_memory,
            waypoint_generator,
            topK,
            logger,
        )
    elif "wo_fail" in prefix:
        wp, subgoal, language_action_str, error_message = subgoal_wo_fail(
            plan_model,
            original_final_goal,
            inventory,
            action_memory,
            waypoint_generator,
            topK,
            logger,
        )
    elif "deckard" in prefix or "adam" in prefix:
        wp, subgoal, language_action_str, error_message = subgoal_only_with_reuse_no_positive_reference(
            plan_model,
            original_final_goal,
            inventory,
            action_memory,
            waypoint_generator,
            topK,
            logger,
        )

    elif "self_correction" in prefix:
        wp, subgoal, language_action_str, error_message = subgoal_only_with_reuse_no_positive_reference_with_reflection(
            plan_model,
            original_final_goal,
            inventory,
            action_memory,
            waypoint_generator,
            topK,
            logger,
        )

    elif "ours" in prefix or "full" in prefix:
        wp, subgoal, language_action_str, error_message = subgoal_ours_with_full_memory(
            plan_model,
            original_final_goal,
            inventory,
            action_memory,
            waypoint_generator,
            topK,
            logger,
        )
    else:
        logger.error(f"Unknown prefix: {prefix}")
        raise ValueError(f"Unknown prefix: {prefix}")

    return wp, subgoal, language_action_str, error_message

