"""
Main orchestrator for inference pipeline for APRICOT project
"""

import os
import sys
import datetime
import copy

# os.environ["LOGURU_LEVEL"] = "INFO"

from loguru import logger

if "LOGURU_LEVEL" in os.environ and (os.environ["LOGURU_LEVEL"] == "INFO"):
    fmt = "{message}"
    config = {
        "handlers": [
            {"sink": sys.stderr, "format": fmt},
        ],
    }
    logger.configure(**config)

import argparse
import json
import yaml
import gym
import numpy as np

from utils import format_semantic_loc_for_llm, format_plan
from constants import ALL_POSSIBLE_OBJECTS_TO_PLACE_DICT

from general_prompt_builder.prompt_llm import (
    prompt_llm,
)  
# from llms_for_planning.prompt_builder.prompt_llm import prompt_llm
from llms_for_planning.planners.geometric_feasibility.v0_no_llm_scoring import plan
from llms_for_planning.utils import fetch_messages as planner_fetch_message
from utils import *


class QueryingOnlyInteractivePrefLearning(object):
    NAME = "QueryingOnlyInteractive"

    OBJS_TO_PUT_AWAY_PARAM_NAME = "/perception/obj_to_put_away"
    OBJS_IN_FRIDGE_PARAM_NAME = "/perception/init_obj_in_fridge"
    STATE_TRACKER_TOPIC = "/state_tracker/fridgepcl"
    OBJECT_DETECTION_TOPIC = "/perception/detected_obj"
    ###### Fridge Range
    _BBOX_OFFSET = 0.02  # Define an offset around the fridge for safer planner output
    _SHELF_WIDTH = 0.42 - _BBOX_OFFSET * 2
    _SHELF_DEPTH = 0.24 - _BBOX_OFFSET * 2
    # Origin (Middle shelf's top left corner)
    O_X = -0.0275
    O_Y = 0.10
    O_Z = 0.0
    # Top shelf
    _TOP_SHELF_Y_OFFSET = (
        0.0 + _BBOX_OFFSET
    )  # Don't want to sample too close to the bottom of the shelf
    _TOP_SHELF_HEIGHT = (
        0.23 - _BBOX_OFFSET * 2
    )  # Don't want to sample too close to the top of the shelf
    # Middle shelf
    _MIDDLE_SHELF_Y_OFFSET = -0.27 + _BBOX_OFFSET
    _MIDDLE_SHELF_HEIGHT = 0.27 - _BBOX_OFFSET * 2
    # Bottom shelf
    _BOTTOM_SHELF_Y_OFFSET = -0.42 + _BBOX_OFFSET
    _BOTTOM_SHELF_HEIGHT = 0.15 - _BBOX_OFFSET * 2
    # Bbox format (bottom_left_x, bottom_left_y, bottom_left_z, w, h, d)
    FRIDGE_LOCATION_BBOX = {
        # Top shelf locations
        "top shelf": (
            O_X,
            O_Y + _TOP_SHELF_Y_OFFSET,
            O_Z - _BBOX_OFFSET,
            _SHELF_WIDTH,
            _TOP_SHELF_HEIGHT,
            _SHELF_DEPTH,
        ),
        "left side of top shelf": (
            O_X,
            O_Y + _TOP_SHELF_Y_OFFSET,
            O_Z - _BBOX_OFFSET,
            _SHELF_WIDTH / 2.0,
            _TOP_SHELF_HEIGHT,
            _SHELF_DEPTH,
        ),
        "right side of top shelf": (
            O_X + _SHELF_WIDTH / 2.0,
            O_Y + _TOP_SHELF_Y_OFFSET,
            O_Z - _BBOX_OFFSET,
            _SHELF_WIDTH / 2.0,
            _TOP_SHELF_HEIGHT,
            _SHELF_DEPTH,
        ),
        # Middle shelf locations
        "middle shelf": (
            O_X,
            O_Y + _MIDDLE_SHELF_Y_OFFSET,
            O_Z - _BBOX_OFFSET,
            _SHELF_WIDTH,
            _MIDDLE_SHELF_HEIGHT,
            _SHELF_DEPTH,
        ),
        "left side of middle shelf": (
            O_X,
            O_Y + _MIDDLE_SHELF_Y_OFFSET,
            O_Z - _BBOX_OFFSET,
            _SHELF_WIDTH / 2.0,
            _MIDDLE_SHELF_HEIGHT,
            _SHELF_DEPTH,
        ),
        "right side of middle shelf": (
            O_X + _SHELF_WIDTH / 2.0,
            O_Y + _MIDDLE_SHELF_Y_OFFSET,
            O_Z - _BBOX_OFFSET,
            _SHELF_WIDTH / 2.0,
            _MIDDLE_SHELF_HEIGHT,
            _SHELF_DEPTH,
        ),
        # Bottom shelf locations
        "bottom shelf": (
            O_X,
            O_Y + _BOTTOM_SHELF_Y_OFFSET,
            O_Z - _BBOX_OFFSET,
            _SHELF_WIDTH,
            _BOTTOM_SHELF_HEIGHT,
            _SHELF_DEPTH,
        ),
        "left side of bottom shelf": (
            O_X,
            O_Y + _BOTTOM_SHELF_Y_OFFSET,
            O_Z - _BBOX_OFFSET,
            _SHELF_WIDTH / 2.0,
            _BOTTOM_SHELF_HEIGHT,
            _SHELF_DEPTH,
        ),
        "right side of bottom shelf": (
            O_X + _SHELF_WIDTH / 2.0,
            O_Y + _BOTTOM_SHELF_Y_OFFSET,
            O_Z - _BBOX_OFFSET,
            _SHELF_WIDTH / 2.0,
            _BOTTOM_SHELF_HEIGHT,
            _SHELF_DEPTH,
        ),
    }

    """===================================================================================================

    Initialization

    ==================================================================================================="""

    def __init__(
        self,
        seed,
        prompt_config_path,
        learn_pref_config=dict(
            n_pref=5,
            not_agreeable_threshold=0.1,
            n_q_per_pair=5,
            max_step=5,
        ),
        planner_config=dict(
            num_plans=10,  # Number of plans to generate
            beam_size=10,  # Size of the beams to maintain
            num_samples=10,  # Number of samples to take for each object placement
        ),
        sim_config=dict(
            fps=4,  # Frames per second to render the GIF at
            gif_path=None,  # Path to save the GIF to; doesn't save if None
        ),
        debug=False,
        use_ros=False,
        use_ros_service=False,
    ):
        """Initialize the APRICOTPrefLearning main function for inference

        Parameters:
            objs_to_put_away (list) - list of strings that are names of the objects to put away
            preference (str) - text description of the user's preference
        """
        self.need_to_plan = True  # Control whether the planner needs to get replanned
        self.curr_state = {}

        self.debug = debug
        self.use_ros = use_ros
        self.use_ros_service = use_ros_service
        self.planner_config = planner_config
        self.sim_config = sim_config
        self.seed = seed

        self.progress_save_file = ""
        self.saved_progress_dict = {}

        # Because these get used frequently during preference learning, we define individual variabless
        self.n_preferences = learn_pref_config["n_pref"]
        self.not_agreeable_plan_threshold = learn_pref_config["not_agreeable_threshold"]
        self.n_q_per_pair = learn_pref_config["n_q_per_pair"]
        self.max_pref_learning_step = learn_pref_config["max_step"]

        self.init_llm_prompts(prompt_config_path)

    def bp(self):
        if self.debug:
            breakpoint()

    def init_llm_prompts(self, prompt_config_path):
        """
        Load LLM prompts

        Effect:
            Set the variable self.prompt_config with the configuration and messages for each LLM prompt
        """
        with open(prompt_config_path, "r") as fin:
            prompt_config = yaml.safe_load(fin)

        self.prompt_config = prompt_config[self.NAME]

        for prompt_type in self.prompt_config:
            message_fetching_fn_to_use = (
                planner_fetch_message
                if prompt_type == "generate_plan"
                else fetch_messages
            )

            self.prompt_config[prompt_type]["llm_config"]["messages"] = (
                message_fetching_fn_to_use(
                    self.prompt_config[prompt_type]["experiment_name"],
                    self.prompt_config[prompt_type]["prompt_description"],
                    self.prompt_config[prompt_type]["prompt_version"],
                )
            )

            if prompt_type == "generate_preferences":
                self.prompt_config[prompt_type]["llm_config"]["messages"][1][
                    "content"
                ] = self.prompt_config[prompt_type]["llm_config"]["messages"][1][
                    "content"
                ].replace(
                    "<num_preferences>", str(self.n_preferences)
                )
                self.prompt_config[prompt_type]["llm_config"]["messages"][1][
                    "content"
                ] = self.prompt_config[prompt_type]["llm_config"]["messages"][1][
                    "content"
                ].replace(
                    "<num_preferences_minus_one>", str(self.n_preferences - 1)
                )
            elif prompt_type == "generate_questions":
                self.prompt_config[prompt_type]["llm_config"]["messages"][1][
                    "content"
                ] = self.prompt_config[prompt_type]["llm_config"]["messages"][1][
                    "content"
                ].replace(
                    "<num_questions>", str(self.n_q_per_pair)
                )

    """===================================================================================================

    One prompt to do the entire interactive process

    ==================================================================================================="""

    def one_prompt_interactive(
        self, preferences, initial_state, objs_to_put_away, chat_log, questions_left
    ):
        can_ask_question = (
            f"Yes! You can ask {questions_left} more questions."
            if questions_left > 0
            else "No. You must generate the preference now."
        )

        user_query = f"""
        ## List of preferences
        {preferences}

        ## New initial state of the fridge
        {initial_state}

        ## Objects to put away
        {objs_to_put_away}

        ## Chat log
        {json.dumps(chat_log, indent=4)}

        ## Can you still ask question?
        {can_ask_question}
        """

        """
        use_query should have
        - list of preferences
        - New initial state of the fridge
        - objects to put away
        - chat log
        - can you still ask question?
        """

        logger.debug(f"\n{user_query}")

        output = prompt_llm(
            user_query, **self.prompt_config["one_prompt_querying"]["llm_config"]
        )

        logger.debug(f"\n{output}")

        # Cleaning up the output
        if "```json" in output or "```" in output:
            partition_keywords = "```json" if "```json" in output else "```"

            _, _, reasonings_and_actions = output.partition(partition_keywords)
            reasonings_and_actions = reasonings_and_actions.replace(
                "```", ""
            )  # Remove the second half of ```
            reasonings_and_actions = reasonings_and_actions.strip()
        else:
            reasonings_and_actions = output

        try:
            reasonings_and_actions = json.loads(reasonings_and_actions)
        except:
            logger.info("Failed to load the answer as a json")

            logger.info(output)

            os.system('say "Hello, something went wrong. Please check!"')

            reasonings_and_actions = {}

            for key in [
                "terminate? (yes/no)",
                "reasoning",
                "question",
                "index_of_best_preference",
            ]:
                value = input(
                    f"Please manually copy what's under key={key} here (without quotations): "
                )

                reasonings_and_actions[key] = value.strip()

            breakpoint()

        # Extracted the LLMs' output
        terminate = (
            "yes"
            in reasonings_and_actions.get("terminate? (yes/no)", "").strip().lower()
        )
        reasoning = reasonings_and_actions.get("reasoning", "")
        question = reasonings_and_actions.get("question", "")
        best_preference = reasonings_and_actions.get("index_of_best_preference", None)

        assert (terminate and best_preference != None) or (
            question != ""
        ), "Either there has to be a preference generated and terminate is true, or question is generated"

        return int(terminate), reasoning, question, best_preference

    """===================================================================================================

    Query the user

    ==================================================================================================="""

    def query_user(self, question, gt_pref=""):
        logger.info(f"\nQuestion:\n{question}")

        if not gt_pref:
            # If there is no gt_preferences written out, assume that we have to query a real user
            while True:
                user_answer = input("Please only answer 'yes' or 'no: ").lower()

                if "yes" in user_answer or "no" in user_answer:
                    return 1 if "yes" in user_answer else 0
                else:
                    print(
                        "You have entered an invalid response. Please only answer in 'yes' or 'no'"
                    )
        else:
            _, _, answer = self.get_answer_probs(q=question, theta=gt_pref)

            if "yes" in answer or "no" in answer:
                return 1 if "yes" in answer else 0
            else:
                assert False, f"{answer} is does not contain yes or no"

    def get_answer_probs(self, q, theta):
        user_prompt = f"""
        # Your preference
        {theta}

        # Questions for you to answer
        {q}
        """
        logger.debug("------- Answering questions ...")
        logger.debug(f"{user_prompt}")
        response, logprobs = prompt_llm(
            user_prompt, **self.prompt_config["answer_questions"]["llm_config"]
        )
        logger.debug(f"{response}")

        # Try to reduce json loading error
        response = response.replace("'Reasoning'", '"Reasoning"')
        response = response.replace("'Answer (yes/no)", '"Answer (yes/no)"')
        response = response.replace(": 'yes'", ': "yes"')
        response = response.replace(": 'no'", ': "no"')

        try:
            response = json.loads(response)

            assert "Answer (yes/no)" in response
            answer = response["Answer (yes/no)"].lower().strip()
            assert "yes" in answer or "no" in answer

            answer_prob = get_answer_prob(logprobs)

            if "yes" in answer:
                positive_ans_prob = answer_prob
                negative_ans_prob = 1 - answer_prob
            else:
                positive_ans_prob = 1 - answer_prob
                negative_ans_prob = answer_prob

            return positive_ans_prob, negative_ans_prob, answer
        except:
            logger.debug("Failed to load the answer as a json")
            logger.debug(f"\n{response}")

            os.system('say "Hello, something went wrong. Please check!"')

            breakpoint()

    """===================================================================================================

    Generate text plans from preferences

    ==================================================================================================="""

    def generate_text_plan(self, objs_to_put_away, init_state, preference):
        can_load, value = self.load_progress("text_plan")

        if can_load:
            text_plan = value
        else:
            user_prompt = f"""
            Objects: {objs_to_put_away}
            Locations: {list(self.FRIDGE_LOCATION_BBOX.keys())}
            Initial State: {init_state}
            Preference: "{preference}"
            """
            logger.debug("------- Generating a plan ...")
            logger.debug(user_prompt)
            response = prompt_llm(
                user_prompt, **self.prompt_config["generate_plan"]["llm_config"]
            )

            logger.debug(f"\n{response}")

            _, _, text_plan = response.partition("pickandplace")

            text_plan = (
                "pickandplace" + text_plan
            )  # Add back the first 'pickandplace' that got removed

            assert (
                "#" not in text_plan
            ), f"Should not have more comments in the code:\n{text_plan}"

            self.save_progress("text_plan", text_plan)

        # self.bp()

        return text_plan.strip()  # Remove uncessary whitespace

    def generate_plans_from_preferences(
        self, preference_list, initial_state, objs_to_put_away
    ):
        """
        Return:
            (list) A list of string. The text plan is a string of code
        """
        can_load, value = self.load_progress("text_plan_list")

        if can_load:
            text_plan_list = value
        else:
            text_plan_list = []

            for preference in preference_list:
                text_plan = self.generate_text_plan(
                    objs_to_put_away=objs_to_put_away,
                    init_state=initial_state,
                    preference=preference,
                )

                text_plan_list.append(text_plan)

            self.save_progress("text_plan_list", text_plan_list)

        self.bp()

        plan_str = "\n======================\n".join(text_plan_list)
        logger.debug(f"Generated Text Plans:\n{plan_str}")

        return text_plan_list

    """===================================================================================================

    Log progress

    ==================================================================================================="""

    def init_progress_save_file(self, progress_save_file):
        if progress_save_file:
            self.progress_save_file = progress_save_file

            if not os.path.exists(self.progress_save_file):
                self.saved_progress_dict = {}

                with open(self.progress_save_file, "w") as fout:
                    json.dump(self.saved_progress_dict, fout, indent=4)
            else:
                with open(self.progress_save_file, "r") as fin:
                    self.saved_progress_dict = json.load(fin)

    def reset_progress_save_file(self):
        with open(self.progress_save_file, "w") as fout:
            json.dump(self.saved_progress_dict, fout, indent=4)

        self.progress_save_file = []
        self.saved_progress_dict = {}

    def load_progress(self, key):
        if key in self.saved_progress_dict:
            if key == "reasoning_worksheet" or key == "preference_writing_reasoning":
                reasoning_worksheet_fpath = self.saved_progress_dict[key]

                with open(reasoning_worksheet_fpath, "r") as fin:
                    reasoning_worksheet = fin.read()

                return True, reasoning_worksheet
            else:
                return True, copy.deepcopy(self.saved_progress_dict[key])
        else:
            return False, None

    def save_progress(self, key, value):
        if key == "reasoning_worksheet" or key == "preference_writing_reasoning":
            if key == "reasoning_worksheet":
                extension = "_reasoning.md"
            elif key == "preference_writing_reasoning":
                extension = "_preference_writing.md"

            fpath = str(self.progress_save_file).replace(".json", extension)

            with open(fpath, "w") as fout:
                fout.write(value)

            self.saved_progress_dict[key] = fpath
        else:
            self.saved_progress_dict[key] = value

        with open(self.progress_save_file, "w") as fout:
            json.dump(self.saved_progress_dict, fout, indent=4)

    def generate_preferences(self, demos):
        can_load, value = self.load_progress("preference_list")

        if can_load:
            preference_list = value
        else:
            can_load, value = self.load_progress("reasoning_worksheet")

            if can_load:
                reasoning_worksheet = value
            else:
                logger.debug("------- Analyzing demonstrations ...")
                reasoning_worksheet = prompt_llm(
                    demos, **self.prompt_config["analyze_demos"]["llm_config"]
                )

                for category_section in [
                    "(Fruits)",
                    "(Vegetables)",
                    "(Juice-and-soft-drinks)",
                    "(Dairy-Products)",
                    "(Condiments)",
                ]:
                    if not category_section in reasoning_worksheet:
                        logger.debug(
                            f"{category_section} is not in the reasoning worksheet"
                        )

                        os.system('say "Hello, please verify reasoning worksheet!"')
                        breakpoint()

                logger.debug(f"\n{reasoning_worksheet}")

                self.save_progress("reasoning_worksheet", reasoning_worksheet)

            self.bp()

            can_load, value = self.load_progress("preference_writing_reasoning")

            if can_load:
                output = value
            else:
                logger.debug("------- Generating preferences ...")
                # user_prompt = f"Input-Demonstrations:\n{demos}\n\nInput-Reasoning-Worksheet:\n{reasoning_worksheet}"
                user_prompt = reasoning_worksheet

                logger.debug(f"\n{user_prompt}")
                output = prompt_llm(
                    user_prompt,
                    **self.prompt_config["generate_preferences"]["llm_config"],
                )

                self.save_progress("preference_writing_reasoning", output)

            logger.debug(f"\n{output}")

            if "```json" in output or "```" in output:
                partition_keywords = "```json" if "```json" in output else "```"

                _, _, preferences = output.partition(partition_keywords)
                preferences = preferences.replace(
                    "```", ""
                )  # Remove the second half of ```
                preferences = preferences.strip()
            else:
                preferences = output

            try:
                preference_list = json.loads(preferences)
            except:
                logger.info("Failed to load the answer as a json")

                logger.info(output)

                os.system('say "Hello, something went wrong. Please check!"')

                while True:
                    best_preference = input(
                        "Please manually copy the preference here (without quotations): "
                    )

                    if best_preference == "":
                        break

                    preference_list.append(best_preference.strip())

                breakpoint()

            self.save_progress("preference_list", preference_list)

            os.system('say "Hello, please verify demonstration!"')
            breakpoint()

        self.bp()

        return preference_list

    def main(
        self, demos, initial_state, objs_to_put_away, gt_pref="", progress_save_file=""
    ):
        """Main inference loop"""
        # Initialize the progress_save_file
        self.init_progress_save_file(progress_save_file)

        terminate = 0
        step = 0

        can_load, value = self.load_progress("chat_log")

        if can_load:
            chat_log = value
        else:
            chat_log = []

        can_load, value = self.load_progress("step")

        if can_load:
            step = value
        else:
            step = 0

        terminate, value = self.load_progress("terminate")

        if can_load:
            terminate = value
        else:
            terminate = 0

        # Generate candidate preferences (N x 1)
        preference_list = self.generate_preferences(demos)
        pref_list_str = "\n".join(f"{i}. {s}" for i, s in enumerate(preference_list))
        plan_list = self.generate_plans_from_preferences(
            preference_list, initial_state, objs_to_put_away
        )

        # Plus because the last step should be a chance to generate a preference
        while (step < self.max_pref_learning_step + 1) and (not terminate):
            terminate, reasoning, question, best_preference_idx = (
                self.one_prompt_interactive(
                    pref_list_str,
                    initial_state,
                    objs_to_put_away,
                    chat_log,
                    self.max_pref_learning_step - step - 1,
                )
            )

            chat_log.append({"role": "thought", "content": reasoning})

            if question != "":
                user_answer = self.query_user(question, gt_pref)

                user_answer_in_text = "yes" if int(user_answer) == 1 else "no"

                chat_log.append({"role": "assistant", "content": question})
                chat_log.append({"role": "user", "content": user_answer_in_text})

            self.save_progress("step", step)
            self.save_progress("chat_log", chat_log)
            self.save_progress("terminate", terminate)
            self.save_progress("index_of_best_preference", best_preference_idx)

            if terminate:
                break

            step += 1

            self.bp()

        can_load, value = self.load_progress("index_of_best_preference")

        if can_load:
            best_preference_idx = value
            best_preference = ""
        else:
            best_preference_idx = ""
            best_preference = preference_list[best_preference_idx]

        if terminate and best_preference_idx != None:
            # Generate plan
            # plan = self.generate_text_plan(
            #     objs_to_put_away, initial_state, best_preference
            # )
            plan = plan_list[best_preference_idx]

            self.bp()

            return plan, best_preference_idx, best_preference, step
        else:
            return "", "", step
