"""
Main orchestrator for inference pipeline for APRICOT project
"""
import os
import sys
import datetime
import copy

# os.environ["LOGURU_LEVEL"] = "INFO"

from loguru import logger

if "LOGURU_LEVEL" in os.environ and (os.environ["LOGURU_LEVEL"] == "INFO"):
    fmt = "{message}"
    config = {
        "handlers": [
            {"sink": sys.stderr, "format": fmt},
        ],
    }
    logger.configure(**config)


import argparse
import json
import yaml
import gym
import numpy as np

from utils import format_semantic_loc_for_llm, format_plan
from constants import ALL_POSSIBLE_OBJECTS_TO_PLACE_DICT

from general_prompt_builder.prompt_llm import prompt_llm, get_accumulated_cost
# from llms_for_planning.prompt_builder.prompt_llm import prompt_llm
from llms_for_planning.planners.geometric_feasibility.v0_no_llm_scoring import plan
from llms_for_planning.utils import fetch_messages as planner_fetch_message
from utils import *


class APRICOTPrefLearning(object):
    NAME = "APRICOT"

    OBJS_TO_PUT_AWAY_PARAM_NAME = "/perception/obj_to_put_away"
    OBJS_IN_FRIDGE_PARAM_NAME = "/perception/init_obj_in_fridge"
    STATE_TRACKER_TOPIC = "/state_tracker/fridgepcl"
    OBJECT_DETECTION_TOPIC = "/perception/detected_obj"
    ###### Fridge Range
    _BBOX_OFFSET = 0.02  # Define an offset around the fridge for safer planner output
    _SHELF_WIDTH = 0.42 - _BBOX_OFFSET * 2
    _SHELF_DEPTH = 0.24 - _BBOX_OFFSET * 2
    # Origin (Middle shelf's top left corner)
    O_X = -0.0275
    O_Y = 0.10
    O_Z = 0.0
    # Top shelf
    _TOP_SHELF_Y_OFFSET = 0.0 + _BBOX_OFFSET  # Don't want to sample too close to the bottom of the shelf
    _TOP_SHELF_HEIGHT = 0.23 - _BBOX_OFFSET * 2  # Don't want to sample too close to the top of the shelf
    # Middle shelf
    _MIDDLE_SHELF_Y_OFFSET = -0.27 + _BBOX_OFFSET
    _MIDDLE_SHELF_HEIGHT = 0.27 - _BBOX_OFFSET * 2
    # Bottom shelf
    _BOTTOM_SHELF_Y_OFFSET = -0.42 + _BBOX_OFFSET
    _BOTTOM_SHELF_HEIGHT = 0.15 - _BBOX_OFFSET * 2
    # Bbox format (bottom_left_x, bottom_left_y, bottom_left_z, w, h, d)
    FRIDGE_LOCATION_BBOX = {
        # Top shelf locations
        "top shelf":                    (O_X,      O_Y + _TOP_SHELF_Y_OFFSET,    O_Z - _BBOX_OFFSET, _SHELF_WIDTH,       _TOP_SHELF_HEIGHT,      _SHELF_DEPTH),
        "left side of top shelf":            (O_X,      O_Y + _TOP_SHELF_Y_OFFSET,    O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _TOP_SHELF_HEIGHT,      _SHELF_DEPTH),
        "right side of top shelf":           (O_X + _SHELF_WIDTH/2.0,  O_Y + _TOP_SHELF_Y_OFFSET,    O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _TOP_SHELF_HEIGHT,      _SHELF_DEPTH),
        # Middle shelf locations    
        "middle shelf":                    (O_X,      O_Y + _MIDDLE_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH,       _MIDDLE_SHELF_HEIGHT,   _SHELF_DEPTH),
        "left side of middle shelf":            (O_X,      O_Y + _MIDDLE_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _MIDDLE_SHELF_HEIGHT,   _SHELF_DEPTH),
        "right side of middle shelf":           (O_X + _SHELF_WIDTH/2.0,  O_Y + _MIDDLE_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _MIDDLE_SHELF_HEIGHT,   _SHELF_DEPTH),
        # Bottom shelf locations
        "bottom shelf":                    (O_X,      O_Y + _BOTTOM_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH,       _BOTTOM_SHELF_HEIGHT,   _SHELF_DEPTH),
        "left side of bottom shelf":            (O_X,      O_Y + _BOTTOM_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _BOTTOM_SHELF_HEIGHT,   _SHELF_DEPTH),
        "right side of bottom shelf":           (O_X + _SHELF_WIDTH/2.0,  O_Y + _BOTTOM_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _BOTTOM_SHELF_HEIGHT,   _SHELF_DEPTH),
    }

    """===================================================================================================

    Initialization

    ==================================================================================================="""
    

    def __init__(self,  seed,
                        prompt_config_path,
                        learn_pref_config = dict(
                            n_pref = 5,
                            not_agreeable_threshold = 0.1,
                            n_q_per_pair = 5,
                            max_step = 5,
                        ),
                        planner_config = dict(
                            num_plans = 10,  # Number of plans to generate
                            beam_size = 10,  # Size of the beams to maintain
                            num_samples = 10,  # Number of samples to take for each object placement
                        ),
                        sim_config = dict(
                            fps = 4, # Frames per second to render the GIF at
                            gif_path = None, # Path to save the GIF to; doesn't save if None
                        ),
                        debug=False,
                        use_ros=False,
                        use_ros_service=False,
                        batch_output=True
        ):
        """Initialize the APRICOTPrefLearning main function for inference

        Parameters:
            objs_to_put_away (list) - list of strings that are names of the objects to put away
            preference (str) - text description of the user's preference
        """
        self.need_to_plan = True  # Control whether the planner needs to get replanned
        self.curr_state = {}

        self.debug = debug
        self.use_ros = use_ros
        self.use_ros_service = use_ros_service
        self.planner_config = planner_config
        self.sim_config = sim_config
        self.seed = seed
        self.batch_output = batch_output

        self.progress_save_file = ""
        self.saved_progress_dict = {}
        self.logger_handler_id = None

        # Because these get used frequently during preference learning, we define individual variabless
        self.n_preferences = learn_pref_config["n_pref"]
        self.not_agreeable_plan_threshold = learn_pref_config["not_agreeable_threshold"]
        self.n_q_per_pair = learn_pref_config["n_q_per_pair"]
        self.max_pref_learning_step = learn_pref_config["max_step"]

        self.init_llm_prompts(prompt_config_path)

    
    def bp(self):
        if self.debug:
            # os.system('say "Hit a break point"')
            breakpoint()

    def init_llm_prompts(self, prompt_config_path):
        """
        Load LLM prompts

        Effect:
            Set the variable self.prompt_config with the configuration and messages for each LLM prompt
        """
        with open(prompt_config_path, "r") as fin:
            prompt_config = yaml.safe_load(fin)

        self.prompt_config = prompt_config[self.NAME]

        for prompt_type in self.prompt_config:
            message_fetching_fn_to_use = planner_fetch_message if prompt_type == "generate_plan" else fetch_messages

            self.prompt_config[prompt_type]["llm_config"]["messages"] = message_fetching_fn_to_use(self.prompt_config[prompt_type]["experiment_name"],
                                                                            self.prompt_config[prompt_type]["prompt_description"],
                                                                            self.prompt_config[prompt_type]["prompt_version"])
            
            if prompt_type == "generate_preferences":
                n_preferences_to_gen = self.n_preferences

                if self.prompt_config["generate_preferences"]["inject_noninteractive"]:
                    n_preferences_to_gen -= 1

                self.prompt_config[prompt_type]["llm_config"]["messages"][1]["content"] = self.prompt_config[prompt_type]["llm_config"]["messages"][1]["content"].replace("<num_preferences>", str(n_preferences_to_gen))
                self.prompt_config[prompt_type]["llm_config"]["messages"][1]["content"] = self.prompt_config[prompt_type]["llm_config"]["messages"][1]["content"].replace("<num_preferences_minus_one>", str(n_preferences_to_gen - 1))
            elif prompt_type == "generate_questions":
                self.prompt_config[prompt_type]["llm_config"]["messages"][1]["content"] = self.prompt_config[prompt_type]["llm_config"]["messages"][1]["content"].replace("<num_questions>", str(self.n_q_per_pair))
            

    """===================================================================================================

    Generating preferences

    ==================================================================================================="""

    def inject_noninteractive_pref(self):
        # Load the saved progress
        noninteractive_progress_save_file = self.progress_save_file.replace(self.NAME, "NonInteractive")

        with open(noninteractive_progress_save_file, "r") as fin:
            noninteractive_progress_save_dict = json.load(fin)

        return noninteractive_progress_save_dict["preference"]

    def generate_preferences(self, demos):
        can_load, value = self.load_progress("preference_list")

        if can_load:
            preference_list = value
        else:
            can_load, value = self.load_progress("reasoning_worksheet")

            if can_load:
                reasoning_worksheet = value
            else:
                logger.debug("------- Analyzing demonstrations ...")
                reasoning_worksheet = prompt_llm(demos, **self.prompt_config["analyze_demos"]["llm_config"])

                for category_section in ["(Fruits)", "(Vegetables)", "(Juice-and-soft-drinks)", "(Dairy-Products)", "(Condiments)"]:
                    if not category_section in reasoning_worksheet:
                        logger.debug(f"{category_section} is not in the reasoning worksheet")

                        os.system('say "Hello, please verify reasoning worksheet!"')
                        breakpoint()

                logger.debug(f"\n{reasoning_worksheet}")

                self.save_progress("reasoning_worksheet", reasoning_worksheet)

            self.bp()

            can_load, value = self.load_progress("preference_writing_reasoning")

            if can_load:
                output = value
            else:
                logger.debug("------- Generating preferences ...")
                # user_prompt = f"Input-Demonstrations:\n{demos}\n\nInput-Reasoning-Worksheet:\n{reasoning_worksheet}"
                user_prompt = reasoning_worksheet

                logger.debug(f"\n{user_prompt}")
                output = prompt_llm(user_prompt, **self.prompt_config["generate_preferences"]["llm_config"])

                self.save_progress("preference_writing_reasoning", output)

            logger.debug(f"\n{output}")

            if "```json" in output or "```" in output:
                partition_keywords = "```json" if "```json" in output else "```"

                _, _, preferences = output.partition(partition_keywords)
                preferences = preferences.replace("```", "")  # Remove the second half of ```
                preferences = preferences.strip()
            else:
                preferences = output
            
            try:
                preference_list = json.loads(preferences)
            except:
                logger.info("Failed to load the answer as a json")

                logger.info(output)

                os.system('say "Hello, something went wrong. Please check!"')

                while True:
                    preference = input("Please manually copy the preference here (without quotations): ")

                    if preference == "":
                        break

                    preference_list.append(preference.strip())
                
                breakpoint()

            if self.prompt_config["generate_preferences"]["inject_noninteractive"]:
                # For Active Preference Learning experiment where we add the Non-Interactive Preference as a candidate
                noninteractive_preference = self.inject_noninteractive_pref()
                preference_list.append(noninteractive_preference)

            self.save_progress("preference_list", preference_list)

            # os.system('say "Hello, please verify demonstration!"')
            # breakpoint()

        self.bp()

        return preference_list


    """===================================================================================================

    Generate text plans from preferences

    ==================================================================================================="""

    def generate_plans_from_preferences(self, preference_list, initial_state, objs_to_put_away):
        """
        Return:
            (list) A list of string. The text plan is a string of code
        """
        can_load, value = self.load_progress("text_plan_list")

        if can_load:
            text_plan_list = value
        else:
            text_plan_list = []
            
            for preference in preference_list:
                text_plan = self.generate_text_plan(objs_to_put_away=objs_to_put_away,
                                                    init_state=initial_state,
                                                    preference=preference) 

                text_plan_list.append(text_plan)

            self.save_progress("text_plan_list", text_plan_list)

        self.bp()

        plan_str = "\n======================\n".join(text_plan_list)
        logger.debug(f"Generated Text Plans:\n{plan_str}")

        return text_plan_list

    def generate_text_plan(self, objs_to_put_away, init_state, preference):
        user_prompt = f"""
        Objects: {objs_to_put_away}
        Locations: {list(self.FRIDGE_LOCATION_BBOX.keys())}
        Initial State: {init_state}
        Preference: "{preference}"
        """
        logger.debug("------- Generating a plan ...")
        logger.debug(user_prompt)
        text_plan = prompt_llm(user_prompt, **self.prompt_config["generate_plan"]["llm_config"])

        logger.debug(f"\n{text_plan}")

        _, _, code = text_plan.partition("pickandplace")

        code = "pickandplace" + code  # Add back the first 'pickandplace' that got removed

        assert "#" not in code, f"Should not have more comments in the code:\n{code}"

        return code.strip()  # Remove uncessary whitespace
    

    """===================================================================================================

    Score plans based on prefereneces and initial condition

    ==================================================================================================="""

    def score_plans_on_preferences(self, preference_list, plan_list, initial_state):
        """
        Return:
            N x N matrix (where N is the number of preference candidate that we generated)
        """
        can_load, value = self.load_progress("score_matrix")

        if can_load:
            score_matrix = np.array(value)
        else:
            score_matrix = np.zeros((len(preference_list), len(plan_list)))
            for j, preference in enumerate(preference_list):
                for i, plan in enumerate(plan_list):
                    score_matrix[i][j] = self.score_plan(preference, initial_state, plan)

            self.save_progress("score_matrix", score_matrix.tolist())  # Need to convert back to list because json does not take objects

        self.bp()
        return score_matrix   


    def score_plan(self, preference, initial_state, plan, debug=False):
        if debug:
            print(f"Preference:\n{preference}\nPlan:\n{plan}")
            sum_obj_score = int(input("What is the total number of object correctly placed?: "))
            num_of_object = len(plan.split("\n"))
            logger.debug(f"Score: {sum_obj_score / num_of_object:.2f}")
            return sum_obj_score / num_of_object 
        else:
            object_plan_list = plan.strip().split("\n")
            num_of_object = len(object_plan_list)
            
            if self.batch_output:
                clean_plan = plan.strip()
                # clean_plan = clean_plan.replace('"', "'")  # Replace double quote with single quote

                user_prompt = f"""
                # Your preference
                {preference}
                
                # Objects already initially in the fridge
                ```
                {initial_state}
                ```

                # Object placement plan
                ```
                {clean_plan}
                ```
                """

                logger.debug("------- Scoring questions (in batch) ...")
                logger.debug(f"{user_prompt}")

                response = prompt_llm(user_prompt, **self.prompt_config["score_plan_batch"]["llm_config"])

                logger.debug(f"{response}")

                if "```json" in response or "```" in response:
                    partition_keywords = "```json" if "```json" in response else "```"

                    _, _, response = response.partition(partition_keywords)
                    response = response.replace("```", "")  # Remove the second half of ```
                    response = response.strip()

                try:
                    response = json.loads(response)
                    
                    per_obj_answer_list = [o["Does this plan satisfy your preference (yes/no)"].lower().strip() for o in response]

                    per_obj_score_list = [1 if "yes" in a else 0 for a in per_obj_answer_list]
                except:
                    logger.info("Failed to load the answer as a json")
                    logger.info(f"\n{response}")

                    os.system('say "Hello, something went wrong. Please check!"')

                    per_obj_score_list = []

                    while True:
                        print(f"Current {per_obj_score_list=}")
                        answer = input("Manually enter 0/1 based on what is printed out: ")

                        if answer == "":
                            break
                        else:
                            try:
                                per_obj_score_list.append(int(answer.strip()))
                            except:
                                print("The response that you wrote is invalid, please try again. You must only write 0 or 1, nothing else.")

                    breakpoint()
            else:
                object_plan_list = plan.strip().split("\n")
                per_obj_score_list = []

                for object_plan in object_plan_list:
                    if object_plan != "":
                        user_prompt = f"""
                        # Your preference
                        {preference}
                        
                        # Objects already initially in the fridge
                        ```
                        {initial_state}
                        ```

                        # Object placement plan
                        ```
                        {object_plan}
                        ```
                        """
                        logger.debug("------- Scoring questions ...")
                        logger.debug(f"{user_prompt}")
                        response = prompt_llm(user_prompt, **self.prompt_config["score_plan"]["llm_config"])

                        logger.debug(f"{response}")

                        try:
                            response = json.loads(response)

                            assert "Does this plan satisfy your preference (yes/no)" in response
                            answer = response["Does this plan satisfy your preference (yes/no)"].lower().strip()
                            assert "yes" in answer or "no" in answer

                            if "yes" in answer:
                                per_obj_score_list.append(1)
                            else:
                                per_obj_score_list.append(0)
                        except:
                            logger.info("Failed to load the answer as a json")
                            logger.info(response)

                            os.system('say "Hello, something went wrong. Please check!"')

                            answer = input("Manually enter 0/1 based on what is printed out: ")
                            per_obj_score_list.append(int(answer))

                            breakpoint()

            logger.debug(f"\nInitial Condition:\n{initial_state}\nPreference:\n{preference}\nPlan:\n{plan}\n\n{np.sum(per_obj_score_list) / num_of_object:.2f}: {per_obj_score_list}")
                
            # self.bp()
            return np.sum(per_obj_score_list) / num_of_object  # Remove uncessary whitespace
    
    
    """===================================================================================================

    Check if there exist one plans that everyone agrees on

    ==================================================================================================="""
    def check_exist_agreeable_plan(self, score_matrix, plans_list):
        """
        Parameters:
            score_matrix (np.array) - N x N (where N is the number of plans)

        Return:
           (bool) - True if there is a plan that all preferences agree on
           (np.array) - index to those plan
        """
        logger.debug(f"\nscore_matrix\n{score_matrix}")

        # Optimal score: (1, N) Find the optimal plan for each preference
        optimal_score = np.max(score_matrix, axis=0).reshape(1, self.n_preferences)  # Gives us (1 x self.n_preferences)
        optimal_score_matrix = np.repeat(optimal_score, self.n_preferences, axis=0)  # N x N, where optimal_score_matrix[:, i] are all the same value

        logger.debug(f"\noptimal_score\n{optimal_score}")

        # Disadvantage score: (Optimal score - Current plan score)
        #   Higher the number, more the disadvantage
        disadvantage_matrix = optimal_score_matrix - score_matrix
        disadvantage_matrix += 0.01

        logger.debug(f"\ndisadvantage_matrix\n{disadvantage_matrix}")
        
        # Disadvantage * a stack of N number of P(theta) --> (N x N) (N x N) (each row needs to be multiplied by P(theta))
        weights = np.repeat(self.p_theta.reshape(1, self.n_preferences), self.n_preferences, axis=0)
        logger.debug(f"\nweights\n{weights}")

        weighted_opinion_matrix = weights * disadvantage_matrix

        # shape = (N, )
        sum_weighted_opinion_matrix = weighted_opinion_matrix.sum(axis=1)

        logger.debug(f"\nweighted_opinion_matrix\n{weighted_opinion_matrix}\n{sum_weighted_opinion_matrix=}")

        # Below_threshold: shape = (N, )
        #   True if the sum of disadvantage score is below a threshold
        #       In other words, the disadvantage score is not bad enough
        is_agreeable_matrix = sum_weighted_opinion_matrix < self.not_agreeable_plan_threshold

        logger.opt(colors=True).info(f"\n{pretty_str_plans_with_agreeableness(plans_list, sum_weighted_opinion_matrix, is_agreeable_matrix)}")

        # Find a row where it's all true
        idx_of_agreeable_plans = np.where(is_agreeable_matrix)[0]

        exist_agreeable_plan = len(idx_of_agreeable_plans)

        if exist_agreeable_plan:
            logger.opt(colors=True).info(f"\n{pretty_str_agreeable_plans_details(idx_of_agreeable_plans, plans_list, self.p_theta, disadvantage_matrix, sum_weighted_opinion_matrix)}\n")

        idx_to_best_plan = np.argmin(sum_weighted_opinion_matrix)

        return exist_agreeable_plan, idx_to_best_plan
    
    """===================================================================================================

    Generate set of questions Q that we can pick and ask

    ==================================================================================================="""

    def generate_all_pairwise_questions(self, preference_list, initial_state, objs_to_put_away, plan_list, score_matrix):
        """
        Return
            (list) A flat list of strings
        """

        can_load, value = self.load_progress("question_list")

        if can_load:
            question_list = value
        else:
            question_list = []
            for i in range(len(preference_list)):
                for j in range(i, len(preference_list)):
                    if i != j:
                        pref_i = preference_list[i]
                        pref_j = preference_list[j]
                        
                        plan_i = plan_list[i]
                        plan_j = plan_list[i]

                        pref_i_on_plan_i_score = score_matrix[i][i]
                        pref_j_on_plan_i_score = score_matrix[i][j]

                        pref_i_on_plan_j_score = score_matrix[j][i]
                        pref_j_on_plan_j_score = score_matrix[j][j]

                        raw_questions = self.generate_questions(pref_i, pref_j, 
                                                                initial_state, objs_to_put_away,
                                                                plan_i, plan_j,
                                                                pref_i_on_plan_i_score, pref_j_on_plan_i_score,
                                                                pref_i_on_plan_j_score, pref_j_on_plan_j_score)
                        questions_with_pair_info = [[[i, j], q] for q in raw_questions]                        

                        question_list.extend(questions_with_pair_info)

            self.save_progress("question_list", copy.deepcopy(question_list))
            
        self.bp()

        return question_list

    def generate_questions(self, pref_i, pref_j, 
                           initial_state, objs_to_put_away, 
                           plan_i, plan_j,
                           pref_i_on_plan_i_score, pref_j_on_plan_i_score,
                           pref_i_on_plan_j_score, pref_j_on_plan_j_score):
        """
        Return
            (list) a list of string, where each string is the question
        """
        user_prompt = f"""
        # New Initial Condition To Solve
        ## Objects initially in the fridge
        ```json
        {initial_state}
        ```

        ## Objects that must be put away into the fridge
        ```json
        {objs_to_put_away}
        ```

        # Preference Candidate 1
        {pref_i}

        ## Plan based on Preference Candidate 1
        {plan_i}

        ### Preference Candidate 1's Score on Plan 1
        {pref_i_on_plan_i_score}
        ### Preference Candidate 2's Score on Plan 1
        {pref_j_on_plan_i_score}

        # Preference Candidate 2
        {pref_j}

        ## Plan based on Preference Candidate 2
        {plan_j}

        ### Preference Candidate 1's Score on Plan 2
        {pref_i_on_plan_j_score}
        ### Preference Candidate 2's Score on Plan 2
        {pref_j_on_plan_j_score}
        """
        logger.debug("------- Generating questions ...")
        logger.debug(f"\n{user_prompt}")
        response = prompt_llm(user_prompt, **self.prompt_config["generate_questions"]["llm_config"])

        logger.debug(f"\n{response}")

        assert "# Questions" in response
            
        question_list = response.partition("# Questions")[2].strip().split("\n")

        # Filter out times when the LLM does not generate the question
        question_list = [q for q in question_list if "?" in q]

        logger.debug(f"\n{question_list=}")

        # self.bp()

        return question_list  # Remove uncessary whitespace

    # def generate_questions(self, pref_i, pref_j):
    #     """
    #     Return
    #         (list) a list of string, where each string is the question
    #     """
        # user_prompt = f"""
        # # Preference Candidate 1
        # {pref_i}
        
        # # Preference Candidate 2
        # {pref_j}
        # """
    #     logger.debug("------- Generating questions ...")
    #     logger.debug(f"\n{user_prompt}")
    #     response = prompt_llm(user_prompt, **self.prompt_config["generate_questions"]["llm_config"])

    #     logger.debug(f"\n{response}")

    #     assert "# Questions" in response
            
    #     question_list = response.partition("# Questions")[2].strip().split("\n")

    #     # Filter out times when the LLM does not generate the question
    #     question_list = [q for q in question_list if "?" in q]

    #     logger.debug(f"\n{question_list=}")

    #     # self.bp()

    #     return question_list  # Remove uncessary whitespace
    

    """===================================================================================================

    Calculate P(o=+ | theta, q) and P(o=- | theta, q) for each question q in set of questions Q

    ==================================================================================================="""
    
    def get_likelihoods(self, question_list, preference_list):
        can_load, value = self.load_progress("pos_ans_likelihood_matrix")

        if can_load:
            pos_ans_likelihood_matrix = np.array(value)
        else:

            pos_ans_likelihood_matrix = np.zeros((len(question_list), len(preference_list)))

            for i in range(len(question_list)):
                for j in range(len(preference_list)):
                    _, q = question_list[i]
                    theta = preference_list[j]

                    pos_ans_likelihood_matrix[i][j] = self.calculate_likelihood(q, theta)

            self.save_progress("pos_ans_likelihood_matrix", copy.deepcopy(pos_ans_likelihood_matrix.tolist()))

        neg_ans_likelihood_matrix = 1 - pos_ans_likelihood_matrix

        self.bp()

        return pos_ans_likelihood_matrix, neg_ans_likelihood_matrix

    def score_questions(self, pos_ans_likelihood_matrix, neg_ans_likelihood_matrix):
        """
        Parameters:
            pos_ans_likelihood_matrix - numpy array, shape = Q x N, P(o=yes | q, theta)
                Q = number of questions, N = number of preferences
            neg_ans_likelihood_matrix - numpy array, shape = Q x N, P(o=no | q, theta)

        What is the score that we are calculating?
            H[P(theta)] - E_o H[P(theta | o, q)]
            = H[P(theta)] - sum_of ans in {yes, no} ( P(o=ans | q) H[P(theta | o=ans, q)] )

            Since there are only 2 possible answers, we can concretely write out
            E_o H[P(theta | o, q)]
                = sum_of ans in {yes, no} ( P(o=ans | q) H[P(theta | o=ans, q)] )
                = P(o=yes | q)H[P(theta | o=yes, q)] + P(o=no | q)H[P(theta | o=no, q)]

            The pieces that we need are:
            * P(o=yes | q) and P(o=no | q)
                Expected shape = (Q x 1)

                Equestion: 
                    P(o=yes | q) = E_theta P(o=yes | q, theta)
                                = P(o=yes | q, theta) P(theta)

                    To double check P(o=yes | q, theta) is Q x N and P(theta) is N x 1,
                        so matmul will give us Q x 1

            * P(theta | o=yes, q) and P(theta | o=no, q) Posteriors
                Expected shape = (Q x N)

                Equestion:
                    P(theta | o=yes, q) = (P(o=yes | q, theta) P(theta)) / P(o=yes, q)

                    Within this, 
                        P(o=yes, q) can be treated as an normalizing factor
                        We need to stack P(theta) Q times to get Q x N shape
                            Each row of P(o=yes | q, theta) represents: given a specific question, what is 
                            the likelihood that a preference would answer yes. 
                            We want to scale each of those terms by the probability of that specific preference.

        Return:
            Score matrix: H[P(theta)] - E_o H[P(theta | o, q)]
                shape: Q x N
        """
        num_questions = pos_ans_likelihood_matrix.shape[0]

        # 1. Calculate P(o=yes | q) and P(o=no | q)
        pos_ans_given_q_vector = np.matmul(pos_ans_likelihood_matrix, self.p_theta.reshape(self.n_preferences, 1))  # (Q x N) * (N x 1) = Q x 1
        neg_ans_given_q_vector = np.matmul(neg_ans_likelihood_matrix, self.p_theta.reshape(self.n_preferences, 1))

        # logger.debug(f"\n{pos_ans_given_q_vector.shape}: {pos_ans_given_q_vector=}\n{neg_ans_given_q_vector.shape}: {neg_ans_given_q_vector=}")

        # 2. Calculate posteriors P(theta | o=yes, q) and P(theta | o=no, q)
        p_theta_matrix = np.repeat(self.p_theta.reshape(1, self.n_preferences), num_questions, axis=0)

        pos_ans_posterior_matrix = pos_ans_likelihood_matrix * p_theta_matrix  # Q x N
        neg_ans_posterior_matrix = neg_ans_likelihood_matrix * p_theta_matrix

        # logger.debug(f"\nBefore:\n{pos_ans_posterior_matrix.shape}: {pos_ans_posterior_matrix=}\n{neg_ans_posterior_matrix.shape}: {neg_ans_posterior_matrix=}")

        # Normalize P(theta | q, o) for a given q
        #   In other words, the sum of a row needs to be 1
        pos_ans_rowsum = np.sum(pos_ans_posterior_matrix, axis=1).reshape(num_questions, 1)
        pos_ans_rowsum_matrix = np.repeat(pos_ans_rowsum, self.n_preferences, axis=1)
        # logger.debug(f"\n{pos_ans_rowsum_matrix=}")
        pos_ans_posterior_matrix = pos_ans_posterior_matrix / pos_ans_rowsum_matrix

        neg_ans_rowsum = np.sum(neg_ans_posterior_matrix, axis=1).reshape(num_questions, 1)
        neg_ans_rowsum_matrix = np.repeat(neg_ans_rowsum, self.n_preferences, axis=1)
        # logger.debug(f"\n{neg_ans_rowsum=}")
        neg_ans_posterior_matrix = neg_ans_posterior_matrix / neg_ans_rowsum_matrix

        # logger.debug(f"\n{pos_ans_posterior_matrix.shape}: {pos_ans_posterior_matrix=}\n{neg_ans_posterior_matrix.shape}: {neg_ans_posterior_matrix=}")

        # 3. Calculate the score
        pos_ans_given_q_matrix = np.repeat(pos_ans_given_q_vector, self.n_preferences, axis=1)  # Q x N
        neg_ans_given_q_matrix = np.repeat(neg_ans_given_q_vector, self.n_preferences, axis=1)  # Q x N

        # logger.debug(f"\n{pos_ans_given_q_matrix.shape}: {pos_ans_given_q_matrix=}\n{neg_ans_given_q_matrix.shape}: {neg_ans_given_q_matrix=}")

        # Entropy H[P(x)] = sum_x - P(x) log(P(x))
        h_p_theta_matrix = - p_theta_matrix * np.log(p_theta_matrix)  # Q x N
        h_pos_ans_posterior_matrix = - pos_ans_posterior_matrix * np.log(pos_ans_posterior_matrix)  # Q x N
        h_neg_ans_posterior_matrix = - neg_ans_posterior_matrix * np.log(neg_ans_posterior_matrix)  # Q x N

        # logger.debug(f"\n{h_p_theta_matrix.shape}: {h_p_theta_matrix=}\n{h_pos_ans_posterior_matrix.shape}: {h_pos_ans_posterior_matrix=}\n{h_neg_ans_posterior_matrix.shape}: {h_neg_ans_posterior_matrix=}")

        # H[P(theta)] - sum_of ans in {yes, no} ( P(o=ans | q) H[P(theta | o=ans, q)] )
        scores = h_p_theta_matrix - (pos_ans_given_q_matrix * h_pos_ans_posterior_matrix + neg_ans_given_q_matrix * h_neg_ans_posterior_matrix)

        # logger.debug(f"\n{(pos_ans_given_q_matrix * h_pos_ans_posterior_matrix)=}\n{(neg_ans_given_q_matrix * h_neg_ans_posterior_matrix)=}")
        # logger.debug(f"\n{scores.shape}: {scores}\nSum:\n{scores.sum(axis=1)}")

        # self.bp()

        return scores, pos_ans_posterior_matrix, neg_ans_posterior_matrix
    

    def calculate_likelihood(self, q, theta):
        """
        P(o | q, theta) = exp(score(o=1 | q, theta)) / (exp(score(o=1 | q, theta)) + exp(score(o=0 | q, theta)))
        """
        positive_ans_prob, negative_ans_prob, _ = self.get_answer_probs(q, theta)

        return np.exp(positive_ans_prob) / (np.exp(positive_ans_prob) + np.exp(negative_ans_prob))
    

    def get_answer_probs(self, q, theta):
        user_prompt = f"""
        # Your preference
        {theta}

        # Questions for you to answer
        {q}
        """
        logger.debug("------- Answering questions ...")
        logger.debug(f"{user_prompt}")
        response, logprobs = prompt_llm(user_prompt, **self.prompt_config["answer_questions"]["llm_config"])
        logger.debug(f"{response}")
        logger.info(f"Accumulated Cost: ${get_accumulated_cost()}")

        # Try to reduce json loading error
        response = response.replace("'Reasoning'", '"Reasoning"')
        response = response.replace("'Answer (yes/no)", '"Answer (yes/no)"')
        response = response.replace(": 'yes'", ': "yes"')
        response = response.replace(": 'no'", ': "no"')

        try:
            response = json.loads(response)

            assert "Answer (yes/no)" in response
            answer = response["Answer (yes/no)"].lower().strip()
            assert "yes" in answer or "no" in answer

            answer_prob = get_answer_prob(logprobs)

            if "yes" in answer:
                positive_ans_prob = answer_prob
                negative_ans_prob = 1 - answer_prob
            else:
                positive_ans_prob = 1 - answer_prob
                negative_ans_prob = answer_prob

            return positive_ans_prob, negative_ans_prob, answer
        except:
            logger.debug("Failed to load the answer as a json")
            logger.debug(f"\n{response}")

            os.system('say "Hello, something went wrong. Please check!"')
            
            breakpoint()


    def query_user(self, question_tuple, gt_pref = ""):
        preference_pair, question = question_tuple

        logger.info(f"\nQuestion:\n{question}")

        if not gt_pref:
            # If there is no gt_preferences written out, assume that we have to query a real user
            while True:
                user_answer = input("Please only answer 'yes' or 'no: ").lower()

                if "yes" in user_answer or "no" in user_answer:
                    return 1 if "yes" in user_answer else 0
                else:
                    print("You have entered an invalid response. Please only answer in 'yes' or 'no'")
        else:
            _, _, answer = self.get_answer_probs(q = question, theta = gt_pref)

            if "yes" in answer or "no" in answer:
                return 1 if "yes" in answer else 0
            else:
                assert False, f"{answer} is does not contain yes or no"
    
    def init_progress_save_file(self, progress_save_file):
        if progress_save_file:
            self.progress_save_file = progress_save_file
            
            if not os.path.exists(self.progress_save_file):
                self.saved_progress_dict = {}

                with open(self.progress_save_file, "w") as fout:
                    json.dump(self.saved_progress_dict, fout, indent=4)
            else:
                with open(self.progress_save_file, "r") as fin:
                    self.saved_progress_dict = json.load(fin)

    def reset_progress_save_file(self):
        with open(self.progress_save_file, "w") as fout:
            json.dump(self.saved_progress_dict, fout, indent=4)

        self.progress_save_file = []
        self.saved_progress_dict = {}

    def load_progress(self, key):
        if key in self.saved_progress_dict:
            if key == "reasoning_worksheet" or key == "preference_writing_reasoning":
                reasoning_worksheet_fpath = self.saved_progress_dict[key]

                with open(reasoning_worksheet_fpath, "r") as fin:
                    reasoning_worksheet = fin.read()

                return True, reasoning_worksheet
            else:
                return True, copy.deepcopy(self.saved_progress_dict[key])
        else:
            return False, None
        
    def save_progress(self, key, value):
        if key == "reasoning_worksheet" or key == "preference_writing_reasoning":
            if key == "reasoning_worksheet":
                extension = "_reasoning.md"
            elif key == "preference_writing_reasoning":
                extension = "_preference_writing.md"

            fpath = str(self.progress_save_file).replace(".json", extension)

            with open(fpath, "w") as fout:
                fout.write(value)

            self.saved_progress_dict[key] = fpath
        elif key in ["p_theta", "idx_to_best_plan", "best_pref_idx", "query_steps", "plan_score_given_best_pref", "gt_score_best_plan", "gt_score_best_preference_plan"]:
            run_key = str(self.not_agreeable_plan_threshold)
            
            if run_key not in self.saved_progress_dict:
                self.saved_progress_dict[run_key] = {}

            self.saved_progress_dict[run_key][key] = copy.deepcopy(value)
        else:
            self.saved_progress_dict[key] = value

        with open(self.progress_save_file, "w") as fout:
            json.dump(self.saved_progress_dict, fout, indent=4)


    def main(self, demos, initial_state, objs_to_put_away, gt_pref = "", progress_save_file = ""):
        """Main inference loop
        """
        # Initialize the progress_save_file
        self.init_progress_save_file(progress_save_file)

        step = 0
        
        # Generate candidate preferences (N x 1)
        preference_list = self.generate_preferences(demos)
        
        # Initialize P(theta) as uniform distribution
        self.p_theta = np.ones((self.n_preferences,)) * (1 / self.n_preferences)
        logger.info(f"\n{pretty_str_preference_with_p_theta(preference_list, self.p_theta)}")

        # Generate plans from each preference (N x 1)
        plan_list = self.generate_plans_from_preferences(preference_list, initial_state, objs_to_put_away)

        # Create the score metrics (N x N)
        score_matrix = self.score_plans_on_preferences(preference_list, plan_list, initial_state)

        exist_agreeable_plan, idx_to_best_plan = self.check_exist_agreeable_plan(score_matrix, plan_list)
        can_load, _ = self.load_progress("no_apl")

        if not can_load:
            self.save_progress("no_apl", {
                "idx_to_best_plan": int(idx_to_best_plan),
                "gt_score": self.score_plan(gt_pref, initial_state, plan_list[idx_to_best_plan])
            })

        # If there exists a plan that all preferences agree on
        if exist_agreeable_plan:
            best_pref_idx = np.argwhere(self.p_theta == np.max(self.p_theta)).flatten().tolist()
            
            # If there are multiple best preference and idx_to_best_plan is in it, we use that
            if idx_to_best_plan in best_pref_idx:
                best_pref_idx = idx_to_best_plan
            else:
                best_pref_idx = best_pref_idx[0]

            plan_score_given_best_pref = score_matrix[idx_to_best_plan][best_pref_idx]
            return idx_to_best_plan, plan_list, best_pref_idx, preference_list, plan_score_given_best_pref, step
            # Move onto inference pipeline

        exist_agreeable_plan = False
        idx_to_best_plan = None

        # Generate questions
        question_list = self.generate_all_pairwise_questions(preference_list, initial_state, objs_to_put_away, plan_list, score_matrix)     

        pos_ans_likelihood_matrix, neg_ans_likelihood_matrix = self.get_likelihoods(question_list, preference_list) 
        
        while (step < self.max_pref_learning_step) and (not exist_agreeable_plan):
        # While not existing a plan that all preferences agree on or exceeding max step
            scores, pos_posterior_matrix, neg_posterior_matrix = self.score_questions(pos_ans_likelihood_matrix, neg_ans_likelihood_matrix)
        
            # Find best question
            best_question_idx = np.argmax(scores.sum(axis=1))

            logger.opt(colors=True).info(f"\n{pretty_str_scores_for_questions(scores.sum(axis=1), question_list, int(best_question_idx))}")

            print(np.argwhere(scores.sum(axis=1) == np.max(scores.sum(axis=1))).flatten().tolist())
            best_question_idx = int(input("Confirm best question to ask: "))

            # Query user with the best question
            user_answer = self.query_user(question_list[best_question_idx], gt_pref)

            p_theta_before = self.p_theta

            # Update P(theta)
            assert user_answer == 1 or user_answer == 0
            if user_answer == 1:
                self.p_theta = pos_posterior_matrix[best_question_idx]
            elif user_answer == 0:
                self.p_theta = neg_posterior_matrix[best_question_idx]

            logger.opt(colors=True).info(f"\n{pretty_str_preference_with_p_theta_before_after(preference_list, p_theta_before, self.p_theta)}")

            # Remove questions that has already been asked
            pos_ans_likelihood_matrix = np.delete(pos_ans_likelihood_matrix, best_question_idx, axis=0)
            neg_ans_likelihood_matrix = np.delete(neg_ans_likelihood_matrix, best_question_idx, axis=0)
            question_list.pop(best_question_idx)

            exist_agreeable_plan, idx_to_best_plan = self.check_exist_agreeable_plan(score_matrix, plan_list)

            step += 1

            self.bp()

        # If there exists a plan that all preferences agree on
        if exist_agreeable_plan:
            best_pref_idx = np.argwhere(self.p_theta == np.max(self.p_theta)).flatten().tolist()
            
            # If there are multiple best preference and idx_to_best_plan is in it, we use that
            if idx_to_best_plan in best_pref_idx:
                best_pref_idx = idx_to_best_plan
            else:
                best_pref_idx = best_pref_idx[0]
        
            plan_score_given_best_pref = score_matrix[idx_to_best_plan][best_pref_idx]
            return idx_to_best_plan, plan_list, best_pref_idx, preference_list, plan_score_given_best_pref, step
            # Move onto inference pipeline
        else:
            best_pref_idx = np.argmax(self.p_theta)
            return best_pref_idx, plan_list, best_pref_idx, preference_list, None, step
        


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="APRICOT Preference Learning main function")
    parser.add_argument("-o", "--objs_to_put_away",  type=list, required=True, nargs="+", default=[], help="List of objects to put away")
    parser.add_argument("-p", "--preference", type=str, required=False, default="", help="Text description of the user's preference")
    parser.add_argument("--seed", type=int, default=0, help="Random seed")

    parser.add_argument("--prompt_config_path", type=str, required=False, default="", help="Path to the LLM prompt config file")

    parser.add_argument("--num_pref_candidates", type=int, default=5, help="Number of plans to generate")
    parser.add_argument("--not_agreeable_plan_threshold", type=float, default=0.1, help="Threshold of whether a plan is agreeable")
    parser.add_argument("--num_questions_per_pair", type=int, default=5, help="Size of the beam to maintain")
    parser.add_argument("--max_pref_learn_step", type=int, default=5, help="Number of samples to take for each object placement")

    parser.add_argument("--num_plans", type=int, default=10, help="Number of plans to generate")
    parser.add_argument("--beam_size", type=int, default=10, help="Size of the beam to maintain")
    parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to take for each object placement")

    parser.add_argument("--fps", type=int, default=4, help="Frames per second to render the GIF at")
    parser.add_argument("--gif_path", type=str, default=None, help="Path to save the GIF to; doesn't save if None")

    parser.add_argument("--debug", action="store_true", default=False, help="Whether to allow debug")
    parser.add_argument("--ros", action="store_true", default=False, help="Whether or not to communicate with ROS at all")
    parser.add_argument("--ros_service", action="store_true", default=False, help="Whether or not to use ROS service")
    args = parser.parse_args()

    if args.ros:
        import rospy
        rospy.init_node("APRICOT_main")
        print("Node initialized")


    gb = APRICOTPrefLearning(
                    args.seed,
                    args.prompt_config_path,
                    learn_pref_config = dict(
                        n_pref = args.num_pref_candidates,
                        not_agreeable_threshold = args.not_agreeable_plan_threshold,
                        n_q_per_pair = args.num_questions_per_pair,
                        max_step = args.max_pref_learn_step
                    ),
                    planner_config = dict(
                        num_plans = args.num_plans,
                        beam_size = args.beam_size,
                        num_samples = args.num_samples,
                    ),
                    sim_config = dict(
                        fps = args.fps,
                        gif_path = args.gif_path,
                    ),
                    debug=args.debug,
                    use_ros=args.ros,
                    use_ros_service=args.ros_service
    )

    # Ex 1
    initial_state = {
        "top shelf":
            {
                "left side of the top shelf": ["blueberry"],
                "right side of the top shelf": ["kale"]
            },
        "middle shelf":
            {
                "left side of the middle shelf": ["pear", "ranch sauce"],
                "right side of the middle shelf": ["carrot", "mustard"]
            },
        "bottom shelf":
            {
                "left side of the bottom shelf": ["grape", "strawberry", "peach", "watermelon"],
                "right side of the bottom shelf": ["cucumber", "squash", "celery", "potato"]
            }
    }

    # Ex 2
    # initial_state = {
    #     "top shelf":
    #         {
    #             "left side of top shelf": [],
    #             "right side of top shelf": []
    #         },
    #     "middle shelf":
    #         {
    #             "left side of middle shelf": ["yogurt", "whole milk"],
    #             "right side of middle shelf": ["mustard"]
    #         },
    #     "bottom shelf":
    #         {
    #             "left side of bottom shelf": ["grape", "strawberry"],
    #             "right side of bottom shelf": ["cucumber", "carrot"]
    #         }
    # }
    idx_to_agreeable_plan = gb.preference_learning_main(demos="",
                                initial_state=json.dumps(initial_state, indent=4),
                                objs_to_put_away=["".join(str_list) for str_list in args.objs_to_put_away])
    
    logger.info(f"Finished: {idx_to_agreeable_plan=}")