from typing import Dict, Any
from copy import deepcopy

from harl.common.llm_logger  import Logger
from harl.common.memory import LocalMemory, GlobalMemory
from harl.configs.config import Config
from harl.common.base.base_provider import BaseProvider
from harl import constants

logger = Logger()
config = Config()
global_memory = GlobalMemory()

class ActionPlanningPreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 gm: Any,
                 use_screenshot_augmented = False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gm = gm
        self.use_screenshot_augmented = use_screenshot_augmented

    def __call__(self):

        prompts = [
            "This screenshot is the previous step of the game.",
            "This screenshot is the current step of the game."
        ]

        screenshot_paths = memory.get_recent_history("screenshot_path", k=config.action_planning_image_num)
        screenshot_augmnented_paths = memory.get_recent_history("screenshot_augmented_path", k=config.action_planning_image_num)

        if not self.use_screenshot_augmented:
            image_introduction = []
            for i in range(len(screenshot_paths), 0, -1):
                image_introduction.append(
                    {
                        "introduction": prompts[-i],
                        "path": screenshot_paths[-i],
                        "assistant": ""
                    })
        else:
            image_introduction = []
            for i in range(len(screenshot_augmnented_paths), 0, -1):
                image_introduction.append(
                    {
                        "introduction": prompts[-i],
                        "path": screenshot_augmnented_paths[-i],
                        "assistant": ""
                    })

        processed_params = {
            "image_introduction": image_introduction
        }

        memory.working_area.update(processed_params)

        return processed_params


class SMACv2ActionPlanningPreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 memory: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.memory = memory

    def __call__(self, env_idx: int):

        logger.write("SMACv2 Action Planning Preprocess for Env {}".format(env_idx))

        start_frame_id = self.memory[env_idx].get_recent_history("start_frame_id", k=1)[0]
        end_frame_id = self.memory[env_idx].get_recent_history("end_frame_id", k=1)[0]
        pre_action = self.memory[env_idx].get_recent_history("pre_action", k=1)[0]
        pre_self_reflection_reasoning = self.memory[env_idx].get_recent_history("pre_self_reflection_reasoning", k=1)[0]
        skill_library = self.memory[env_idx].get_recent_history("skill_library", k=1)[0]
        task_description = self.memory[env_idx].get_recent_history("task_description", k=1)[0]
        # long_horizon = self.memory[env_idx].get_recent_history("long_horizon", k=1)[0]
        # cooperation = self.memory[env_idx].get_recent_history("cooperation", k=1)[0]
        # game_progression = self.memory[env_idx].get_recent_history("game_progression", k=1)[0]
        unit_type = self.memory[env_idx].unit_type
        unit_id = self.memory[env_idx].agent_id
        scenario_name = config.env_args["map_name"]

        previous_action = ""
        previous_reasoning = ""
        executing_action_error = ""
        if pre_action:
            previous_action = self.memory[env_idx].get_recent_history("action", k=1)[0]
            previous_reasoning = self.memory[env_idx].get_recent_history("decision_making_reasoning", k=1)[0]
            executing_action_error = self.memory[env_idx].get_last_action_error(start_frame_id, end_frame_id)

        previous_self_reflection_reasoning = ""
        if pre_self_reflection_reasoning:
            previous_self_reflection_reasoning = self.memory[env_idx].get_recent_history("self_reflection_reasoning", k=1)[0]

        info_summary = self.memory[env_idx].get_recent_history("summarization", k=1)[0]

        # @TODO Temporary solution with fake augmented entries if no bounding box exists. Ideally it should read images, then check for possible augmentation.
        prompts = [
            "The cyan circle shows what your unit can see (sight range), the red dashed circle shows where you can attack/heal (shooting range). "
            "Your unit is highlighted in green at the center, with allies (green) and enemies (red) visible within your sight range.\n",
            # "The cyan circle shows what your unit can see (sight range), the red dashed circle shows where you can attack/heal (shooting range). "
            # "Your unit is highlighted in green at the center, with allies (green) and enemies (red) visible within your sight range. "
            # "The target unit you should pay attention to is annotated with box and text.\n"
        ]

        end_frame_id = self.memory[env_idx].get_recent_history("end_frame_id", k=1)[0]
        start_frame_id = end_frame_id
        # start_frame_id = max(end_frame_id - config.action_planning_image_num, 0)

        images, obs_texts = self.memory[env_idx].get_frame_paths_obs(start_frame_id, end_frame_id)
        augmented_image = self.memory[env_idx].get_recent_history(constants.AUGMENTED_IMAGES_MEM_BUCKET,k=1)[0]

        image_introduction = []
        text_intros = []
        for i in range(len(images)):
            if i == len(images) - 1:
                text_intro = "This screenshot is the current step of the game\n"
            else:
                text_intro = f"This screenshot is {len(images) - i - 1} step{'s' if len(images) - i - 1 > 1 else ''} before the current step of the game\n"
            text_intros.append(text_intro)

        for i in range(len(images)):
            image_introduction.append(
                {
                    "introduction": text_intros[i] + prompts[0] + obs_texts[i],
                    "path": images[i] if config.use_image else "",
                    "assistant": "",
                    "resolution": "high",
                }
            )

        ally_task = global_memory.get_ally_task(env_idx, unit_id)
        game_situation = self.memory[env_idx].get_recent_history("game_situation", k=1)[0]

        processed_params = {
            # "observation": obs_texts[-1],
            "previous_action": previous_action,
            "executing_action_error": executing_action_error,
            "previous_reasoning": previous_reasoning,
            "previous_self_reflection_reasoning": previous_self_reflection_reasoning,
            "skill_library": skill_library,
            "task_description": task_description,
            # "ego_minimap": ego_minimap,
            "info_summary": info_summary,
            "image_introduction": image_introduction,
            # "long_horizon": long_horizon,
            # "cooperation": cooperation,
            # "game_progression": game_progression,
            "unit_type": unit_type,
            "unit_id": unit_id,
            "scenario_name": scenario_name,
            "last_episode_reasoning": self.memory[env_idx].get_recent_history("last_episode_reasoning", k=1)[0],
            "ally_task": ally_task,
            "game_situation": game_situation,
        }

        self.memory[env_idx].working_area.update(processed_params)

        return processed_params
    

class RDR2ActionPlanningPreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 gm: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gm = gm

    def __call__(self):

        logger.write("RDR2 Action Planning Preprocess")

        prompts = [
            "Now, I will give you five screenshots for decision making.",
            "This screenshot is five steps before the current step of the game",
            "This screenshot is three steps before the current step of the game",
            "This screenshot is two steps before the current step of the game",
            "This screenshot is the previous step of the game",
            "This screenshot is the current step of the game"
        ]

        response_keys = memory.get_recent_history("response_keys", k=1)[0]
        response = memory.get_recent_history("response", k=1)[0]
        pre_action = memory.get_recent_history("pre_action", k=1)[0]
        pre_self_reflection_reasoning = memory.get_recent_history("pre_self_reflection_reasoning", k=1)[0]
        pre_screen_classification = memory.get_recent_history("pre_screen_classification", k=1)[0]
        screen_classification = memory.get_recent_history("screen_classification", k=1)[0]
        skill_library = memory.get_recent_history("skill_library", k=1)[0]
        task_description = memory.get_recent_history("task_description", k=1)[0]

        previous_action = ""
        previous_reasoning = ""
        if pre_action:
            previous_action = memory.get_recent_history("action", k=1)[0]
            previous_reasoning = memory.get_recent_history("decision_making_reasoning", k=1)[0]

        previous_self_reflection_reasoning = ""
        if pre_self_reflection_reasoning:
            previous_self_reflection_reasoning = memory.get_recent_history("self_reflection_reasoning", k=1)[0]

        info_summary = memory.get_recent_history("summarization", k=1)[0]

        # @TODO Temporary solution with fake augmented entries if no bounding box exists. Ideally it should read images, then check for possible augmentation.
        image_memory = memory.get_recent_history("screenshot_path", k=config.action_planning_image_num)
        augmented_image_memory = memory.get_recent_history(constants.AUGMENTED_IMAGES_MEM_BUCKET,
                                                           k=config.action_planning_image_num)

        image_introduction = []
        for i in range(len(image_memory), 0, -1):
            if len(augmented_image_memory) >= i and augmented_image_memory[-i] != constants.NO_IMAGE:
                if i == len(image_memory):
                    image_introduction.append(
                        {
                            "introduction": prompts[-i],
                            "path": augmented_image_memory[-i],
                            "assistant": "",
                            "resolution": "high",
                        })
                else:
                    image_introduction.append(
                        {
                            "introduction": prompts[-i],
                            "path": augmented_image_memory[-i],
                            "assistant": "",
                        })
            else:
                image_introduction.append(
                    {
                        "introduction": prompts[-i],
                        "path": image_memory[-i],
                        "assistant": ""
                    })

        # Minimap info tracking
        minimap_information = ""
        if constants.MINIMAP_INFORMATION in response_keys:
            minimap_information = response[constants.MINIMAP_INFORMATION]
            logger.write(f"{constants.MINIMAP_INFORMATION}: {minimap_information}")

            minimap_info_str = ""
            for key, value in minimap_information.items():
                if value:
                    for index, item in enumerate(value):
                        minimap_info_str = minimap_info_str + key + ' ' + str(index) + ': angle ' + str(
                            int(item['theta'])) + ' degree' + '\n'
            minimap_info_str = minimap_info_str.rstrip('\n')

            logger.write(f'minimap_info_str: {minimap_info_str}')
            minimap_information = minimap_info_str

        processed_params = {
            "pre_screen_classification": pre_screen_classification,
            "screen_classification": screen_classification,
            "previous_action": previous_action,
            "previous_reasoning": previous_reasoning,
            "previous_self_reflection_reasoning": previous_self_reflection_reasoning,
            "skill_library": skill_library,
            "task_description": task_description,
            "minimap_information": minimap_information,
            "info_summary": info_summary,
            "image_introduction": image_introduction
        }

        memory.working_area.update(processed_params)

        return processed_params

class StardewActionPlanningPreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 gm: Any,
                 toolbar_information: str,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gm = gm
        self.toolbar_information = toolbar_information

    def __call__(self):

        logger.write("Stardew Action Planning Preprocess")

        prompts = [
            "Now, I will give you five screenshots for decision making."
            "This screenshot is five steps before the current step of the game",
            "This screenshot is three steps before the current step of the game",
            "This screenshot is two steps before the current step of the game",
            "This screenshot is the previous step of the game. The blue band represents the left side and the yellow band represents the right side.",
            "This screenshot is the current step of the game. The blue band represents the left side and the yellow band represents the right side."
        ]

        pre_action = memory.get_recent_history("pre_action", k=1)[0]
        pre_self_reflection_reasoning = memory.get_recent_history("pre_self_reflection_reasoning", k=1)[0]
        toolbar_information = memory.get_recent_history("toolbar_information", k=1)[0]
        selected_position = memory.get_recent_history("selected_position", k=1)[0]
        summarization = memory.get_recent_history("summarization", k=1)[0]
        skill_library = memory.get_recent_history("skill_library", k=1)[0]
        task_description = memory.get_recent_history("task_description", k=1)[0]
        subtask_description = memory.get_recent_history("subtask_description", k=1)[0]
        history_summary = memory.get_recent_history("summarization", k=1)[0]

        # Decision making preparation
        toolbar_information = toolbar_information if toolbar_information is not None else self.toolbar_information
        selected_position = selected_position if selected_position is not None else 1

        previous_action = ""
        previous_reasoning = ""
        if pre_action:
            previous_action = memory.get_recent_history("action", k=1)[0]
            previous_reasoning = memory.get_recent_history("decision_making_reasoning", k=1)[0]

        previous_self_reflection_reasoning = ""
        if pre_self_reflection_reasoning:
            previous_self_reflection_reasoning = memory.get_recent_history("self_reflection_reasoning", k=1)[0]

        # @TODO Temporary solution with fake augmented entries if no bounding box exists. Ideally it should read images, then check for possible augmentation.
        image_memory = memory.get_recent_history("augmented_image", k=config.action_planning_image_num)

        image_introduction = []
        for i in range(len(image_memory), 0, -1):
            image_introduction.append(
                {
                    "introduction": prompts[-i],
                    "path": image_memory[-i],
                    "assistant": ""
                })

        processed_params = {
            "pre_self_reflection_reasoning": pre_self_reflection_reasoning,
            "toolbar_information": toolbar_information,
            "selected_position": selected_position,
            "summarization": summarization,
            "skill_library": skill_library,
            "task_description": task_description,
            "subtask_description": subtask_description,
            "history_summary": history_summary,
            "previous_action": previous_action,
            "previous_reasoning": previous_reasoning,
            "previous_self_reflection_reasoning": previous_self_reflection_reasoning,
            "image_introduction": image_introduction,
        }

        memory.working_area.update(processed_params)

        return processed_params

class ActionPlanningPostprocessProvider(BaseProvider):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, response: Dict):

        processed_response = deepcopy(response)

        skill_steps = []
        if 'actions' in response:
            skill_steps = response['actions']

        if skill_steps:
            skill_steps = [i for i in skill_steps if i != '']
        else:
            skill_steps = ['']

        skill_steps = skill_steps[:config.number_of_execute_skills]

        if config.number_of_execute_skills > 1:
            actions = "[" + ",".join(skill_steps) + "]"
        else:
            actions = str(skill_steps[0])

        decision_making_reasoning = response['reasoning']

        processed_response.update({
            "actions": actions,
            "decision_making_reasoning": decision_making_reasoning,
            "skill_steps": skill_steps,
        })
        memory.update_info_history(processed_response)

        return processed_response

class SMACv2ActionPlanningPostprocessProvider(BaseProvider):

    def __init__(self, *args, 
                 memory: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.memory = memory

    def __call__(self, response: Dict, env_idx: int):
        try:
            logger.write("SMACv2 Action Planning Postprocess for Env {}".format(env_idx))

            processed_response = deepcopy(response)

            skill_steps = []
            if 'skills' in response:
                skill_steps = response['skills']
            elif 'skill' in response:
                skill_steps = response['skill']

            if skill_steps:
                skill_steps = [i for i in skill_steps if i != '']
            else:
                skill_steps = ['']

            skill_steps = skill_steps[:config.number_of_execute_skills]

            if config.number_of_execute_skills > 1:
                actions = "[" + ",".join(skill_steps) + "]"
            else:
                actions = str(skill_steps[0])

            decision_making_reasoning = response['reasoning']
            pre_decision_making_reasoning = decision_making_reasoning

            processed_response.update({
                "pre_action": actions,
                "action": actions,
                "pre_decision_making_reasoning": pre_decision_making_reasoning,
                "decision_making_reasoning": decision_making_reasoning,
                "skill_steps": skill_steps,
            })
            self.memory[env_idx].update_info_history(processed_response)

            return processed_response
        except Exception as e:
            raise ValueError(f"SMACv2 Action Planning Postprocess Error Env {env_idx}: {e}")
    

class RDR2ActionPlanningPostprocessProvider(BaseProvider):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, response: Dict):

        logger.write("RDR2 Action Planning Postprocess")

        processed_response = deepcopy(response)

        skill_steps = []
        if 'actions' in response:
            skill_steps = response['actions']

        if skill_steps:
            skill_steps = [i for i in skill_steps if i != '']
        else:
            skill_steps = ['']

        skill_steps = skill_steps[:config.number_of_execute_skills]

        if config.number_of_execute_skills > 1:
            actions = "[" + ",".join(skill_steps) + "]"
        else:
            actions = str(skill_steps[0])

        decision_making_reasoning = response['reasoning']
        pre_decision_making_reasoning = decision_making_reasoning

        processed_response.update({
            "action": actions,
            "pre_decision_making_reasoning": pre_decision_making_reasoning,
            "decision_making_reasoning": decision_making_reasoning,
            "skill_steps": skill_steps,
        })
        memory.update_info_history(processed_response)

        return processed_response


class StardewActionPlanningPostprocessProvider(BaseProvider):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, response: Dict):

        logger.write("Stardew Action Planning Postprocess")

        processed_response = deepcopy(response)

        skill_steps = []
        if 'actions' in response:
            skill_steps = response['actions']

        if skill_steps:
            skill_steps = [i for i in skill_steps if i != '']
        else:
            skill_steps = ['']

        skill_steps = skill_steps[:config.number_of_execute_skills]
        pre_action = "[" + ",".join(skill_steps) + "]"

        if config.number_of_execute_skills > 1:
            actions = "[" + ",".join(skill_steps) + "]"
        else:
            actions = str(skill_steps[0])

        decision_making_reasoning = response['reasoning']
        pre_decision_making_reasoning = decision_making_reasoning

        processed_response.update({
            "pre_action": pre_action,
            "action": actions,
            "pre_decision_making_reasoning": pre_decision_making_reasoning,
            "decision_making_reasoning": decision_making_reasoning,
            "skill_steps": skill_steps,
        })
        memory.update_info_history(processed_response)

        return processed_response