from typing import Dict, Any
from copy import deepcopy
import numpy as np

from harl import constants
from harl.common.llm_logger import Logger
from harl.common.memory import LocalMemory, GlobalMemory
from harl.common.base.base_provider import BaseProvider
from harl.configs.config import Config

logger = Logger()
config = Config()
global_memory = GlobalMemory()

class SkillGenerationPreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 gm: Any,
                 use_screenshot_augmented = False,
                 use_video = False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gm = gm
        self.use_screenshot_augmented = use_screenshot_augmented
        self.use_video = use_video

    def __call__(self):

        if not self.use_video:
            screenshot_path = memory.get_recent_history(constants.IMAGES_MEM_BUCKET)[-1]
            screenshot_augmnented_path = memory.get_recent_history(constants.AUGMENTED_IMAGES_MEM_BUCKET)[-1]

            if not self.use_screenshot_augmented:
                image_introduction = [
                    {
                        "introduction": "This screenshot is the current step of the game.",
                        "path": screenshot_path,
                        "assistant": ""
                    }
                ]
            else:
                image_introduction = [
                    {
                        "introduction": "This screenshot is the current step of the game.",
                        "path": screenshot_augmnented_path,
                        "assistant": ""
                    }
                ]

            processed_params = {
                "image_introduction": image_introduction
            }

        else:
            images = memory.get_recent_history(constants.IMAGES_MEM_BUCKET, config.event_count)
            reasonings = memory.get_recent_history('decision_making_reasoning', config.event_count)

            image_introduction = [
                {
                    "path": images[event_i],
                    "assistant": "",
                    "introduction": 'This is the {} screenshot of recent events. The description of this image: {}'.format(
                        ['first', 'second', 'third', 'fourth', 'fifth'][event_i], reasonings[event_i])
                } for event_i in range(config.event_count)
            ]

            processed_params = {
                "image_introduction": image_introduction,
                "event_count": config.event_count
            }

        memory.working_area.update(processed_params)

        return processed_params

class SkillGenerationPostprocessProvider(BaseProvider):

    def __init__(self,
                 *args,
                 use_subtask = False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.use_subtask = use_subtask

    def __call__(self, response: Dict):

        processed_response = deepcopy(response)

        subtask_description = processed_response["subtask"]

        processed_response.update({
            "subtask_description": subtask_description
        })

        if not self.use_subtask:
            processed_response_keys = list(processed_response.keys())
            for key in processed_response_keys:
                if "subtask" in key:
                    processed_response.pop(key)

        memory.update_info_history(processed_response)

        return processed_response


class SMACv2SkillGenerationPreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 memory: Any,
                 skill_registry: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.memory = memory
        self.skill_registry = skill_registry

    def __call__(self, env_idx: int):

        logger.write("SMACv2 Skill Generation Preprocess for Env {}".format(env_idx))

        task_description = self.memory[env_idx].get_recent_history("task_description", k=1)[0]
        unit_type = self.memory[env_idx].unit_type
        unit_id = self.memory[env_idx].agent_id
        scenario_name = config.env_args["map_name"]
        win_lose = self.memory[env_idx].get_recent_history("win_lose", k=1)[0]
        historical_lesson = '\n'.join(self.memory[env_idx].get_recent_history("historical_lesson", k=config.max_summarization_num))
        last_episode_reasoning = self.memory[env_idx].get_recent_history("last_episode_reasoning", k=1)[0]
        pre_self_reflection_reasoning = self.memory[env_idx].get_recent_history("pre_self_reflection_reasoning", k=1)[0]
        previous_self_reflection_reasoning = ""
        if pre_self_reflection_reasoning:
            previous_self_reflection_reasoning = self.memory[env_idx].get_recent_history("self_reflection_reasoning", k=1)[0]

        # Minimap info tracking
        ego_minimap = self.memory[env_idx].get_recent_history("ego_minimap", k=1)[0]

        info_summary = self.memory[env_idx].get_recent_history("summarization", k=1)[0]

        skill_library = self.memory[env_idx].get_recent_history("skill_library", k=1)[0]

        processed_params = {
            "task_description": task_description,
            "ego_minimap": ego_minimap,
            "win_lose": win_lose,
            "unit_type": unit_type,
            "unit_id": unit_id,
            "scenario_name": scenario_name,
            "info_summary": info_summary,
            "historical_lesson": historical_lesson,
            "last_episode_reasoning": last_episode_reasoning,
            "previous_self_reflection_reasoning": previous_self_reflection_reasoning,
            "skill_library": skill_library,
        }

        prompts = [
            "The cyan circle shows what your unit can see (sight range), the red dashed circle shows where you can attack/heal (shooting range). Your unit is highlighted in green at the center, with allies (green) and enemies (red) visible within your sight range.\n",
        ]
        start_frame_id = self.memory[env_idx].get_recent_history("start_frame_id", k=1)[0]
        end_frame_id = self.memory[env_idx].get_recent_history("end_frame_id", k=1)[0]

        screenshot_paths = self.memory[env_idx].get_share_frame_paths(start_frame_id, end_frame_id)
        _, obs_texts = self.memory[env_idx].get_frame_paths_obs(start_frame_id, end_frame_id)

        num_frames = len(screenshot_paths)
        if num_frames > config.max_images_in_skill_generation:
            # Keep first and last frame, sample remaining frames
            indices = np.linspace(0, num_frames-1, config.max_images_in_skill_generation, dtype=int)
            sampled_frames = [screenshot_paths[i] for i in indices]
            sampled_obs_texts = [obs_texts[i] for i in indices]
            screenshot_paths = sampled_frames
            obs_texts = sampled_obs_texts
        obs_texts = [f"Frame {i+1}: {obs_texts[i]}" for i in range(len(obs_texts))] 
        image_introduction = [
                        {
                            "introduction": '\n'.join(obs_texts),
                            "path": screenshot_paths if config.use_image else "",
                            "assistant": "",
                            "resolution": "low"
                        }]
        pre_action = self.memory[env_idx].get_recent_history("pre_action", k=1)[0]
        if pre_action:
            pre_action_name, pre_action_params = self.skill_registry.convert_expression_to_skill(pre_action)    
        else:
            pre_action_name, pre_action_params = self.skill_registry.convert_expression_to_skill("race_melee_ranged_medivac_navi_A_star_score_type_default_center(obs='current')")
        
        # only input the pre_action name
        previous_action = pre_action_name
        action_code, action_code_info = self.skill_registry.get_skill_code(pre_action_name)
        action_code = action_code if action_code is not None else action_code_info

        ally_task = global_memory.get_ally_task(env_idx, unit_id)
        game_situation = self.memory[env_idx].get_recent_history("game_situation", k=1)[0]

        processed_params.update({
            "image_introduction": image_introduction,
            "action_code": action_code,
            "ally_task": ally_task,
            "game_situation": game_situation,
        })
    
        # if len(screenshot_paths) > config.max_images_in_self_reflection:
        #     screenshots, obs_, text_intros = [], [], []
        #     # Always include the last image
        #     screenshots.append(screenshot_paths[-1])
        #     obs_.append(obs_texts[-1])
        #     text_intro = "This screenshot is the current step of the game\n"
        #     text_intros.append(text_intro)
            
        #     # Sample remaining images evenly from the rest
        #     remaining_slots = config.max_images_in_self_reflection - 1
        #     if remaining_slots > 0:
        #         step_size = (len(screenshot_paths) - 1) // remaining_slots
        #         for i in range(remaining_slots):
        #             idx = i * step_size
        #             screenshots.insert(i, screenshot_paths[idx])
        #             obs_.insert(i, obs_texts[idx])
        #             text_intro = f"This screenshot is {len(screenshot_paths) - idx - 1} step{'s' if len(screenshot_paths) - idx - 1 > 1 else ''} before the current step of the game\n"
        #             text_intros.insert(i, text_intro)
        #     screenshot_paths = screenshots
        #     obs_texts = obs_
        # else:
        #     text_intros = []
        #     for i in range(len(screenshot_paths)):
        #         if i == len(screenshot_paths) - 1:
        #             text_intro = "This screenshot is the current step of the game\n"
        #         else:
        #             text_intro = f"This screenshot is {len(screenshot_paths) - i - 1} step{'s' if len(screenshot_paths) - i - 1 > 1 else ''} before the current step of the game\n"
        #         text_intros.append(text_intro)

        # get_text_image_introduction = []
        # for i in range(len(screenshot_paths)):
        #     get_text_image_introduction.append(
        #         {
        #             "introduction": text_intros[i] + prompts[0] + obs_texts[i],
        #             "path": screenshot_paths[i],
        #             "assistant": ""
        #         }
        #     )

        # processed_params.update({
        #     "image_introduction": get_text_image_introduction,
        # })

        self.memory[env_idx].working_area.update(processed_params)

        return processed_params
    

class RDR2TaskInferencePreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 gm: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gm = gm

    def __call__(self):

        logger.write(f'RDR2 Task Inference Preprocess')

        task_description = memory.get_recent_history("task_description", k=1)[0]
        screenshot_path = memory.get_recent_history(constants.IMAGES_MEM_BUCKET, k=1)[0]

        processed_params = {
            "task_description": task_description,
            constants.IMAGES_MEM_BUCKET: screenshot_path
        }

        # Information summary preparation
        if len(memory.get_recent_history("decision_making_reasoning",
                                         memory.max_recent_steps)) == memory.max_recent_steps:
            logger.write(f'> Information summary call...')

            images = memory.get_recent_history(constants.IMAGES_MEM_BUCKET, config.event_count)
            reasonings = memory.get_recent_history('decision_making_reasoning', config.event_count)

            image_introduction = [
                {
                    "path": images[event_i], "assistant": "",
                    "introduction": 'This is the {} screenshot of recent events. The description of this image: {}'.format(
                        ['first', 'second', 'third', 'fourth', 'fifth'][event_i], reasonings[event_i])
                } for event_i in range(config.event_count)
            ]

            previous_summarization = memory.get_summarization()
            event_count = str(config.event_count)

            processed_params.update({
                "image_introduction": image_introduction,
                "previous_summarization": previous_summarization,
                "event_count": event_count
            })

        memory.working_area.update(processed_params)

        return processed_params

class SMACv2SkillGenerationPostprocessProvider(BaseProvider):

    def __init__(self,
                 *args,
                 memory: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.memory = memory
    def __call__(self, response: Dict, env_idx: int):

        logger.write("SMACv2 Skill Generation Postprocess for Env {}".format(env_idx))

        processed_response = deepcopy(response)

        if constants.SKILL_GENERATION_MODULE in processed_response:
            all_generated_actions = processed_response[constants.SKILL_GENERATION_MODULE]
        else:
            all_generated_actions = []

        processed_response.update({
            "all_generated_actions": all_generated_actions,
        })

        self.memory[env_idx].update_info_history(processed_response)

        return processed_response
    

class RDR2TaskInferencePostprocessProvider(BaseProvider):

    def __init__(self,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, response: Dict):

        logger.write(f'RDR2 Task Inference Postprocess')

        processed_response = deepcopy(response)

        if "info_summary" not in response:
            response["info_summary"] = ""

        info_summary = response["info_summary"]

        processed_response.update({
            "summarization": info_summary
        })

        memory.update_info_history(processed_response)

        return processed_response


class StardewTaskInferencePreprocessProvider(BaseProvider):

    def __init__(self, *args,
                 gm: Any,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gm = gm

    def __call__(self):

        logger.write(f'Stardew Task Inference Preprocess')

        prompts = [
            "This screenshot is the current step of the game. The blue band represents the left side and the yellow band represents the right side."
        ]

        task_description = memory.get_recent_history("task_description", k=1)[0]
        previous_summarization = memory.get_recent_history("summarization", 1)[0]
        substask_description = memory.get_recent_history("subtask_description", 1)[0]
        substask_reasoning = memory.get_recent_history("subtask_reasoning", 1)[0]
        toolbar_information = memory.get_recent_history("toolbar_information", 1)[0]
        images = memory.get_recent_history(constants.AUGMENTED_IMAGES_MEM_BUCKET, 1)
        decision_making_reasoning = memory.get_recent_history('decision_making_reasoning', 1)
        self_reflection_reasoning = memory.get_recent_history('self_reflection_reasoning', 1)

        image_introduction = []
        image_introduction.append(
            {
                "introduction": prompts[-1],
                "path": images,
                "assistant": ""
            })

        processed_params = {
            "image_introduction": image_introduction,
            "previous_summarization": previous_summarization,
            "task_description": task_description,
            "subtask_description": substask_description,
            "subtask_reasoning": substask_reasoning,
            "previous_reasoning": decision_making_reasoning,
            "self_reflection_reasoning": self_reflection_reasoning,
            "toolbar_information": toolbar_information
        }

        memory.working_area.update(processed_params)

        return processed_params


class StardewTaskInferencePostprocessProvider(BaseProvider):

    def __init__(self,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, response: Dict):

        logger.write(f'Stardew Task Inference Postprocess')

        processed_response = deepcopy(response)

        history_summary = response['history_summary']

        subtask_description = response['subtask']
        subtask_reasoning = response['subtask_reasoning']

        processed_response.update({
            'summarization': history_summary,
            'subtask_description': subtask_description,
            'subtask_reasoning': subtask_reasoning
        })

        memory.update_info_history(processed_response)

        return processed_response