from typing import Dict, Any
from copy import deepcopy
import numpy as np

from harl import constants
from harl.common.llm_logger import Logger
from harl.common.memory import LocalMemory, GlobalMemory
from harl.common.base.base_provider import BaseProvider
from harl.configs.config import Config

logger = Logger()
config = Config()
global_memory = GlobalMemory()

class SkillRefinePreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 gm: Any,
                 use_screenshot_augmented = False,
                 use_video = False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gm = gm
        self.use_screenshot_augmented = use_screenshot_augmented
        self.use_video = use_video

    def __call__(self):

        if not self.use_video:
            screenshot_path = memory.get_recent_history(constants.IMAGES_MEM_BUCKET)[-1]
            screenshot_augmnented_path = memory.get_recent_history(constants.AUGMENTED_IMAGES_MEM_BUCKET)[-1]

            if not self.use_screenshot_augmented:
                image_introduction = [
                    {
                        "introduction": "This screenshot is the current step of the game.",
                        "path": screenshot_path,
                        "assistant": ""
                    }
                ]
            else:
                image_introduction = [
                    {
                        "introduction": "This screenshot is the current step of the game.",
                        "path": screenshot_augmnented_path,
                        "assistant": ""
                    }
                ]

            processed_params = {
                "image_introduction": image_introduction
            }

        else:
            images = memory.get_recent_history(constants.IMAGES_MEM_BUCKET, config.event_count)
            reasonings = memory.get_recent_history('decision_making_reasoning', config.event_count)

            image_introduction = [
                {
                    "path": images[event_i],
                    "assistant": "",
                    "introduction": 'This is the {} screenshot of recent events. The description of this image: {}'.format(
                        ['first', 'second', 'third', 'fourth', 'fifth'][event_i], reasonings[event_i])
                } for event_i in range(config.event_count)
            ]

            processed_params = {
                "image_introduction": image_introduction,
                "event_count": config.event_count
            }

        memory.working_area.update(processed_params)

        return processed_params

class SkillGenerationPostprocessProvider(BaseProvider):

    def __init__(self,
                 *args,
                 use_subtask = False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.use_subtask = use_subtask

    def __call__(self, response: Dict):

        processed_response = deepcopy(response)

        subtask_description = processed_response["subtask"]

        processed_response.update({
            "subtask_description": subtask_description
        })

        if not self.use_subtask:
            processed_response_keys = list(processed_response.keys())
            for key in processed_response_keys:
                if "subtask" in key:
                    processed_response.pop(key)

        memory.update_info_history(processed_response)

        return processed_response


class SMACv2SkillRefinePreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 memory: Any,
                 skill_registry: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.memory = memory
        self.skill_registry = skill_registry

    def __call__(self, env_idx: int):

        logger.write("SMACv2 Skill Refine Preprocess for Env {}".format(env_idx))

        task_description = self.memory[env_idx].get_recent_history("task_description", k=1)[0]
        unit_type = self.memory[env_idx].unit_type
        unit_id = self.memory[env_idx].agent_id
        scenario_name = config.env_args["map_name"]
        win_lose = self.memory[env_idx].get_recent_history("win_lose", k=1)[0]
        exec_error = self.memory[env_idx].get_recent_history("exec_error", k=1)[0]
        if exec_error:
            executing_action_error = exec_error
        else:
            executing_action_error = ""

        # Minimap info tracking
        ego_minimap = self.memory[env_idx].get_recent_history("ego_minimap", k=1)[0]

        info_summary = self.memory[env_idx].get_recent_history("summarization", k=1)[0]

        processed_params = {
            "task_description": task_description,
            "ego_minimap": ego_minimap,
            "win_lose": win_lose,
            "unit_type": unit_type,
            "unit_id": unit_id,
            "scenario_name": scenario_name,
            "info_summary": info_summary,
            "executing_action_error": executing_action_error,
        }
        pre_action = self.memory[env_idx].get_recent_history("pre_action", k=1)[0]
        pre_action_name, pre_action_params = self.skill_registry.convert_expression_to_skill(pre_action)
        action_code, action_code_info = self.skill_registry.get_skill_code(pre_action_name)
        action_code = action_code if action_code is not None else action_code_info
        processed_params.update({
            "action_code": action_code,
        })

        self.memory[env_idx].working_area.update(processed_params)

        return processed_params
    

class RDR2TaskInferencePreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 gm: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gm = gm

    def __call__(self):

        logger.write(f'RDR2 Task Inference Preprocess')

        task_description = memory.get_recent_history("task_description", k=1)[0]
        screenshot_path = memory.get_recent_history(constants.IMAGES_MEM_BUCKET, k=1)[0]

        processed_params = {
            "task_description": task_description,
            constants.IMAGES_MEM_BUCKET: screenshot_path
        }

        # Information summary preparation
        if len(memory.get_recent_history("decision_making_reasoning",
                                         memory.max_recent_steps)) == memory.max_recent_steps:
            logger.write(f'> Information summary call...')

            images = memory.get_recent_history(constants.IMAGES_MEM_BUCKET, config.event_count)
            reasonings = memory.get_recent_history('decision_making_reasoning', config.event_count)

            image_introduction = [
                {
                    "path": images[event_i], "assistant": "",
                    "introduction": 'This is the {} screenshot of recent events. The description of this image: {}'.format(
                        ['first', 'second', 'third', 'fourth', 'fifth'][event_i], reasonings[event_i])
                } for event_i in range(config.event_count)
            ]

            previous_summarization = memory.get_summarization()
            event_count = str(config.event_count)

            processed_params.update({
                "image_introduction": image_introduction,
                "previous_summarization": previous_summarization,
                "event_count": event_count
            })

        memory.working_area.update(processed_params)

        return processed_params

class SMACv2SkillRefinePostprocessProvider(BaseProvider):

    def __init__(self,
                 *args,
                 memory: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.memory = memory
    def __call__(self, response: Dict, env_idx: int):

        logger.write("SMACv2 Skill Refine Postprocess for Env {}".format(env_idx))

        processed_response = deepcopy(response)

        if constants.SKILL_GENERATION_MODULE in processed_response:
            all_generated_actions = processed_response[constants.SKILL_GENERATION_MODULE]
        else:
            all_generated_actions = []

        processed_response.update({
            "all_generated_actions": all_generated_actions,
        })

        self.memory[env_idx].update_info_history(processed_response)

        return processed_response
    

class RDR2TaskInferencePostprocessProvider(BaseProvider):

    def __init__(self,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, response: Dict):

        logger.write(f'RDR2 Task Inference Postprocess')

        processed_response = deepcopy(response)

        if "info_summary" not in response:
            response["info_summary"] = ""

        info_summary = response["info_summary"]

        processed_response.update({
            "summarization": info_summary
        })

        memory.update_info_history(processed_response)

        return processed_response


class StardewTaskInferencePreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 gm: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gm = gm

    def __call__(self):

        logger.write(f'Stardew Task Inference Preprocess')

        prompts = [
            "This screenshot is the current step of the game. The blue band represents the left side and the yellow band represents the right side."
        ]

        task_description = memory.get_recent_history("task_description", k=1)[0]
        previous_summarization = memory.get_recent_history("summarization", 1)[0]
        substask_description = memory.get_recent_history("subtask_description", 1)[0]
        substask_reasoning = memory.get_recent_history("subtask_reasoning", 1)[0]
        toolbar_information = memory.get_recent_history("toolbar_information", 1)[0]
        images = memory.get_recent_history(constants.AUGMENTED_IMAGES_MEM_BUCKET, 1)
        decision_making_reasoning = memory.get_recent_history('decision_making_reasoning', 1)
        self_reflection_reasoning = memory.get_recent_history('self_reflection_reasoning', 1)

        image_introduction = []
        image_introduction.append(
            {
                "introduction": prompts[-1],
                "path": images,
                "assistant": ""
            })

        processed_params = {
            "image_introduction": image_introduction,
            "previous_summarization": previous_summarization,
            "task_description": task_description,
            "subtask_description": substask_description,
            "subtask_reasoning": substask_reasoning,
            "previous_reasoning": decision_making_reasoning,
            "self_reflection_reasoning": self_reflection_reasoning,
            "toolbar_information": toolbar_information
        }

        memory.working_area.update(processed_params)

        return processed_params


class StardewTaskInferencePostprocessProvider(BaseProvider):

    def __init__(self,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, response: Dict):

        logger.write(f'Stardew Task Inference Postprocess')

        processed_response = deepcopy(response)

        history_summary = response['history_summary']

        subtask_description = response['subtask']
        subtask_reasoning = response['subtask_reasoning']

        processed_response.update({
            'summarization': history_summary,
            'subtask_description': subtask_description,
            'subtask_reasoning': subtask_reasoning
        })

        memory.update_info_history(processed_response)

        return processed_response