import os
import sys
import datetime
import copy

# os.environ["LOGURU_LEVEL"] = "INFO"

from loguru import logger

if "LOGURU_LEVEL" in os.environ and (os.environ["LOGURU_LEVEL"] == "INFO"):
    fmt = "{message}"
    config = {
        "handlers": [
            {"sink": sys.stderr, "format": fmt},
        ],
    }
    logger.configure(**config)

import argparse
import json
import yaml
import gym
import numpy as np

from utils import format_semantic_loc_for_llm, format_plan
from constants import ALL_POSSIBLE_OBJECTS_TO_PLACE_DICT

from general_prompt_builder.prompt_llm import prompt_llm  
# from llms_for_planning.prompt_builder.prompt_llm import prompt_llm
from llms_for_planning.planners.geometric_feasibility.v0_no_llm_scoring import plan
from llms_for_planning.utils import fetch_messages as planner_fetch_message
from utils import *


class InteractiveBaselinePrefLearning(object):
    NAME = "Interactive"

    OBJS_TO_PUT_AWAY_PARAM_NAME = "/perception/obj_to_put_away"
    OBJS_IN_FRIDGE_PARAM_NAME = "/perception/init_obj_in_fridge"
    STATE_TRACKER_TOPIC = "/state_tracker/fridgepcl"
    OBJECT_DETECTION_TOPIC = "/perception/detected_obj"
    ###### Fridge Range
    _BBOX_OFFSET = 0.02  # Define an offset around the fridge for safer planner output
    _SHELF_WIDTH = 0.42 - _BBOX_OFFSET * 2
    _SHELF_DEPTH = 0.24 - _BBOX_OFFSET * 2
    # Origin (Middle shelf's top left corner)
    O_X = -0.0275
    O_Y = 0.10
    O_Z = 0.0
    # Top shelf
    _TOP_SHELF_Y_OFFSET = 0.0 + _BBOX_OFFSET  # Don't want to sample too close to the bottom of the shelf
    _TOP_SHELF_HEIGHT = 0.23 - _BBOX_OFFSET * 2  # Don't want to sample too close to the top of the shelf
    # Middle shelf
    _MIDDLE_SHELF_Y_OFFSET = -0.27 + _BBOX_OFFSET
    _MIDDLE_SHELF_HEIGHT = 0.27 - _BBOX_OFFSET * 2
    # Bottom shelf
    _BOTTOM_SHELF_Y_OFFSET = -0.42 + _BBOX_OFFSET
    _BOTTOM_SHELF_HEIGHT = 0.15 - _BBOX_OFFSET * 2
    # Bbox format (bottom_left_x, bottom_left_y, bottom_left_z, w, h, d)
    FRIDGE_LOCATION_BBOX = {
        # Top shelf locations
        "top shelf":                    (O_X,      O_Y + _TOP_SHELF_Y_OFFSET,    O_Z - _BBOX_OFFSET, _SHELF_WIDTH,       _TOP_SHELF_HEIGHT,      _SHELF_DEPTH),
        "left side of top shelf":            (O_X,      O_Y + _TOP_SHELF_Y_OFFSET,    O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _TOP_SHELF_HEIGHT,      _SHELF_DEPTH),
        "right side of top shelf":           (O_X + _SHELF_WIDTH/2.0,  O_Y + _TOP_SHELF_Y_OFFSET,    O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _TOP_SHELF_HEIGHT,      _SHELF_DEPTH),
        # Middle shelf locations    
        "middle shelf":                    (O_X,      O_Y + _MIDDLE_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH,       _MIDDLE_SHELF_HEIGHT,   _SHELF_DEPTH),
        "left side of middle shelf":            (O_X,      O_Y + _MIDDLE_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _MIDDLE_SHELF_HEIGHT,   _SHELF_DEPTH),
        "right side of middle shelf":           (O_X + _SHELF_WIDTH/2.0,  O_Y + _MIDDLE_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _MIDDLE_SHELF_HEIGHT,   _SHELF_DEPTH),
        # Bottom shelf locations
        "bottom shelf":                    (O_X,      O_Y + _BOTTOM_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH,       _BOTTOM_SHELF_HEIGHT,   _SHELF_DEPTH),
        "left side of bottom shelf":            (O_X,      O_Y + _BOTTOM_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _BOTTOM_SHELF_HEIGHT,   _SHELF_DEPTH),
        "right side of bottom shelf":           (O_X + _SHELF_WIDTH/2.0,  O_Y + _BOTTOM_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _BOTTOM_SHELF_HEIGHT,   _SHELF_DEPTH),
    }

    """===================================================================================================

    Initialization

    ==================================================================================================="""
    

    def __init__(self,  seed,
                        prompt_config_path,
                        learn_pref_config = dict(
                            n_pref = 5,
                            not_agreeable_threshold = 0.1,
                            n_q_per_pair = 5,
                            max_step = 5,
                        ),
                        planner_config = dict(
                            num_plans = 10,  # Number of plans to generate
                            beam_size = 10,  # Size of the beams to maintain
                            num_samples = 10,  # Number of samples to take for each object placement
                        ),
                        sim_config = dict(
                            fps = 4, # Frames per second to render the GIF at
                            gif_path = None, # Path to save the GIF to; doesn't save if None
                        ),
                        debug=False,
                        use_ros=False,
                        use_ros_service=False
        ):
        """Initialize the APRICOTPrefLearning main function for inference

        Parameters:
            objs_to_put_away (list) - list of strings that are names of the objects to put away
            preference (str) - text description of the user's preference
        """
        self.need_to_plan = True  # Control whether the planner needs to get replanned
        self.curr_state = {}

        self.debug = debug
        self.use_ros = use_ros
        self.use_ros_service = use_ros_service
        self.planner_config = planner_config
        self.sim_config = sim_config
        self.seed = seed

        self.progress_save_file = ""
        self.saved_progress_dict = {}

        # Because these get used frequently during preference learning, we define individual variabless
        self.n_preferences = learn_pref_config["n_pref"]
        self.not_agreeable_plan_threshold = learn_pref_config["not_agreeable_threshold"]
        self.n_q_per_pair = learn_pref_config["n_q_per_pair"]
        self.max_pref_learning_step = learn_pref_config["max_step"]

        self.init_llm_prompts(prompt_config_path)

    
    def bp(self):
        if self.debug:
            breakpoint()

    def init_llm_prompts(self, prompt_config_path):
        """
        Load LLM prompts

        Effect:
            Set the variable self.prompt_config with the configuration and messages for each LLM prompt
        """
        with open(prompt_config_path, "r") as fin:
            prompt_config = yaml.safe_load(fin)

        self.prompt_config = prompt_config[self.NAME]

        for prompt_type in self.prompt_config:
            message_fetching_fn_to_use = planner_fetch_message if prompt_type == "generate_plan" else fetch_messages

            self.prompt_config[prompt_type]["llm_config"]["messages"] = message_fetching_fn_to_use(self.prompt_config[prompt_type]["experiment_name"],
                                                                            self.prompt_config[prompt_type]["prompt_description"],
                                                                            self.prompt_config[prompt_type]["prompt_version"])
            
            

    """===================================================================================================

    One prompt to do the entire interactive process

    ==================================================================================================="""
    def one_prompt_interactive(self, demos, initial_state, objs_to_put_away, chat_log, questions_left):
        can_ask_question = f"Yes! You can ask {questions_left} more questions." if questions_left > 0 else "No. You must generate the preference now."

        user_query = f"""
        {demos}

        ## New initial state of the fridge
        {initial_state}

        ## Objects to put away
        {objs_to_put_away}

        ## Chat log
        {json.dumps(chat_log, indent=4)}

        ## Can you still ask question?
        {can_ask_question}
        """

        logger.debug(f"\n{user_query}")

        output = prompt_llm(user_query, **self.prompt_config["one_prompt_interactive"]["llm_config"])

        logger.debug(f"\n{output}")

        if "```json" in output or "```" in output:
            partition_keywords = "```json" if "```json" in output else "```"

            _, _, reasonings_and_actions = output.partition(partition_keywords)
            reasonings_and_actions = reasonings_and_actions.replace("```", "")  # Remove the second half of ```
            reasonings_and_actions = reasonings_and_actions.strip()
        else:
            reasonings_and_actions = output

        try:
            reasonings_and_actions = json.loads(reasonings_and_actions)
        except:
            logger.info("Failed to load the answer as a json")

            logger.info(output)

            os.system('say "Hello, something went wrong. Please check!"')

            reasonings_and_actions = {}

            for key in ["terminate? (yes/no)", "reasoning", "question", "preference"]:
                value = input(f"Please manually copy what's under key={key} here (without quotations): ")

                reasonings_and_actions[key] = value.strip()
            
            breakpoint()

        terminate = "yes" in reasonings_and_actions.get("terminate? (yes/no)", "").strip().lower()
        reasoning = reasonings_and_actions.get("reasoning", "")
        question = reasonings_and_actions.get("question", "")
        preference = reasonings_and_actions.get("preference", "")

        assert (terminate and preference != "") or (question != ""), "Either there has to be a preference generated and terminate is true, or question is generated"

        return int(terminate), reasoning, question, preference
    
    """===================================================================================================

    Query the user

    ==================================================================================================="""
    
    def query_user(self, question, gt_pref = ""):
        logger.info(f"\nQuestion:\n{question}")

        if not gt_pref:
            # If there is no gt_preferences written out, assume that we have to query a real user
            while True:
                user_answer = input("Please only answer 'yes' or 'no: ").lower()

                if "yes" in user_answer or "no" in user_answer:
                    return 1 if "yes" in user_answer else 0
                else:
                    print("You have entered an invalid response. Please only answer in 'yes' or 'no'")
        else:
            _, _, answer = self.get_answer_probs(q = question, theta = gt_pref)

            if "yes" in answer or "no" in answer:
                return 1 if "yes" in answer else 0
            else:
                assert False, f"{answer} is does not contain yes or no"

    def get_answer_probs(self, q, theta):
        user_prompt = f"""
        # Your preference
        {theta}

        # Questions for you to answer
        {q}
        """
        logger.debug("------- Answering questions ...")
        logger.debug(f"{user_prompt}")
        response, logprobs = prompt_llm(user_prompt, **self.prompt_config["answer_questions"]["llm_config"])
        logger.debug(f"{response}")

        # Try to reduce json loading error
        response = response.replace("'Reasoning'", '"Reasoning"')
        response = response.replace("'Answer (yes/no)", '"Answer (yes/no)"')
        response = response.replace(": 'yes'", ': "yes"')
        response = response.replace(": 'no'", ': "no"')

        try:
            response = json.loads(response)

            assert "Answer (yes/no)" in response
            answer = response["Answer (yes/no)"].lower().strip()
            assert "yes" in answer or "no" in answer

            answer_prob = get_answer_prob(logprobs)

            if "yes" in answer:
                positive_ans_prob = answer_prob
                negative_ans_prob = 1 - answer_prob
            else:
                positive_ans_prob = 1 - answer_prob
                negative_ans_prob = answer_prob

            return positive_ans_prob, negative_ans_prob, answer
        except:
            logger.debug("Failed to load the answer as a json")
            logger.debug(f"\n{response}")

            os.system('say "Hello, something went wrong. Please check!"')
            
            breakpoint()

    """===================================================================================================

    Generate text plans from preferences

    ==================================================================================================="""

    def generate_text_plan(self, objs_to_put_away, init_state, preference):
        can_load, value = self.load_progress("text_plan")

        if can_load:
            text_plan = value
        else:
            user_prompt = f"""
            Objects: {objs_to_put_away}
            Locations: {list(self.FRIDGE_LOCATION_BBOX.keys())}
            Initial State: {init_state}
            Preference: "{preference}"
            """
            logger.debug("------- Generating a plan ...")
            logger.debug(user_prompt)
            response = prompt_llm(user_prompt, **self.prompt_config["generate_plan"]["llm_config"])

            logger.debug(f"\n{response}")

            _, _, text_plan = response.partition("pickandplace")

            text_plan = "pickandplace" + text_plan  # Add back the first 'pickandplace' that got removed

            assert "#" not in text_plan, f"Should not have more comments in the code:\n{text_plan}"

            self.save_progress("text_plan", text_plan)

        # self.bp()

        return text_plan.strip()  # Remove uncessary whitespace

    """===================================================================================================

    Log progress

    ==================================================================================================="""

    def init_progress_save_file(self, progress_save_file):
        if progress_save_file:
            self.progress_save_file = progress_save_file
            
            if not os.path.exists(self.progress_save_file):
                self.saved_progress_dict = {}

                with open(self.progress_save_file, "w") as fout:
                    json.dump(self.saved_progress_dict, fout, indent=4)
            else:
                with open(self.progress_save_file, "r") as fin:
                    self.saved_progress_dict = json.load(fin)

    def reset_progress_save_file(self):
        with open(self.progress_save_file, "w") as fout:
            json.dump(self.saved_progress_dict, fout, indent=4)

        self.progress_save_file = []
        self.saved_progress_dict = {}

    def load_progress(self, key):
        if key in self.saved_progress_dict:
            if key == "reasoning_worksheet" or key == "preference_writing_reasoning":
                reasoning_worksheet_fpath = self.saved_progress_dict[key]

                with open(reasoning_worksheet_fpath, "r") as fin:
                    reasoning_worksheet = fin.read()

                return True, reasoning_worksheet
            else:
                return True, copy.deepcopy(self.saved_progress_dict[key])
        else:
            return False, None
        
    def save_progress(self, key, value):
        if key == "reasoning_worksheet" or key == "preference_writing_reasoning":
            if key == "reasoning_worksheet":
                extension = "_reasoning.md"
            elif key == "preference_writing_reasoning":
                extension = "_preference_writing.md"

            fpath = str(self.progress_save_file).replace(".json", extension)

            with open(fpath, "w") as fout:
                fout.write(value)

            self.saved_progress_dict[key] = fpath
        else:
            self.saved_progress_dict[key] = value

        with open(self.progress_save_file, "w") as fout:
            json.dump(self.saved_progress_dict, fout, indent=4)


    def main(self, demos, initial_state, objs_to_put_away, gt_pref = "", progress_save_file = ""):
        """Main inference loop
        """
        # Initialize the progress_save_file
        self.init_progress_save_file(progress_save_file)
        
        terminate = 0
        step = 0

        can_load, value = self.load_progress("chat_log")

        if can_load:
            chat_log = value
        else:
            chat_log = []

        can_load, value = self.load_progress("step")

        if can_load:
            step = value
        else:
            step = 0

        terminate, value = self.load_progress("terminate")

        if can_load:
            terminate = value
        else:
            terminate = 0

        # Plus because the last step should be a chance to generate a preference
        while (step < self.max_pref_learning_step + 1) and (not terminate):
            terminate, reasoning, question, preference = self.one_prompt_interactive(demos,
                                                                                initial_state,
                                                                                objs_to_put_away,
                                                                                chat_log,
                                                                                self.max_pref_learning_step - step - 1)

            chat_log.append({"role": "thought", "content": reasoning})

            if question != "":
                user_answer = self.query_user(question, "")

                user_answer_in_text = "yes" if int(user_answer) == 1 else "no"

                chat_log.append({"role": "assistant", "content": question})
                chat_log.append({"role": "user", "content": user_answer_in_text})

            self.save_progress("step", step)
            self.save_progress("chat_log", chat_log)
            self.save_progress("terminate", terminate)
            self.save_progress("preference", preference)

            if terminate:
                break
            
            step += 1

            # self.bp()

        can_load, value = self.load_progress("preference")

        if can_load:
            preference = value
        else:
            preference = ""

        if terminate and preference != "":
            # Generate plan
            plan = self.generate_text_plan(objs_to_put_away, initial_state, preference)

            self.bp()

            return plan, preference, step
        else:
            return "", "", step
            



        

