import argparse
import copy
import math
import json
import logging
import os
import re
import shutil
import time
from typing import Any, Dict
import sys
import random
import numpy as np
import torch

import hydra
import shortuuid
from omegaconf import DictConfig, OmegaConf
from rich.progress import Progress, TaskID, TimeElapsedColumn
import wandb

from mctextworld.simulator import *
from mctextworld.action import ActionLibrary
from mctextworld.memories.decomposed_memory import DecomposedMemory
from mctextworld.memories.hypothesized_recipe_graph import HypothesizedRecipeGraph
from mctextworld.agent import Agent
from mctextworld.utils import get_logger, change_textworld_item_name, is_belief_correct, create_new_action_lib_json, get_action_lib_patch


def retrieve_first_waypoint(
    waypoint_generator: HypothesizedRecipeGraph,
    item: str,
    number: int = 1,
    cur_inventory: dict = dict()
) -> str:
    item = item.lower().replace(" ", "_")

    _cur_inventory = copy.deepcopy(cur_inventory)
    if item in _cur_inventory:
        del _cur_inventory[item]

    pretty_result, ordered_text, ordered_item, ordered_item_quantity = \
        waypoint_generator.compile(item.replace(" ", "_"), number, _cur_inventory)
    return ordered_item[0]


def feasibility_min_count_all(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger, cfg):
    exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_hypothesized()
    knowledge_score_dict = hypothesized_recipe_graph.calculate_knowledge_all_hypothesized()
    level_dict = hypothesized_recipe_graph.calculate_level_all_hypothesized()
    hypothesized_item_names = list(exploration_count_dict.keys())

    min_count = min(exploration_count_dict.values())
    items_with_min_count = [item for item in hypothesized_item_names if exploration_count_dict[item] == min_count]

    final_score_dict = {}
    for item_name in items_with_min_count:
        knowledge = knowledge_score_dict[item_name]
        level = level_dict[item_name]

        final_score_dict[item_name] = knowledge / level

    best_feasible_item = None
    max_score = float('-inf')
    for item, score in final_score_dict.items():
        if score > max_score:
            max_score = score
            best_feasible_item = item

    prefix = cfg["prefix"]
    if exploration_count_dict[best_feasible_item] > 1 and "wo_recipe_revision" not in prefix:
        hypothesized_recipe_graph.update_hypothesis(best_feasible_item)
    return best_feasible_item


def feasibility_min_count_frontier(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger, cfg):
    frontier_item_names = hypothesized_recipe_graph.find_frontiers()
    exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_hypothesized()
    knowledge_score_dict = hypothesized_recipe_graph.calculate_knowledge_all_hypothesized()
    level_dict = hypothesized_recipe_graph.calculate_level_all_hypothesized()

    frontier_exploration_count_dict = {}
    frontier_knowledge_score_dict = {}
    frontier_level_dict = {}

    for item_name in frontier_item_names:
        frontier_exploration_count_dict[item_name] = exploration_count_dict[item_name]
        frontier_knowledge_score_dict[item_name] = knowledge_score_dict[item_name]
        frontier_level_dict[item_name] = level_dict[item_name]

    min_count = min(frontier_exploration_count_dict.values())
    frontier_items_with_min_count = [item for item in frontier_item_names if frontier_exploration_count_dict[item] == min_count]

    final_score_dict = {}
    for item_name in frontier_items_with_min_count:
        knowledge = frontier_knowledge_score_dict[item_name]
        level = frontier_level_dict[item_name]

        if cfg['feasibility'] == 'w_knowledge':
            final_score_dict[item_name] = knowledge / level

        elif cfg['feasibility'] == 'wo_knowledge':
            final_score_dict[item_name] = 1 / level

    if cfg['feasibility'] == 'w_knowledge':
        best_feasible_item = None
        max_score = float('-inf')
        for item, score in final_score_dict.items():
            if score > max_score:
                max_score = score
                best_feasible_item = item
    
    elif cfg['feasibility'] == 'wo_knowledge':
        max_value = max(final_score_dict.values())
        max_items = [key for key, value in final_score_dict.items() if value == max_value]
        best_feasible_item = random.choice(max_items)


    prefix = cfg["prefix"]
    if frontier_exploration_count_dict[best_feasible_item] > 1 and "wo_recipe_revision" not in prefix and "wo_guided_revision" not in prefix:
        hypothesized_recipe_graph.update_hypothesis(best_feasible_item) 
    else:
        logger.info(f"prefix: {prefix} is not supported. Hypothesis is not updated when selecting intrinsic goal")
    return best_feasible_item


def random_min_count_all_hypothesis(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger, cfg):
    exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_hypothesized()

    min_count = min(exploration_count_dict.values())

    min_count = min(exploration_count_dict.values())
    hypothesized_item_with_min_count = [item for item in exploration_count_dict.keys() if exploration_count_dict[item] == min_count]

    return random.choice(hypothesized_item_with_min_count)


def random_among_frontier_and_verified(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger, cfg):
    frontier_item_names = hypothesized_recipe_graph.find_frontiers()
    frontier_exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_frontiers()

    inadmissible_threshold = cfg["memory"]["inadmissible_threshold"] # c_0 in the paper "Embodied Decision Making using Language Guided World Modelling"
    admissible_item_names = [item for item in frontier_item_names if frontier_exploration_count_dict[item] <= inadmissible_threshold]

    if len(admissible_item_names) == 0:
        logger.info(f"No admissible items. Select from all frontiers + verified.")
        admissible_item_names = list(set(frontier_item_names + hypothesized_recipe_graph.get_verified_item_names()))

    logger.info(f'\nIn random_among_frontier_and_verified()')
    logger.info(f'frontier_item_names: {frontier_item_names}')
    logger.info(f'admissible_item_names: {admissible_item_names}')

    return random.choice(admissible_item_names)

def ours_like_deckard(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger, cfg):
    frontier_item_names = hypothesized_recipe_graph.find_frontiers()
    exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_hypothesized()

    frontier_exploration_count_dict = {}
    for item_name in frontier_item_names:
        frontier_exploration_count_dict[item_name] = exploration_count_dict[item_name]

    inadmissible_threshold = cfg["memory"]["inadmissible_threshold"] # c_0 in the paper "Embodied Decision Making using Language Guided World Modelling"
    admissible_item_names = [item for item in frontier_item_names if frontier_exploration_count_dict[item] <= inadmissible_threshold]

    if len(admissible_item_names) == 0:
        logger.info(f"No admissible items. Select from all frontiers.")
        admissible_item_names = list(set(frontier_item_names))

    logger.info(f'\nIn ours_like_deckard()')
    logger.info(f'frontier_item_names: {frontier_item_names}')
    logger.info(f'admissible_item_names: {admissible_item_names}')

    selected_int_goal = random.choice(admissible_item_names)

    if exploration_count_dict[selected_int_goal] > 1:
        hypothesized_recipe_graph.update_hypothesis(selected_int_goal)
    return selected_int_goal

def ours_like_deckard_min_random(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger, cfg):
    exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_hypothesized()

    min_val = min(exploration_count_dict.values())

    candidates = [k for k, v in exploration_count_dict.items() if v == min_val]
    
    logger.info(f'\nIn ours_like_deckard_min_random()')
    logger.info(f'candidate_item_names: {candidates}')

    selected_int_goal = random.choice(candidates)

    if exploration_count_dict[selected_int_goal] > 1:
        hypothesized_recipe_graph.update_hypothesis(selected_int_goal)
    return selected_int_goal

def ours_goal_wo_feasibility(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger, cfg):
    frontier_item_names = hypothesized_recipe_graph.find_frontiers()
    exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_hypothesized()
    knowledge_score_dict = hypothesized_recipe_graph.calculate_knowledge_all_hypothesized()
    level_dict = hypothesized_recipe_graph.calculate_level_all_hypothesized()

    frontier_exploration_count_dict = {}
    frontier_knowledge_score_dict = {}
    frontier_level_dict = {}

    for item_name in frontier_item_names:
        frontier_exploration_count_dict[item_name] = exploration_count_dict[item_name]
        frontier_knowledge_score_dict[item_name] = knowledge_score_dict[item_name]
        frontier_level_dict[item_name] = level_dict[item_name]

    min_count = min(frontier_exploration_count_dict.values())
    frontier_items_with_min_count = [item for item in frontier_item_names if frontier_exploration_count_dict[item] == min_count]

    random_from_frontier_with_min_count = random.choice(frontier_items_with_min_count)
    logger.info(f"in ours_goal_wo_feasibility(). intrinsic goal: {random_from_frontier_with_min_count}")

    if frontier_exploration_count_dict[random_from_frontier_with_min_count] > 1:
        hypothesized_recipe_graph.update_hypothesis(random_from_frontier_with_min_count)
    return random_from_frontier_with_min_count


def ours_goal_wo_count(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger, cfg):
    frontier_item_names = hypothesized_recipe_graph.find_frontiers()
    exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_hypothesized()
    knowledge_score_dict = hypothesized_recipe_graph.calculate_knowledge_all_hypothesized()
    level_dict = hypothesized_recipe_graph.calculate_level_all_hypothesized()

    frontier_exploration_count_dict = {}
    frontier_knowledge_score_dict = {}
    frontier_level_dict = {}

    for item_name in frontier_item_names:
        frontier_exploration_count_dict[item_name] = exploration_count_dict[item_name]
        frontier_knowledge_score_dict[item_name] = knowledge_score_dict[item_name]
        frontier_level_dict[item_name] = level_dict[item_name]

    # min_count = min(frontier_exploration_count_dict.values())
    # frontier_items_with_min_count = [item for item in frontier_item_names if frontier_exploration_count_dict[item] == min_count]

    final_score_dict = {}
    for item_name in frontier_item_names:
        knowledge = frontier_knowledge_score_dict[item_name]
        level = frontier_level_dict[item_name]

        final_score_dict[item_name] = knowledge / level

    best_feasible_item = None
    max_score = float('-inf')
    for item, score in final_score_dict.items():
        if score > max_score:
            max_score = score
            best_feasible_item = item

    logger.info(f"in ours_goal_wo_count(). intrinsic goal: {best_feasible_item}")

    if frontier_exploration_count_dict[best_feasible_item] > 1:
        hypothesized_recipe_graph.update_hypothesis(best_feasible_item)
    return best_feasible_item


def select_int_goal(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger, cfg):
    prefix = cfg["prefix"]

    # TODO: make various intrinsic goal algorithms
    if "feasibility_divided_by_count" in prefix:
        logger.info(f"prefix: {prefix}. Select intrinsic goal by feasibility_divided_by_count")
        raise NotImplementedError
    elif "ours_with_reflection" in prefix or "ours_wo_fail" in prefix:
        logger.info(f"prefix: {prefix}. Select intrinsic goal by ours_with_reflection or ours_wo_fail")
        int_goal = feasibility_min_count_frontier(hypothesized_recipe_graph, logger, cfg)
    elif "wo_feasibility" in prefix:
        logger.info(f"prefix: {prefix}. Select intrinsic goal by ours_goal_wo_feasibility")
        int_goal = ours_goal_wo_feasibility(hypothesized_recipe_graph, logger, cfg)
    elif "like_deckard_min_random" in prefix:
        logger.info(f"prefix: {prefix}. Ours but select intrinsic goal like deckard and in min random")
        int_goal = ours_like_deckard_min_random(hypothesized_recipe_graph, logger, cfg)
    elif "like_deckard" in prefix:
        logger.info(f"prefix: {prefix}. Ours but select intrinsic goal like deckard")
        int_goal = ours_like_deckard(hypothesized_recipe_graph, logger, cfg)
    elif "ours_goal_uniform_random" in prefix:
        logger.info(f"prefix: {prefix}. Select intrinsic goal by ours_goal_uniform_random")
        int_goal = random.choice(list(set(hypothesized_recipe_graph.hypothesized_item_names)))
    elif "feasibility_min_count_frontier" in prefix or "ours" in prefix:
        logger.info(f"prefix: {prefix}. Select intrinsic goal by feasibility_min_count_frontier")
        int_goal = feasibility_min_count_frontier(hypothesized_recipe_graph, logger, cfg)
    elif "feasibility_min_count_all" in prefix:
        logger.info(f"prefix: {prefix}. Select intrinsic goal by feasibility_min_count_all")
        int_goal = feasibility_min_count_all(hypothesized_recipe_graph, logger, cfg)
    elif "deckard" in prefix:
        logger.info(f"prefix: {prefix}. Select intrinsic goal by random_among_frontier")
        int_goal = random_among_frontier_and_verified(hypothesized_recipe_graph, logger, cfg)
    elif "adam" in prefix or "pure_llm" in prefix or "optimus" in prefix:
        logger.info(f"prefix: {prefix}. Select intrinsic goal randomly from all hypothesized items")
        int_goal = random.choice(list(set(hypothesized_recipe_graph.hypothesized_item_names)))
    elif "random_min_count_all_hypothesis" in prefix:
        logger.info(f"prefix: {prefix}. Select intrinsic goal by random_min_count_all_hypothesis")
        int_goal = random_min_count_all_hypothesis(hypothesized_recipe_graph, logger, cfg)
    else:
        logger.error(f"prefix: {prefix} is not supported.")
        sys.exit(1)

    return int_goal


def check_enough_all_crafting_resources(obs, hypothesized_recipe_graph):
    not_enough_resources = []
    minimum_path = math.inf

    for resource in hypothesized_recipe_graph.crafting_resources:
        item_name = copy.deepcopy(resource)
        if item_name == "coals":
            item_name = "coal"
        elif item_name == "planks":
            item_name = "oak_planks"
        elif item_name == "logs":
            item_name = "oak_log"

        if item_name not in obs['inventory'] or obs['inventory'][item_name] < 8:
            if minimum_path > len(hypothesized_recipe_graph._calculate_path(item_name)):
                not_enough_resources.clear()
                not_enough_resources.append(resource)

                minimum_path = len(hypothesized_recipe_graph._calculate_path(item_name))
            
    return not_enough_resources


def get_verified_benchmark_items(verified_items, benchmark_goal_items):
    num_verified_benchmark_items = 0

    for verified_item in verified_items:
        item_name = copy.deepcopy(verified_item)

        if "log" in verified_item:
            item_name = "logs"
        elif "planks" in verified_item:
            item_name = "planks"
        elif verified_item == "coal" or verified_item == "coals":
            item_name = "coals"
        else:
            item_name = verified_item

        if item_name in benchmark_goal_items:
            num_verified_benchmark_items += 1

    return num_verified_benchmark_items


def reflect_on_failure(item_name, language_action_str, inventory_before_action,
                       subgoal_memory: DecomposedMemory, agent: Agent,
                       plan_failure_threshold, logger):
    history_dict = subgoal_memory.get_history_of_action(item_name, language_action_str)
    if history_dict['failure'] >= plan_failure_threshold:
        prev_reflection = history_dict.get("reflection", dict())

        if prev_reflection is None or len(prev_reflection) == 0:
            # Call LLM to get reflection
            reflection = agent.reflect_on_failure(item_name, language_action_str, inventory_before_action)
            logger.info(f'''\nReflection on failure of {item_name}''')
            logger.info(f'Action: {language_action_str}')
            logger.info(f'Inventory: {inventory_before_action}')
            logger.info(f'Reflection: {reflection}\n')
            subgoal_memory.save_reflection(item_name, language_action_str, inventory_before_action, reflection)

            return {
                "item_name": item_name,
                "inventory": inventory_before_action,
                "plan": language_action_str,
                "failure_analysis": reflection
            }

    return None

@hydra.main(version_base=None, config_path="conf", config_name="evaluate")
def main(cfg: DictConfig):
    maximum_step = cfg["maximum_step"]
    prefix = cfg["prefix"]
    hydra_path = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

    # Open-ended exploration task
    # There is no goal, and the agent should explore recipes of items as many as possible.
    env = Env(
        task_name="open_ended_exploration",
        init_inv=dict(),
        task_obj=None,
        maximum_step=maximum_step,
        # action_lib=cfg["action_lib"],
    )

    obs, reward, done, info = env.reset()
    if 'new' in cfg["action_lib"] and 'level' in cfg["action_lib"]:
        perturbed_action_lib = create_new_action_lib_json(env, cfg["action_lib"])
        env.set_action_lib(perturbed_action_lib)
        obs, reward, done, info = env.reset()

    logger = get_logger(__name__)

    src_dir = cfg["memory"]["root_path"] + "_" + cfg["initial_memory"]
    dst_dir = os.path.join(hydra_path, "v1")
    logger.info(f"Copy {src_dir} to {dst_dir}")
    shutil.copytree(src_dir, dst_dir)
    cfg["memory"]["path"] = dst_dir

    # Set random seed
    seed = cfg["seed"]

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    if "guided_revision" in prefix:
        os.environ["WANDB_CONSOLE"] = "auto"

    is_dynamic_ground_truth = False
    if 0 < cfg["action_lib_change_step"] and cfg["action_lib_change_step"] < cfg["maximum_step"]:
        is_dynamic_ground_truth = True


    plan_failure_threshold = cfg["memory"]["plan_failure_threshold"]
    prefix = cfg.get("prefix")
    logger.info(f"wandb.run.dir: {wandb.run.dir}")
    logger.info(f"prefix: {prefix}")
    logger.info(f'cfg["action_lib"]: {cfg["action_lib"]}\n')

    subgoal_memory = DecomposedMemory(cfg, logger)
    agent = Agent(cfg, logger, subgoal_memory)
    hypothesized_recipe_graph = HypothesizedRecipeGraph(agent.plan_model, cfg, logger)

    experienced_item_names = []

    all_reflections = []

    int_goal = None
    int_goal_steps = 0
    wp = ""
    subgoal = None
    language_action_str = ""
    subgoal_done = False

    # threshold_failure_waypoint = 5
    # wp_trials = 0

    memory_root_path = cfg["memory"]["path"]
    done_steps = 0
    done_data = []

    inventory_before_action = copy.deepcopy(obs['inventory'])

    if 'data.json' in os.listdir(memory_root_path):
        with open(os.path.join(memory_root_path, 'data.json'), 'r') as f:
            done_data = json.load(f)
        for data in done_data:
            done_steps = max(done_steps, data['step'])

    for i in range(done_steps + 1, done_steps + maximum_step + 1):
        # rebuttal for dynamic target graph
        if i == cfg["action_lib_change_step"]:
            logger.info(f"Change action library at step {i}.")
            if cfg["base_action_lib"] != cfg["changed_action_lib"]:
                perturbed_action_lib = create_new_action_lib_json(env, cfg["changed_action_lib"])
                env.set_action_lib(perturbed_action_lib)

                experienced_item_names = list(set(experienced_item_names))
                action_lib_patch = get_action_lib_patch(perturbed_action_lib)
                hypothesized_recipe_graph.reset_recipe_revised_items()
                for action_name in action_lib_patch["replace"].keys():
                    item_name = action_name.replace("craft_", "").replace("smelt_", "").replace("mine_", "")
                    hypothesized_recipe_graph.move_verified_recipe_to_hypothesized(item_name)
                    # delete item name in experienced_item_names
                    if item_name in experienced_item_names:
                        experienced_item_names.remove(item_name)
                    subgoal_memory.reset_success_failure_history(item_name)

                for action_name in action_lib_patch["add"].keys():
                    item_name = action_name.replace("craft_", "").replace("smelt_", "").replace("mine_", "")
                    hypothesized_recipe_graph.move_verified_recipe_to_hypothesized(item_name)
                    # delete item name in experienced_item_names
                    if item_name in experienced_item_names:
                        experienced_item_names.remove(item_name)
                    subgoal_memory.reset_success_failure_history(item_name)

                hypothesized_recipe_graph.load_and_init_all_recipes()
                logger.info(f"Action library changed to {cfg['changed_action_lib']}.")

        if int_goal is None:
            not_enough_resources = check_enough_all_crafting_resources(obs, hypothesized_recipe_graph)
            if "adam" in prefix and len(not_enough_resources) > 0:
                logger.info(f"prefix with adam. prefix: {prefix}. Select intrinsic goal randomly from not_enough_resources")
                int_goal = random.choice(not_enough_resources)
            elif "wo_initial_hypothesis" in prefix and len(not_enough_resources) > 0:
                logger.info(f"prefix with wo_initial_hypothesis. prefix: {prefix}. Select intrinsic goal randomly from not_enough_resources")
                int_goal = random.choice(not_enough_resources)
            else:
                int_goal = select_int_goal(hypothesized_recipe_graph, logger, cfg)

            int_goal_steps = 0
            subgoal = None
            logger.info(f"New intrinsic goal: {int_goal}")
            exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_hypothesized()
            cnt = exploration_count_dict.get(int_goal, 1)
            logger.info(f"Exploration count of the intrinsic goal {int_goal}: {cnt}")
            logger.info(f"Current all crafting resources {hypothesized_recipe_graph.crafting_resources}\n")

        if subgoal is None:
            wp = retrieve_first_waypoint(hypothesized_recipe_graph, int_goal, 1, obs['inventory'])
            if "wo_fail" in prefix:
                _, subgoal, language_action_str, error_message = agent.make_plan_wo_fail(wp)
            elif "with_reflection" in prefix:
                _, subgoal, language_action_str, error_message = agent.make_plan_with_reflection(wp)
            else:
                _, subgoal, language_action_str, error_message = agent.make_plan(wp)
            if error_message is not None:
                logger.error(f"Error during make_plan(). error message: {error_message}")
                # status = "cannot generate plan"
                break

            # logger.info(f"After make_plan()")
            # logger.info(f"[yellow]Waypoint: {wp}, Subgoal: {subgoal}[/yellow]")
        
        wp_log = copy.deepcopy(wp)
        subgoal_log = copy.deepcopy(subgoal)

        # logger.info(f"Step {i}, Intrinsic Goal: {int_goal}, Waypoint: {wp}, Subgoal: {subgoal}")
        # logger.info(f"Recipe from graph of waypoint {wp}: {str(hypothesized_recipe_graph.get_recipe(wp))}")
        # logger.info(f"Inventory before subgoal {subgoal}: {obs['inventory']}")

        obs, reward, done, info = env.step(subgoal)
        inventory_after_action = copy.deepcopy(obs['inventory'])

        # logger.info(f"Inventory after subgoal {subgoal}: {obs['inventory']}\n")

        if 'reach_maximum_step' in info and info['reach_maximum_step']:
            logger.info(f"Reach maximum step {maximum_step}.")
            break

        if info['action_success']:
            success_item_name = info['item_name']
            subgoal_done = True
            subgoal = None
            # wp_trials = 0
            subgoal_memory.save_success_failure(wp, language_action_str, True)

            success_item_name = change_textworld_item_name(success_item_name)

            if success_item_name not in experienced_item_names:
                logger.info(f"New item is experienced: {success_item_name}")
                experienced_item_names.append(success_item_name)
                recipe_data = {
                    "item_name": info['item_name'],
                    "output_qty": info['output_qty'],
                    "ingredients": info['ingredients'],
                    "required_pickaxe": info['required_pickaxe'],
                    "is_crafting": info['is_crafting'],
                }
                recipe_data = change_log_planks_coal_recipe(recipe_data)
                hypothesized_recipe_graph.save_verified_recipe_data(recipe_data, prefix)

            if success_item_name == int_goal:
                logger.info(f"Intrinsic goal {int_goal} reached!")
                int_goal = None
                int_goal_steps = 0

        else:
            logger.info(f"Subgoal {subgoal} failed.")

            # If int_goal is not a frontier or low-level controller fail.
            if "uniform" in prefix and wp != int_goal:
                revise_goal = int_goal if cfg["uniform_revise_goal"] == "intrinsic_goal" else wp
                subgoal_memory.save_success_failure(revise_goal, language_action_str, False)
                subgoal = None

                ret = reflect_on_failure(revise_goal, language_action_str, inventory_before_action, subgoal_memory, agent, plan_failure_threshold, logger)
                if ret is not None:
                    all_reflections.append(ret)

                revise_goal_total_failure_counts = abs(subgoal_memory.retrieve_total_failed_counts(revise_goal))

                if revise_goal_total_failure_counts >= plan_failure_threshold * 3:
                    logger.warning(f"{revise_goal} failed {revise_goal_total_failure_counts} times, so increment exploration count of {revise_goal}.")

                    inventory = copy.deepcopy(obs['inventory'])
                    inventory_for_hypothesis_update = {}
                    for k, v in inventory.items():
                        name = change_textworld_item_name(k)
                        inventory_for_hypothesis_update[name] = v
                    experienced_items_for_update = []
                    for k in experienced_item_names:
                        name = change_textworld_item_name(k)
                        experienced_items_for_update.append(name)

                    all_reflections = subgoal_memory.retrieve_all_reflections(revise_goal)
                    hypothesized_recipe_graph.increment_count(revise_goal, prefix, inventory_for_hypothesis_update, experienced_items_for_update, all_reflections)

                    # reset success failure history of the changed items from the subgoal_memory
                    recipe_revised_items = hypothesized_recipe_graph.get_recipe_revised_items()
                    logger.info(f"recipe_revised_items: {recipe_revised_items}. prefix: {prefix}")
                    for item in recipe_revised_items:
                        subgoal_memory.reset_success_failure_history(item)
                    # subgoal_memory.reset_success_failure_history(revise_goal)
                    hypothesized_recipe_graph.reset_recipe_revised_items()

                    int_goal = None
                    subgoal = None
                    # int_goal_trials = 0
                    int_goal_steps = 0

            else:
                subgoal_memory.save_success_failure(wp, language_action_str, False)
                subgoal = None

                ret = reflect_on_failure(wp, language_action_str, inventory_before_action, subgoal_memory, agent, plan_failure_threshold, logger)
                if ret is not None:
                    all_reflections.append(ret)

                wp_total_failure_counts = abs(subgoal_memory.retrieve_total_failed_counts(wp))

                # if we explore all actions ("mine", "craft", "smelt") enough but still fail,
                # then the hypothesized recipe for the intrinsic goal is wrong.
                # So, increment exploration count of the intrinsic goal and change the hypothesized recipe.
                # Then, change the intrinsic goal.
                if wp_total_failure_counts >= plan_failure_threshold * 3:
                    logger.warning(f"{wp} failed {wp_total_failure_counts} times, so increment exploration count of {wp}.")

                    inventory = copy.deepcopy(obs['inventory'])
                    inventory_for_hypothesis_update = {}
                    for k, v in inventory.items():
                        name = change_textworld_item_name(k)
                        inventory_for_hypothesis_update[name] = v
                    experienced_items_for_update = []
                    for k in experienced_item_names:
                        name = change_textworld_item_name(k)
                        experienced_items_for_update.append(name)

                    all_reflections = subgoal_memory.retrieve_all_reflections(wp)
                    hypothesized_recipe_graph.increment_count(wp, prefix, inventory_for_hypothesis_update, experienced_items_for_update, all_reflections)

                    # reset success failure history of the changed items from the subgoal_memory
                    recipe_revised_items = hypothesized_recipe_graph.get_recipe_revised_items()
                    logger.info(f"recipe_revised_items: {recipe_revised_items}. prefix: {prefix}")
                    for item in recipe_revised_items:
                        subgoal_memory.reset_success_failure_history(item)
                    subgoal_memory.reset_success_failure_history(wp)
                    hypothesized_recipe_graph.reset_recipe_revised_items()

                    int_goal = None
                    subgoal = None
                    # wp_trials = 0
                    int_goal_steps = 0

        int_goal_steps += 1
        if int_goal_steps >= cfg["int_goal_steps"]:
            logger.info(f"Int goal steps {int_goal_steps} reached. Change intrinsic goal.")
            int_goal = None
            int_goal_steps = 0

        _verified_items = list(set(copy.deepcopy(hypothesized_recipe_graph.verified_item_names)))
        _hypothesized_items = list(set(copy.deepcopy(hypothesized_recipe_graph.hypothesized_item_names)))
        _frontier_items = list(set(copy.deepcopy(hypothesized_recipe_graph.frontier_item_names)))
        _inadmissible_items = list(set(copy.deepcopy(hypothesized_recipe_graph.inadmissible_item_names)))

        # logger.info(f"Step {i}. _verified_items: {str(_verified_items)}")
        # logger.info(f"Step {i}. _hypothesized_items: {str(_hypothesized_items)}")
        # logger.info(f"Step {i}. _inadmissible_items: {str(_inadmissible_items)}")

        belief_correct_verified_benchmark_goals = []
        belief_incorrect_verified_benchmark_goals = []
        num_verified_benchmark_item = 0
        for v in _verified_items:
            if v in ["logs", "log", "oak_log", "planks", "oak_planks", "coal", "coals"]:
                continue
            if v in hypothesized_recipe_graph.goal_items and hypothesized_recipe_graph.graph[v]['is_verified']:
                oracle_action = env.action_lib.oracle_item_to_action_dict[v]
                flag = is_belief_correct(
                    copy.deepcopy(hypothesized_recipe_graph.graph[v]['ingredients']),
                    copy.deepcopy(env.action_lib.action_lib[oracle_action])
                )
                num_verified_benchmark_item += flag
                if flag:
                    belief_correct_verified_benchmark_goals.append(v)
                else:
                    belief_incorrect_verified_benchmark_goals.append(v)

        wandb.log({
            "action_success": info['action_success'],
            "belief_correct_verified_benchmark_goals": len(belief_correct_verified_benchmark_goals),
            "belief_incorrect_verified_benchmark_goals": len(belief_incorrect_verified_benchmark_goals),
            "num_verified": len(_verified_items),
            "num_hypothesized": len(_hypothesized_items),
            "num_frontier": len(_frontier_items),
            "num_inadmissible": len(_inadmissible_items),
            "step": i,
        })

        num_verified_benchmark_item = get_verified_benchmark_items(_verified_items, copy.deepcopy(hypothesized_recipe_graph.goal_items))

        tmp = subgoal_log.replace(f'_{wp_log}', '')
        op = "none"
        if "mine" in tmp or "dig" in tmp:
            op = "mine"
        elif "smelt" in tmp:
            op = "smelt"
        elif "craft" in tmp:
            op = "craft"

        curr_data = {
            "step": i,
            "inventory_before_action": copy.deepcopy(inventory_before_action),
            "inventory_after_action": copy.deepcopy(inventory_after_action),
            # "inventory": copy.deepcopy(obs['inventory']),
            "waypoint": copy.deepcopy(wp_log),
            "op": copy.deepcopy(op),
            "intrinsic_goal": copy.deepcopy(int_goal),
            "action": copy.deepcopy(subgoal_log),
            "action_success": copy.deepcopy(info['action_success']),
            "num_verified": len(_verified_items),
            "num_hypothesized": len(_hypothesized_items),
            "num_frontier": len(_frontier_items),
            "num_inadmissible": len(_inadmissible_items),
            "num_verified_benchmark_item": num_verified_benchmark_item,
            "verified_items": _verified_items,
            "hypothesized_items": _hypothesized_items,
            "frontier_items": _frontier_items,
            "inadmissible_items": _inadmissible_items,
            "experienced_item_names": copy.deepcopy(experienced_item_names),
            "belief_correct_verified_benchmark_goals": belief_correct_verified_benchmark_goals,
            "belief_incorrect_verified_benchmark_goals": belief_incorrect_verified_benchmark_goals
        }
        done_data.append(curr_data)

        inventory_before_action = copy.deepcopy(obs['inventory'])

        if done:
            print("\n\n\n\n")
            env.print_obs()
            print("Task Finished!")
            break

    uuid = shortuuid.uuid()[:4]
    if 'main_ablate_hypothesis.log' in os.listdir(hydra_path):
        log_src = os.path.join(hydra_path, 'main_ablate_hypothesis.log')
        log_dst = f'clean_output_{uuid}.log'
        shutil.copy(log_src, log_dst)
        wandb.save(log_dst)

    with open(f"{wandb.run.dir}/result.json", "w") as f:
        json.dump(done_data, f, indent=2)
        wandb.save(f"result.json")

    verified_all = {}
    verified_correct_all = {}
    verified_incorrect_all = {}
    verified_correct_goals = {}
    verified_incorrect_goals = {}

    verified_file_list = os.listdir(hypothesized_recipe_graph.verified_recipe_dir)

    for f in verified_file_list:
        if not f.endswith(".json"):
            continue

        with open(os.path.join(hypothesized_recipe_graph.verified_recipe_dir, f), "r") as file:
            data = json.load(file)
        item_name = data["item_name"]
        if item_name in ["logs", "log", "oak_log", "planks", "oak_planks", "coal", "coals"]:
            continue

        verified_all[item_name] = data

        oracle_action = env.action_lib.oracle_item_to_action_dict[item_name]
        flag = is_belief_correct(
            copy.deepcopy(hypothesized_recipe_graph.graph[item_name]['ingredients']),
            copy.deepcopy(env.action_lib.action_lib[oracle_action])
        )
        if flag:
            verified_correct_all[item_name] = data
            if item_name in hypothesized_recipe_graph.goal_items:
                verified_correct_goals[item_name] = data
        else:
            verified_incorrect_all[item_name] = data
            if item_name in hypothesized_recipe_graph.goal_items:
                verified_incorrect_goals[item_name] = data

    with open(f"{wandb.run.dir}/verified_all_recipes.json", "w") as out_file:
        json.dump(verified_all, out_file, indent=2)
        wandb.save(f"verified_all_recipes.json")
    with open(f"{wandb.run.dir}/verified_correct_all_recipes.json", "w") as out_file:
        json.dump(verified_correct_all, out_file, indent=2)
        wandb.save(f"verified_correct_all_recipes.json")
    with open(f"{wandb.run.dir}/verified_incorrect_all_recipes.json", "w") as out_file:
        json.dump(verified_incorrect_all, out_file, indent=2)
        wandb.save(f"verified_incorrect_all_recipes.json")
    with open(f"{wandb.run.dir}/verified_correct_goals_recipes.json", "w") as out_file:
        json.dump(verified_correct_goals, out_file, indent=2)
        wandb.save(f"verified_correct_goals_recipes.json")
    with open(f"{wandb.run.dir}/verified_incorrect_goals_recipes.json", "w") as out_file:
        json.dump(verified_incorrect_goals, out_file, indent=2)
        wandb.save(f"verified_incorrect_goals_recipes.json")

    if all_reflections:
        with open(f"{wandb.run.dir}/all_reflections.json", "w") as out_file:
            json.dump(all_reflections, out_file, indent=2)
            wandb.save(f"all_reflections.json")

    if os.path.isfile(env.perturbed_action_lib_dst_path) and env.perturbed_action_lib_dst_path.endswith(".json"):
        with open(env.perturbed_action_lib_dst_path, "r") as f:
            data = json.load(f)
        with open(f"{wandb.run.dir}/perturbed_action_lib.json", "w") as f:
            json.dump(data, f, indent=2)
            wandb.save(f"perturbed_action_lib.json")
        os.remove(env.perturbed_action_lib_dst_path)

    logger.info(f"wandb.run.dir: {wandb.run.dir}")
    wandb.finish()
    # if not ("level3" in cfg["action_lib"] or "level3" in cfg["changed_action_lib"]):
    #     shutil.rmtree(dst_dir)


if __name__ == '__main__':
    main()

