import json
import numpy as np

from general_prompt_builder.constants import PROMPT_HISTORY_PATH
import general_prompt_builder.serializer as prompt_serializer
import general_prompt_builder.utils as prompt_utils

def fetch_messages(experiment_name, prompt_description, prompt_version):
    """Fetches the messages for the prompt from the version control directory.

    Parameters:
        experiment_name (str)
            The name of the experiment for the prompt.
        prompt_description (str)
            The description of the prompt.
        prompt_version (str)
            The version of the prompt.

    Returns:
        messages (List[Dict[str, str]])
            The messages to query the LLM with.
    """
    prompt_path = prompt_utils.get_prompt_path(PROMPT_HISTORY_PATH, experiment_name, prompt_description, prompt_version)
    messages = prompt_serializer.serialize_into_messages(prompt_path)
    return messages

def parse_language_skill(language_skill):
    """Parses a language skill into an action skill.
    
    Parameters:
        text_plan (str)
            The text plan to parse.
    
    Returns:
        skill (tuple)
            The skill to execute in the environment.
    """
    skill_name, params = language_skill.split("(", 1)
    assert skill_name == "pickandplace"
    obj_name, loc_name = params.split(", ")
    loc_name, _ = loc_name.split(")", 1)
    obj_name = obj_name.strip("'")
    obj_name = obj_name.strip('"')
    loc_name = loc_name.strip("'")
    loc_name = loc_name.strip('"')
    return ("pickandplace", obj_name, loc_name)


def format_semantic_loc_for_llm(obj_detect_dict):
    semantic_loc_dict = {}
    for obj_name in obj_detect_dict:
        for semantic_loc in obj_detect_dict[obj_name]["locations"]:
            if semantic_loc not in semantic_loc_dict:
                semantic_loc_dict[semantic_loc] = []
            semantic_loc_dict[semantic_loc].append(obj_name)
    
    return json.dumps(semantic_loc_dict)



def format_plan(text_plan, best_action_sequence):
    text_skill_list = text_plan.split("\n")
    
    plan_code_str = ""
    for i in range(len(text_skill_list)):
        _, obj_name, loc_name = parse_language_skill(text_skill_list[i])
        
        # Determine place location
        x, y, z, obj_w, obj_h, obj_d = best_action_sequence[i]

        print(f"pickandplace('{obj_name}', '{loc_name}', loc=({x:.2f}, {y:.2f}, {z:.2f}), obj_dim=({obj_w:.2f}, {obj_h:.2f}, {obj_d:.2f}))")

        plan_code_str += f"{generate_code_for_pickandplace(obj_name, x, y, z)}\n"
        
    return plan_code_str.strip()


def generate_code_for_pickandplace(obj_name, x, y, z):
    code_str = \
"""
pick_up_item('{object_name}')
go_to(FRIDGE)
place_item_at('{x}', '{y}', '{z}')
go_to(TABLE)
"""
    code_str = code_str.strip()
    code_str = code_str.replace("{object_name}", obj_name).replace("{x}", f"{x:.4f}").replace("{y}", f"{y:.4f}").replace("{z}", f"{z:.4f}")
    return code_str

def get_answer_prob(logprobs):
    for i in range(len(logprobs) - 1, -1, -1):
        token = logprobs[i].token.lower().strip()
        if "yes" in token or "no" in token:
            return np.exp(logprobs[i].logprob)

def pretty_str_preference_with_p_theta(preference_list, p_theta):
    pretty_str = "P(theta)  | Preference candidates\n"
    for i, preference in enumerate(preference_list):
        pretty_str += f"{p_theta[i]:.2f}  | {i}. {preference}\n"

    return pretty_str.strip()

def pretty_str_preference_with_p_theta_before_after(preference_list, p_theta_before, p_theta_after):
    pretty_str = "P(theta) Before | P(theta) After | Preference candidates (<yellow>Yellow=Decrease</yellow>, <green>Green=Increase</green>, White=No Change)\n"
    for i, preference in enumerate(preference_list):
        before = p_theta_before[i]
        after = p_theta_after[i]
        
        if before > after:
            # Yellow represents decrease in value
            color_to_use = "yellow"
        elif before < after:
            # Green represents decrease in value
            color_to_use = "green"
        else:
            # White represents no changes
            color_to_use = "white"
        
        pretty_str += f"<n><{color_to_use}>{p_theta_before[i]:.2f} | {p_theta_after[i]:.2f}  | {i}. {preference}</{color_to_use}></n>\n"

    return pretty_str.strip()

def pretty_str_scores_for_questions(scores, question_list, best_question_idx):
    pretty_str = "Score  | Preference pair | Question list (Best one highlighted in green)\n"
    for i in range(len(question_list)):
        pref_pair, question = question_list[i]
        if i == best_question_idx:
            pretty_str += f"<b><green>{scores[i]}  | {pref_pair} | {i}. {question}</green></b>\n"
        else:
            pretty_str += f"<n><yellow>{scores[i]}  | {pref_pair} | {i}. {question}</yellow></n>\n"
    
    return pretty_str.strip()

def pretty_str_plans_with_agreeableness(plans_list, sum_weighted_opinion_matrix, is_agreeable_matrix):
    pretty_str = "Agreeable? | Sum Weighted Opinions | Plan\n"

    for i, text_plan in enumerate(plans_list):
        text_plan_list = text_plan.split("\n")
        text_plan_list = [" " * 13 + text for text in text_plan_list]
        text_plan_list[0] = text_plan_list[0][11:] # Subtract Extract white space 

        pretty_text_plan = "\n".join(text_plan_list).rstrip()

        if is_agreeable_matrix[i]:
            pretty_str += f"<b><green>T | {sum_weighted_opinion_matrix[i]:.2f} | {pretty_text_plan}</green></b>\n"
        else:
            pretty_str += f"<n><red>F | {sum_weighted_opinion_matrix[i]:.2f} | {pretty_text_plan}</red></n>\n"

    return pretty_str.strip()


def pretty_str_agreeable_plans_details(idx_of_agreeable_plans, plan_list, p_theta, score_matrix, sum_weighted_opinion_matrix):
    pretty_str = f"<green>[[[[ Found {len(idx_of_agreeable_plans)} agreeable plans ]]]]</green>\n"

    p_theta_str_list = [f"{p:.2f}" for p in p_theta]
    p_theta_str = ", ".join(p_theta_str_list)

    for i in idx_of_agreeable_plans:
        score_str_list = [f"{s:.2f}" for s in score_matrix[i]]
        score_str = ", ".join(score_str_list)

        text_plan_list = plan_list[i].split("\n")
        text_plan_list = [" " * 4 + text for text in text_plan_list]

        pretty_text_plan = "\n".join(text_plan_list).rstrip()

        pretty_str += f"<green>Plan {i}</green>\nP(theta):  {p_theta_str}\nDisadv:    {score_str}\nWeighted sum:  {sum_weighted_opinion_matrix[i]:.2f}\n{pretty_text_plan}\n\n"

    return pretty_str.strip()


def format_state_dictionary_to_str(state_dict):
    """
    Example of state_dict:
        {
        "left side of top shelf": [
            [
                "bell pepper",
                [
                    0.0,
                    0.67
                ]
            ],
            [
                "orange",
                [
                    0.5,
                    0.67
                ]
            ]
        ],
        "right side of top shelf": [],
        "left side of middle shelf": [],
        "right side of middle shelf": [],
        "left side of bottom shelf": [],
        "right side of bottom shelf": []
    }
    """
    filtered_state_dict = {}
    for semantic_loc in state_dict:
        # obj_info is a list of 2 items (object name, object's x y location)
        #   we only need the object name
        objs_at_semantic_loc = [obj_info[0] for obj_info in state_dict[semantic_loc]]
        filtered_state_dict[semantic_loc] = objs_at_semantic_loc

    return json.dumps(filtered_state_dict, indent=4)

def format_demos_dict_for_gb(demos):
    demos_str = ""
    
    for i, demo in enumerate(demos):
        one_demo_str = f"## Objects that got put away: {demo['objects_to_put_away']}\n"
        one_demo_str += f"## Initial state of the fridge:\n```\n{format_state_dictionary_to_str(demo['initial_state'])}\n```\n"
        one_demo_str += f"## Final state of the fridge:\n```\n{format_state_dictionary_to_str(demo['final_state'])}\n```\n"

        demos_str += f"# Demonstration {i+1}\n{one_demo_str}\n"

    return demos_str.strip()
