import os
import sys
import datetime
import copy

# os.environ["LOGURU_LEVEL"] = "INFO"

from loguru import logger

if "LOGURU_LEVEL" in os.environ and (os.environ["LOGURU_LEVEL"] == "INFO"):
    fmt = "{message}"
    config = {
        "handlers": [
            {"sink": sys.stderr, "format": fmt},
        ],
    }
    logger.configure(**config)

import argparse
import json
import yaml
import gym
import numpy as np

from utils import format_semantic_loc_for_llm, format_plan
from constants import ALL_POSSIBLE_OBJECTS_TO_PLACE_DICT

from general_prompt_builder.prompt_llm import prompt_llm  
# from llms_for_planning.prompt_builder.prompt_llm import prompt_llm
from llms_for_planning.planners.geometric_feasibility.v0_no_llm_scoring import plan
from llms_for_planning.utils import fetch_messages as planner_fetch_message
from utils import *


class NonInteractivePrefLearning(object):
    NAME = "NonInteractive"

    OBJS_TO_PUT_AWAY_PARAM_NAME = "/perception/obj_to_put_away"
    OBJS_IN_FRIDGE_PARAM_NAME = "/perception/init_obj_in_fridge"
    STATE_TRACKER_TOPIC = "/state_tracker/fridgepcl"
    OBJECT_DETECTION_TOPIC = "/perception/detected_obj"
    ###### Fridge Range
    _BBOX_OFFSET = 0.02  # Define an offset around the fridge for safer planner output
    _SHELF_WIDTH = 0.42 - _BBOX_OFFSET * 2
    _SHELF_DEPTH = 0.24 - _BBOX_OFFSET * 2
    # Origin (Middle shelf's top left corner)
    O_X = -0.0275
    O_Y = 0.10
    O_Z = 0.0
    # Top shelf
    _TOP_SHELF_Y_OFFSET = 0.0 + _BBOX_OFFSET  # Don't want to sample too close to the bottom of the shelf
    _TOP_SHELF_HEIGHT = 0.23 - _BBOX_OFFSET * 2  # Don't want to sample too close to the top of the shelf
    # Middle shelf
    _MIDDLE_SHELF_Y_OFFSET = -0.27 + _BBOX_OFFSET
    _MIDDLE_SHELF_HEIGHT = 0.27 - _BBOX_OFFSET * 2
    # Bottom shelf
    _BOTTOM_SHELF_Y_OFFSET = -0.42 + _BBOX_OFFSET
    _BOTTOM_SHELF_HEIGHT = 0.15 - _BBOX_OFFSET * 2
    # Bbox format (bottom_left_x, bottom_left_y, bottom_left_z, w, h, d)
    FRIDGE_LOCATION_BBOX = {
        # Top shelf locations
        "top shelf":                    (O_X,      O_Y + _TOP_SHELF_Y_OFFSET,    O_Z - _BBOX_OFFSET, _SHELF_WIDTH,       _TOP_SHELF_HEIGHT,      _SHELF_DEPTH),
        "left side of top shelf":            (O_X,      O_Y + _TOP_SHELF_Y_OFFSET,    O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _TOP_SHELF_HEIGHT,      _SHELF_DEPTH),
        "right side of top shelf":           (O_X + _SHELF_WIDTH/2.0,  O_Y + _TOP_SHELF_Y_OFFSET,    O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _TOP_SHELF_HEIGHT,      _SHELF_DEPTH),
        # Middle shelf locations    
        "middle shelf":                    (O_X,      O_Y + _MIDDLE_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH,       _MIDDLE_SHELF_HEIGHT,   _SHELF_DEPTH),
        "left side of middle shelf":            (O_X,      O_Y + _MIDDLE_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _MIDDLE_SHELF_HEIGHT,   _SHELF_DEPTH),
        "right side of middle shelf":           (O_X + _SHELF_WIDTH/2.0,  O_Y + _MIDDLE_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _MIDDLE_SHELF_HEIGHT,   _SHELF_DEPTH),
        # Bottom shelf locations
        "bottom shelf":                    (O_X,      O_Y + _BOTTOM_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH,       _BOTTOM_SHELF_HEIGHT,   _SHELF_DEPTH),
        "left side of bottom shelf":            (O_X,      O_Y + _BOTTOM_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _BOTTOM_SHELF_HEIGHT,   _SHELF_DEPTH),
        "right side of bottom shelf":           (O_X + _SHELF_WIDTH/2.0,  O_Y + _BOTTOM_SHELF_Y_OFFSET, O_Z - _BBOX_OFFSET, _SHELF_WIDTH/2.0,   _BOTTOM_SHELF_HEIGHT,   _SHELF_DEPTH),
    }

    """===================================================================================================

    Initialization

    ==================================================================================================="""
    

    def __init__(self,  seed,
                        prompt_config_path,
                        learn_pref_config = dict(
                            n_pref = 5,
                            not_agreeable_threshold = 0.1,
                            n_q_per_pair = 5,
                            max_step = 5,
                        ),
                        planner_config = dict(
                            num_plans = 10,  # Number of plans to generate
                            beam_size = 10,  # Size of the beams to maintain
                            num_samples = 10,  # Number of samples to take for each object placement
                        ),
                        sim_config = dict(
                            fps = 4, # Frames per second to render the GIF at
                            gif_path = None, # Path to save the GIF to; doesn't save if None
                        ),
                        debug=False,
                        use_ros=False,
                        use_ros_service=False
        ):
        """Initialize the APRICOTPrefLearning main function for inference

        Parameters:
            objs_to_put_away (list) - list of strings that are names of the objects to put away
            preference (str) - text description of the user's preference
        """
        self.need_to_plan = True  # Control whether the planner needs to get replanned
        self.curr_state = {}

        self.debug = debug
        self.use_ros = use_ros
        self.use_ros_service = use_ros_service
        self.planner_config = planner_config
        self.sim_config = sim_config
        self.seed = seed

        self.progress_save_file = ""
        self.saved_progress_dict = {}

        # Because these get used frequently during preference learning, we define individual variabless
        self.n_preferences = learn_pref_config["n_pref"]
        self.not_agreeable_plan_threshold = learn_pref_config["not_agreeable_threshold"]
        self.n_q_per_pair = learn_pref_config["n_q_per_pair"]
        self.max_pref_learning_step = learn_pref_config["max_step"]

        self.init_llm_prompts(prompt_config_path)

    
    def bp(self):
        if self.debug:
            breakpoint()

    def init_llm_prompts(self, prompt_config_path):
        """
        Load LLM prompts

        Effect:
            Set the variable self.prompt_config with the configuration and messages for each LLM prompt
        """
        with open(prompt_config_path, "r") as fin:
            prompt_config = yaml.safe_load(fin)

        self.prompt_config = prompt_config[self.NAME]

        for prompt_type in self.prompt_config:
            message_fetching_fn_to_use = planner_fetch_message if prompt_type == "generate_plan" else fetch_messages

            self.prompt_config[prompt_type]["llm_config"]["messages"] = message_fetching_fn_to_use(self.prompt_config[prompt_type]["experiment_name"],
                                                                            self.prompt_config[prompt_type]["prompt_description"],
                                                                            self.prompt_config[prompt_type]["prompt_version"])

    """===================================================================================================

    Generating preferences

    ==================================================================================================="""

    def generate_preference(self, demos):
        can_load, value = self.load_progress("preference")

        if can_load:
            preference = value
        else:
            logger.debug("------- Generating preferences ...")
            logger.debug(f"\n{demos}")
            output = prompt_llm(demos, **self.prompt_config["generate_one_preference_baseline"]["llm_config"])

            logger.debug(f"\n{output}")

            self.save_progress("preference_writing_reasoning", output)

            assert "```json" in output or "```" in output
            partition_keywords = "```json" if "```json" in output else "```"

            _, _, reasonings_and_preferences = output.partition(partition_keywords)
            reasonings_and_preferences = reasonings_and_preferences.replace("```", "")  # Remove the second half of ```
            reasonings_and_preferences = reasonings_and_preferences.strip()
            
            try:
                reasonings_and_preferences = json.loads(reasonings_and_preferences)
                
                preference = reasonings_and_preferences["preference"]
            except:
                logger.info("Failed to load the answer as a json")

                logger.info(output)

                preference = input("Please manually copy the preference here (without quotations): ")
                
                breakpoint()

            self.save_progress("preference", preference)

        # self.bp()

        return preference
    
    """===================================================================================================

    Generate text plans from preferences

    ==================================================================================================="""

    def generate_text_plan(self, objs_to_put_away, init_state, preference):
        can_load, value = self.load_progress("text_plan")

        if can_load:
            text_plan = value
        else:
            user_prompt = f"""
            Objects: {objs_to_put_away}
            Locations: {list(self.FRIDGE_LOCATION_BBOX.keys())}
            Initial State: {init_state}
            Preference: "{preference}"
            """
            logger.debug("------- Generating a plan ...")
            logger.debug(user_prompt)
            response = prompt_llm(user_prompt, **self.prompt_config["generate_plan"]["llm_config"])

            logger.debug(f"\n{response}")

            _, _, text_plan = response.partition("pickandplace")

            text_plan = "pickandplace" + text_plan  # Add back the first 'pickandplace' that got removed

            assert "#" not in text_plan, f"Should not have more comments in the code:\n{text_plan}"

            self.save_progress("text_plan", text_plan)

        # self.bp()

        return text_plan.strip()  # Remove uncessary whitespace
    
    """===================================================================================================

    Log progress

    ==================================================================================================="""
    
    def init_progress_save_file(self, progress_save_file):
        if progress_save_file:
            self.progress_save_file = progress_save_file
            
            if not os.path.exists(self.progress_save_file):
                self.saved_progress_dict = {}

                with open(self.progress_save_file, "w") as fout:
                    json.dump(self.saved_progress_dict, fout, indent=4)
            else:
                with open(self.progress_save_file, "r") as fin:
                    self.saved_progress_dict = json.load(fin)

    def reset_progress_save_file(self):
        with open(self.progress_save_file, "w") as fout:
            json.dump(self.saved_progress_dict, fout, indent=4)

        self.progress_save_file = []
        self.saved_progress_dict = {}

    def load_progress(self, key):
        if key in self.saved_progress_dict:
            if key == "reasoning_worksheet" or key == "preference_writing_reasoning":
                reasoning_worksheet_fpath = self.saved_progress_dict[key]

                with open(reasoning_worksheet_fpath, "r") as fin:
                    reasoning_worksheet = fin.read()

                return True, reasoning_worksheet
            else:
                return True, copy.deepcopy(self.saved_progress_dict[key])
        else:
            return False, None
        
    def save_progress(self, key, value):
        if key == "reasoning_worksheet" or key == "preference_writing_reasoning":
            if key == "reasoning_worksheet":
                extension = "_reasoning.md"
            elif key == "preference_writing_reasoning":
                extension = "_preference_writing.md"

            fpath = str(self.progress_save_file).replace(".json", extension)

            with open(fpath, "w") as fout:
                fout.write(value)

            self.saved_progress_dict[key] = fpath
        else:
            self.saved_progress_dict[key] = value

        with open(self.progress_save_file, "w") as fout:
            json.dump(self.saved_progress_dict, fout, indent=4)


    def main(self, demos, initial_state, objs_to_put_away, gt_pref = "", progress_save_file = ""):
        """Main inference loop
        """
        # Initialize the progress_save_file
        self.init_progress_save_file(progress_save_file)
        
        # Generate preference
        preference = self.generate_preference(demos)

        # Generate plan
        plan = self.generate_text_plan(objs_to_put_away, initial_state, preference)

        # Return plan and preference
        return plan, preference, 0
        
