from ours.memories.hypothesized_recipe_graph import HypothesizedRecipeGraph
import logging
from omegaconf import DictConfig, OmegaConf
import random
import copy


def modified_deckard(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger: logging.Logger, cfg: DictConfig):
    logger.info(f"In modified_deckard()")
    hypothesized_recipe_graph.load_and_init_all_recipes()
    frontier_item_names = hypothesized_recipe_graph.find_frontiers()
    exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_hypothesized()

    frontier_exploration_count_dict = {}
    for item_name in frontier_item_names:
        frontier_exploration_count_dict[item_name] = exploration_count_dict[item_name]

    inadmissible_threshold = cfg["memory"]["inadmissible_threshold"] # c_0 in the paper "Embodied Decision Making using Language Guided World Modelling"
    admissible_item_names = [item for item in frontier_item_names if frontier_exploration_count_dict[item] <= inadmissible_threshold]

    if len(admissible_item_names) == 0:
        logger.info(f"No admissible items. Select from all frontiers + verified.")
        admissible_item_names = list(set(frontier_item_names + hypothesized_recipe_graph.get_verified_item_names()))

    random.shuffle(admissible_item_names)
    selected_int_goal = hypothesized_recipe_graph.select_non_conflicting_goal(admissible_item_names)

    # No revision of hypothesized recipes in deckard
    return selected_int_goal


def modified_adam(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger: logging.Logger, cfg: DictConfig):
    logger.info(f"In modified_adam()")
    hypothesized_recipe_graph.load_and_init_all_recipes()

    hypothesized_item_names = copy.deepcopy(hypothesized_recipe_graph.hypothesized_item_names)
    hypothesized_item_names = list(set(hypothesized_item_names))

    # Original ADAM requires a pre-defined curriculum of exploration order
    # But, here, we assume there is no such curriculum
    # So we randomly select an item from all unexperienced items
    random.shuffle(hypothesized_item_names)
    selected_int_goal = hypothesized_recipe_graph.select_non_conflicting_goal(hypothesized_item_names)

    # No revision of hypothesized recipes in adam
    return selected_int_goal


# ours
def feasibility_min_count_frontier(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger: logging.Logger, cfg: DictConfig):
    logger.info(f"In feasibility_min_count_frontier()")
    hypothesized_recipe_graph.load_and_init_all_recipes()
    frontier_item_names = hypothesized_recipe_graph.find_frontiers()
    exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_hypothesized()
    level_dict = hypothesized_recipe_graph.calculate_level_all_hypothesized()

    frontier_exploration_count_dict = {}
    frontier_level_dict = {}

    for item_name in frontier_item_names:
        frontier_exploration_count_dict[item_name] = exploration_count_dict[item_name]
        frontier_level_dict[item_name] = level_dict[item_name]

    sorted_item_names = sorted(
        frontier_item_names,
        key=lambda item: (
            frontier_exploration_count_dict[item], # fewer exploration count is better
            - (1 / frontier_level_dict[item]) # higher feasibility score is better
        )
    )

    selected_int_goal = hypothesized_recipe_graph.select_non_conflicting_goal(sorted_item_names)

    if frontier_exploration_count_dict[selected_int_goal] > 1:
        hypothesized_recipe_graph.update_hypothesis(selected_int_goal)
    return selected_int_goal


# # without feasbility scroe
# def min_count_frontier(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger: logging.Logger, cfg: DictConfig):
#     logger.info(f"In min_count_frontier(). without feasbility score")
#     hypothesized_recipe_graph.load_and_init_all_recipes()
#     frontier_item_names = hypothesized_recipe_graph.find_frontiers()
#     exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_hypothesized()

#     frontier_exploration_count_dict = {}
#     for item_name in frontier_item_names:
#         frontier_exploration_count_dict[item_name] = exploration_count_dict[item_name]

#     random.shuffle(frontier_item_names)
#     frontier_item_names.sort(key=frontier_exploration_count_dict.get)

#     sorted_item_names = copy.deepcopy(frontier_item_names)
#     selected_int_goal = hypothesized_recipe_graph.select_non_conflicting_goal(sorted_item_names)

#     if frontier_exploration_count_dict[selected_int_goal] > 1:
#         hypothesized_recipe_graph.update_hypothesis(selected_int_goal)
#     return selected_int_goal


# def feasibility_min_count_wo_frontier(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger: logging.Logger, cfg: DictConfig):
#     logger.info(f"In feasibility_min_count_wo_frontier().")
#     hypothesized_recipe_graph.load_and_init_all_recipes()
#     hypothesized_item_names = copy.deepcopy(hypothesized_recipe_graph.hypothesized_item_names)
#     exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_hypothesized()
#     level_dict = hypothesized_recipe_graph.calculate_level_all_hypothesized()

#     sorted_item_names = sorted(
#         hypothesized_item_names,
#         key=lambda item: (
#             exploration_count_dict[item], # fewer exploration count is better
#             - (1 / level_dict[item]) # higher feasibility score is better
#         )
#     )

#     selected_int_goal = hypothesized_recipe_graph.select_non_conflicting_goal(sorted_item_names)

#     if exploration_count_dict[selected_int_goal] > 1:
#         hypothesized_recipe_graph.update_hypothesis(selected_int_goal)
#     return selected_int_goal


def uniform_random_goal(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger: logging.Logger, cfg: DictConfig):
    logger.info(f"In uniform_random_goal()")
    hypothesized_recipe_graph.load_and_init_all_recipes()

    hypothesized_item_names = copy.deepcopy(hypothesized_recipe_graph.hypothesized_item_names)
    hypothesized_item_names = list(set(hypothesized_item_names))

    random.shuffle(hypothesized_item_names)
    selected_int_goal = hypothesized_recipe_graph.select_non_conflicting_goal(hypothesized_item_names)

    exploration_count_dict = hypothesized_recipe_graph.get_exploration_count_all_hypothesized()
    # if exploration_count_dict[selected_int_goal] > 1:
    #     hypothesized_recipe_graph.update_hypothesis(selected_int_goal)
    return selected_int_goal


def select_int_goal(hypothesized_recipe_graph: HypothesizedRecipeGraph, logger: logging.Logger, cfg: DictConfig):
    prefix = cfg.get("prefix")

    if "ours" in prefix or "feasibility" in prefix or "frontier" in prefix:
        int_goal = feasibility_min_count_frontier(hypothesized_recipe_graph, logger, cfg)
    elif "uniform_random_goal" in prefix or "pure_llm" in prefix:
        int_goal = uniform_random_goal(hypothesized_recipe_graph, logger, cfg)
    elif "deckard" in prefix:
        int_goal = modified_deckard(hypothesized_recipe_graph, logger, cfg)
    elif "adam" in prefix:
        int_goal = modified_adam(hypothesized_recipe_graph, logger, cfg)
    elif "self_correction" in prefix:
        int_goal = uniform_random_goal(hypothesized_recipe_graph, logger, cfg)
    else:
        logger.error(f"prefix: {prefix} is not supported.")
        return None

    return int_goal
