from typing import Dict, Any
from copy import deepcopy
import numpy as np
from harl import constants
from harl.common.llm_logger import Logger
from harl.common.memory import LocalMemory, GlobalMemory
from harl.common.base.base_provider import BaseProvider
from harl.configs.config import Config
from harl.utils.envs_tools import get_relative_direction

logger = Logger()
config = Config()
global_memory = GlobalMemory()

class TaskInferencePreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 gm: Any,
                 use_screenshot_augmented = False,
                 use_video = False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gm = gm
        self.use_screenshot_augmented = use_screenshot_augmented
        self.use_video = use_video

    def __call__(self):

        if not self.use_video:
            screenshot_path = memory.get_recent_history(constants.IMAGES_MEM_BUCKET)[-1]
            screenshot_augmnented_path = memory.get_recent_history(constants.AUGMENTED_IMAGES_MEM_BUCKET)[-1]

            if not self.use_screenshot_augmented:
                image_introduction = [
                    {
                        "introduction": "This screenshot is the current step of the game.",
                        "path": screenshot_path,
                        "assistant": ""
                    }
                ]
            else:
                image_introduction = [
                    {
                        "introduction": "This screenshot is the current step of the game.",
                        "path": screenshot_augmnented_path,
                        "assistant": ""
                    }
                ]

            processed_params = {
                "image_introduction": image_introduction
            }

        else:
            images = memory.get_recent_history(constants.IMAGES_MEM_BUCKET, config.event_count)
            reasonings = memory.get_recent_history('decision_making_reasoning', config.event_count)

            image_introduction = [
                {
                    "path": images[event_i],
                    "assistant": "",
                    "introduction": 'This is the {} screenshot of recent events. The description of this image: {}'.format(
                        ['first', 'second', 'third', 'fourth', 'fifth'][event_i], reasonings[event_i])
                } for event_i in range(config.event_count)
            ]

            processed_params = {
                "image_introduction": image_introduction,
                "event_count": config.event_count
            }

        memory.working_area.update(processed_params)

        return processed_params

class TaskInferencePostprocessProvider(BaseProvider):

    def __init__(self,
                 *args,
                 use_subtask = False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.use_subtask = use_subtask

    def __call__(self, response: Dict):

        processed_response = deepcopy(response)

        subtask_description = processed_response["subtask"]

        processed_response.update({
            "subtask_description": subtask_description
        })

        if not self.use_subtask:
            processed_response_keys = list(processed_response.keys())
            for key in processed_response_keys:
                if "subtask" in key:
                    processed_response.pop(key)

        memory.update_info_history(processed_response)

        return processed_response


class SMACv2TaskInferencePreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 memory: Any,
                 skill_registry: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.memory = memory
        self.skill_registry = skill_registry

    def __call__(self, env_idx: int):

        logger.write("SMACv2 Task Inference Preprocess for Env {}".format(env_idx))

        task_description = self.memory[env_idx].get_recent_history("task_description", k=1)[0]
        game_situation = self.memory[env_idx].get_recent_history("game_situation", k=1)[0]
        pre_action = self.memory[env_idx].get_recent_history("pre_action", k=1)[0]
        pre_decision_making_reasoning = self.memory[env_idx].get_recent_history("pre_decision_making_reasoning", k=1)[0]
        unit_type = self.memory[env_idx].unit_type
        unit_id = self.memory[env_idx].agent_id
        scenario_name = config.env_args["map_name"]

        start_frame_id = self.memory[env_idx].get_recent_history("start_frame_id", k=1)[0]
        end_frame_id = self.memory[env_idx].get_recent_history("end_frame_id", k=1)[0]
        screenshot_paths, obs_texts = self.memory[env_idx].get_frame_paths_obs(start_frame_id, end_frame_id)
        # Minimap info tracking
        ego_minimap_text = self.memory[env_idx].get_recent_history("ego_minimap", k=1)[0]
        historical_lesson = '\n'.join(self.memory[env_idx].get_recent_history("historical_lesson", k=config.max_summarization_num))
        last_episode_reasoning = self.memory[env_idx].get_recent_history("last_episode_reasoning", k=1)[0]

        last_action_return = self.memory[env_idx].get_last_action_return(start_frame_id, end_frame_id)

        if pre_action:
            pre_action_name, pre_action_params = self.skill_registry.convert_expression_to_skill(pre_action)

            # only input the pre_action name
            previous_action = pre_action_name
            action_code, action_code_info = self.skill_registry.get_skill_code(pre_action_name)
            action_code = action_code if action_code is not None else action_code_info
        else:
            previous_action = "race_melee_ranged_medivac_navi_A_star_score_type_default_center(obs='current')"
            pre_action_name, pre_action_params = self.skill_registry.convert_expression_to_skill(previous_action)
            action_code, action_code_info = self.skill_registry.get_skill_code(pre_action_name)
            action_code = action_code if action_code is not None else action_code_info

        ally_task = global_memory.get_ally_task(env_idx, unit_id)

        # ego_minimap = self.memory[env_idx].get_recent_history("ego_minimap", k=1)[0]
        skill_library = self.memory[env_idx].get_recent_history("skill_library", k=1)[0]
        processed_params = {
            "task_description": task_description,
            "previous_reasoning": pre_decision_making_reasoning,
            "previous_action": previous_action,
            "action_code": action_code,
            "ego_minimap": ego_minimap_text,
            "unit_type": unit_type,
            "unit_id": unit_id,
            "scenario_name": scenario_name,
            "ally_task": ally_task,
            "skill_library": skill_library,
            "game_situation": game_situation,
        }
        logger.write(f'> Information summary call...')

        images = self.memory[env_idx].get_recent_history(constants.IMAGES_MEM_BUCKET, config.event_count)
        # reasonings = self.memory[env_idx].get_recent_history('decision_making_reasoning', config.event_count)
        if len(obs_texts) > config.event_count:
            indices = np.linspace(0, len(obs_texts)-1, config.event_count, dtype=int)
            sampled_obs_texts = [obs_texts[i] for i in indices]
            obs_texts = sampled_obs_texts
        obs_texts = [f"Frame {i+1}: {obs_texts[i]}" for i in range(len(obs_texts))] 

        # assert len(images) == len(reasonings)
        image_introduction = [
            {
                "path": images[event_i] if config.use_image else "",
                "assistant": "",
                "introduction": obs_texts[event_i],
            } for event_i in range(len(images))
        ]

        previous_summarization = self.memory[env_idx].get_summarization()
        event_count = str(len(images))

        pre_self_reflection_reasoning = self.memory[env_idx].get_recent_history("pre_self_reflection_reasoning", k=1)[0]
        previous_self_reflection_reasoning = ""
        if pre_self_reflection_reasoning:
            previous_self_reflection_reasoning = self.memory[env_idx].get_recent_history("self_reflection_reasoning", k=1)[0]

        processed_params.update({
            "image_introduction": image_introduction,
            "previous_summarization": previous_summarization,
            "event_count": event_count,
            "historical_lesson": historical_lesson,
            "last_episode_reasoning": last_episode_reasoning,
            "cumulative_reward": last_action_return,
            "previous_self_reflection_reasoning": previous_self_reflection_reasoning,
        })

        self.memory[env_idx].working_area.update(processed_params)

        # processed_params = {
        #     "task_description": task_description,
        # }
        # global_memory.update_info_history(processed_params)

        return processed_params

class RDR2TaskInferencePreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 gm: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gm = gm

    def __call__(self):

        logger.write(f'RDR2 Task Inference Preprocess')

        task_description = memory.get_recent_history("task_description", k=1)[0]
        screenshot_path = memory.get_recent_history(constants.IMAGES_MEM_BUCKET, k=1)[0]

        processed_params = {
            "task_description": task_description,
            constants.IMAGES_MEM_BUCKET: screenshot_path
        }

        # Information summary preparation
        if len(memory.get_recent_history("decision_making_reasoning",
                                         memory.max_recent_steps)) == memory.max_recent_steps:
            logger.write(f'> Information summary call...')

            images = memory.get_recent_history(constants.IMAGES_MEM_BUCKET, config.event_count)
            reasonings = memory.get_recent_history('decision_making_reasoning', config.event_count)

            image_introduction = [
                {
                    "path": images[event_i], "assistant": "",
                    "introduction": 'This is the {} screenshot of recent events. The description of this image: {}'.format(
                        ['first', 'second', 'third', 'fourth', 'fifth'][event_i], reasonings[event_i])
                } for event_i in range(config.event_count)
            ]

            previous_summarization = memory.get_summarization()
            event_count = str(config.event_count)

            processed_params.update({
                "image_introduction": image_introduction,
                "previous_summarization": previous_summarization,
                "event_count": event_count
            })

        memory.working_area.update(processed_params)

        return processed_params

class SMACv2TaskInferencePostprocessProvider(BaseProvider):

    def __init__(self,
                 *args,
                 memory: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.memory = memory
    def __call__(self, response: Dict, env_idx: int):

        logger.write("SMACv2 Task Inference Postprocess for Env {}".format(env_idx))

        processed_response = deepcopy(response)

        if constants.TASK_GUIDANCE not in processed_response or processed_response[constants.TASK_GUIDANCE] == "" or "null" in processed_response[constants.TASK_GUIDANCE].lower():
            processed_response[constants.TASK_GUIDANCE] = self.memory[env_idx].get_recent_history("task_description", k=1)[0]

        if "info_summary" not in response:
            response["info_summary"] = ""

        info_summary = response["info_summary"]

        if "skill_guidance" not in processed_response:
            processed_response["skill_guidance"] = False
        else:
            processed_response["skill_guidance"] = str(processed_response["skill_guidance"]).lower() == "true"

        processed_response.update({
            "summarization": info_summary,
            "task_description": processed_response[constants.TASK_GUIDANCE],
            "skill_guidance": processed_response["skill_guidance"]
        })

        self.memory[env_idx].update_info_history(processed_response)

        global_memory.update_ally_task(env_idx, self.memory[env_idx].agent_id, processed_response[constants.TASK_GUIDANCE])

        return processed_response
    

class RDR2TaskInferencePostprocessProvider(BaseProvider):

    def __init__(self,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, response: Dict):

        logger.write(f'RDR2 Task Inference Postprocess')

        processed_response = deepcopy(response)

        if "info_summary" not in response:
            response["info_summary"] = ""

        info_summary = response["info_summary"]

        processed_response.update({
            "summarization": info_summary
        })

        memory.update_info_history(processed_response)

        return processed_response


class StardewTaskInferencePreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 gm: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gm = gm

    def __call__(self):

        logger.write(f'Stardew Task Inference Preprocess')

        prompts = [
            "This screenshot is the current step of the game. The blue band represents the left side and the yellow band represents the right side."
        ]

        task_description = memory.get_recent_history("task_description", k=1)[0]
        previous_summarization = memory.get_recent_history("summarization", 1)[0]
        substask_description = memory.get_recent_history("subtask_description", 1)[0]
        substask_reasoning = memory.get_recent_history("subtask_reasoning", 1)[0]
        toolbar_information = memory.get_recent_history("toolbar_information", 1)[0]
        images = memory.get_recent_history(constants.AUGMENTED_IMAGES_MEM_BUCKET, 1)
        decision_making_reasoning = memory.get_recent_history('decision_making_reasoning', 1)
        self_reflection_reasoning = memory.get_recent_history('self_reflection_reasoning', 1)

        image_introduction = []
        image_introduction.append(
            {
                "introduction": prompts[-1],
                "path": images,
                "assistant": ""
            })

        processed_params = {
            "image_introduction": image_introduction,
            "previous_summarization": previous_summarization,
            "task_description": task_description,
            "subtask_description": substask_description,
            "subtask_reasoning": substask_reasoning,
            "previous_reasoning": decision_making_reasoning,
            "self_reflection_reasoning": self_reflection_reasoning,
            "toolbar_information": toolbar_information
        }

        memory.working_area.update(processed_params)

        return processed_params


class StardewTaskInferencePostprocessProvider(BaseProvider):

    def __init__(self,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, response: Dict):

        logger.write(f'Stardew Task Inference Postprocess')

        processed_response = deepcopy(response)

        history_summary = response['history_summary']

        subtask_description = response['subtask']
        subtask_reasoning = response['subtask_reasoning']

        processed_response.update({
            'summarization': history_summary,
            'subtask_description': subtask_description,
            'subtask_reasoning': subtask_reasoning
        })

        memory.update_info_history(processed_response)

        return processed_response