from typing import (
    Any,
    List,
    Dict,
    Union,
    Tuple,
    Optional,
)
import os
from PIL.Image import Image
import numpy as np

from harl import constants
from harl.common.llm_logger import Logger
from harl.common.memory.base import BaseMemory
from harl.utils.json_utils import load_json, save_json
from harl.common.base.base_embedding import EmbeddingProvider

logger = Logger()


class LocalMemory(BaseMemory):

    storage_filename = "memory.json"

    def __init__(
        self,
        memory_path: str,
        max_recent_steps: int = constants.MAX_RECENT_STEPS,
        embedding_provider: Optional[EmbeddingProvider] = None,
        agent_id: int = 0,
        unit_type: str = None,
    ) -> None:

        self.max_recent_steps = max_recent_steps
        self.memory_path = memory_path
        self.screenshot_path = os.path.join(self.memory_path, "screenshots")
        os.makedirs(self.screenshot_path, exist_ok=True)

        self.embedding_provider = embedding_provider
        self.agent_id = agent_id
        self.unit_type = unit_type
        # Public working space for the agent to store information during loop
        self.working_area: Dict[str, Any] = {}

        self.task_duration = 3
        self.current_step = 0

        # @TODO First memory summary should be based on environment spec
        self.recent_history = {
            constants.IMAGES_MEM_BUCKET: [],
            constants.AUGMENTED_IMAGES_MEM_BUCKET: [],
            constants.IMAGES_EVENT_BUCKET: [],
            "observation": [],
            "reward": [],
            "dones" : [],
            "action": [],
            "action_error": [],
            "decision_making_reasoning": [],
            "success_detection_reasoning": [],
            "self_reflection_reasoning": [],
            "image_description": [],
            "task_guidance": [],
            "dialogue": [],
            "task_description": [],
            constants.SKIIL_LIB_MEM_BUCKET: [],
            constants.SUMMARIZATION_MEM_BUCKET: ["The controlled agent is cooperating with the other agents to complete the task."],
            constants.LAST_TASK_GUIDANCE: [],
            "long_horizon_task": [],
            "": [self.task_duration],
            constants.KEY_REASON_OF_LAST_ACTION: [],
            constants.SUCCESS_DETECTION: [],
            "historical_lessons": [],
            }

    def clean_memory(self, unit_type) -> None:
        self.recent_history["observation"] = []
        self.recent_history["reward"] = []
        self.recent_history["dones"] = []
        self.recent_history["skill_steps"] = []
        self.recent_history["ego_minimap"] = []
        self.recent_history["game_situation"] = []
        self.recent_history["region_of_interest"] = []
        self.recent_history["reasoning_of_region_of_interest"] = []
        self.recent_history["start_frame_id"] = [0]
        self.recent_history["end_frame_id"] = [0]
        self.recent_history[constants.IMAGES_MEM_BUCKET] = []
        self.recent_history[constants.SHARE_IMAGES_MEM_BUCKET] = []
        self.recent_history[constants.AUGMENTED_IMAGES_MEM_BUCKET] = []
        self.recent_history[constants.IMAGES_EVENT_BUCKET] = []
        self.recent_history[constants.SHARE_IMAGES_EVENT_BUCKET] = []
        self.recent_history["self_reflection_reasoning"] = []
        self.recent_history["pre_self_reflection_reasoning"] = []
        self.recent_history["summarization"] = ["The controlled agent is cooperating with the other agents to complete the task."]
        self.recent_history["decision_making_reasoning"] = [""]
        self.recent_history["pre_decision_making_reasoning"] = [""]
        self.recent_history["pre_action"] = []
        self.recent_history["action"] = []
        self.recent_history["exec_error"] = []
        self.working_area = {}
        self.unit_type = unit_type

    def update_current_step(self, step: int) -> None:
        self.current_step = step

    def add_recent_history_kv(
        self,
        key: str,
        info: Any,
    ) -> None:

        """Add recent info (skill/image/reasoning) to memory."""
        if key not in self.recent_history:
            self.recent_history[key] = []

        self.recent_history[key].append(info)

        # if len(self.recent_history[key]) > self.max_recent_steps:
        #     self.recent_history[key].pop(0)


    def add_recent_history(
        self,
        information
    ) -> None:

        """Add recent info to memory."""
        for key, value in information.items():
            if key not in self.recent_history:
                self.recent_history[key] = []
            self.recent_history[key].append(value)

            # if len(self.recent_history[key]) > self.max_recent_steps:
            #     self.recent_history[key].pop(0)


    def get_recent_history(
        self,
        key: str,
        k: int = 1,
    ) -> List[Any]:

        """Query recent info (skill/image/reasoning) from memory."""

        if key not in self.recent_history or len(self.recent_history[key]) == 0:
            return [""]

        if k is None:
            k = 1

        return self.recent_history[key][-k:] if len(self.recent_history[key]) >= k else self.recent_history[key]
    

    def get_history(
        self,
        key: str,
    ) -> List[Any]:

        """Query recent info (skill/image/reasoning) from memory."""

        if key not in self.recent_history or len(self.recent_history[key]) == 0:
            return [""]

        return self.recent_history[key]


    def update_info_history(self, data: Dict[str, Any]):
        self.working_area.update(data)
        self.add_recent_history(data)


    def add_summarization(self, summary: str) -> None:
        self.recent_history[constants.SUMMARIZATION_MEM_BUCKET] = [summary]


    def get_summarization(self) -> str:
        return self.recent_history[constants.SUMMARIZATION_MEM_BUCKET][-1]
    
    def get_frame_paths_obs(self, start_frame_id, end_frame_id=None):
        if end_frame_id is None or end_frame_id >= len(self.recent_history[constants.IMAGES_MEM_BUCKET]):
            end_frame_id = len(self.recent_history[constants.IMAGES_MEM_BUCKET]) - 1
        frame_paths = self.recent_history[constants.IMAGES_MEM_BUCKET][start_frame_id:end_frame_id+1]
        obs_texts = self.recent_history["observation"][start_frame_id:end_frame_id+1]
        return frame_paths, obs_texts
    
    def get_share_frame_paths(self, start_frame_id, end_frame_id=None):
        if end_frame_id is None or end_frame_id >= len(self.recent_history[constants.SHARE_IMAGES_MEM_BUCKET]):
            end_frame_id = len(self.recent_history[constants.SHARE_IMAGES_MEM_BUCKET]) - 1
        frame_paths = self.recent_history[constants.SHARE_IMAGES_MEM_BUCKET][start_frame_id:end_frame_id+1]
        return frame_paths
    
    def get_last_action_return(self, start_frame_id, end_frame_id=None):
        if end_frame_id is None or end_frame_id >= len(self.recent_history["reward"]):
            end_frame_id = len(self.recent_history["reward"]) - 1
        cumulative_reward = sum(float(reward) for reward in self.recent_history["reward"][start_frame_id:end_frame_id+1])
        return cumulative_reward
    
    def get_last_action_error(self, start_frame_id, end_frame_id=None):
        if end_frame_id is None or end_frame_id >= len(self.recent_history["exec_error"]):
            end_frame_id = len(self.recent_history["exec_error"]) - 1
        last_action_error = '\n'.join(self.recent_history["exec_error"][start_frame_id:end_frame_id+1])
        return last_action_error

    def add_task_guidance(self, task_description: str, long_horizon: bool) -> None:
        self.recent_history[constants.LAST_TASK_GUIDANCE] = task_description
        self.recent_history[constants.LAST_TASK_DURATION] = self.task_duration
        if long_horizon:
            self.recent_history['long_horizon_task'] = task_description


    def get_task_guidance(self, use_last = True) -> str:
        if use_last:
            return self.recent_history[constants.LAST_TASK_GUIDANCE]
        else:
            self.recent_history[constants.LAST_TASK_DURATION] -= 1
            if self.recent_history[constants.LAST_TASK_DURATION] >= 0:
                return self.recent_history[constants.LAST_TASK_GUIDANCE]
            else:
                return self.recent_history['long_horizon_task']
    
    def similarity_search(
        self,
        data: Union[str, Image],
        top_k: int,
        **kwargs: Any,
    ) -> List[Union[str, Image]]:
        """Retrieve the keys from the store.

        Args:
            data: the query data.
            top_k: the number of results to return.
            **kwargs: Other keyword arguments that subclasses might use.

        Returns:
            the corresponding values from the memory.
        """
        if not self.embedding_provider:
            logger.error("No embedding provider available for similarity search")
            return [""]
        
        if isinstance(data, Image):
            # Get embedding for query image
            query_embedding = self.embedding_provider.embed_image_query(data)
            
            # Get embeddings for recent images in memory
            recent_images = self.recent_history[constants.IMAGES_MEM_BUCKET]
            if not recent_images:
                return []
                
            # Calculate similarities with recent images
            similarities = []
            for img in recent_images:
                img_embedding = self.embedding_provider.embed_image_query(img)
                similarity = np.dot(query_embedding, img_embedding)
                similarities.append(similarity)
                
            # Get indices of top k similar images
            top_k_indices = np.argsort(similarities)[-top_k:][::-1]
            
            # Return corresponding images and their associated information
            results = []
            for idx in top_k_indices:
                img_info = {
                    'image': recent_images[idx],
                    'action': self.recent_history['action'][idx] if idx < len(self.recent_history['action']) else None,
                    'description': self.recent_history['image_description'][idx] if idx < len(self.recent_history['image_description']) else None,
                    'similarity': similarities[idx]
                }
                results.append(img_info)
                
            return results
        else:
            # Handle text similarity search if needed
            logger.error("Text similarity search not implemented")
            return [""]

    def load(self, load_path=None) -> None:
        """Load the memory from the local file."""
        # @TODO load and store whole memory
        if load_path != None:
            if os.path.exists(os.path.join(load_path)):
                self.recent_history = load_json(load_path)
                logger.write(f"{load_path} has been loaded.")
            else:
                logger.error(f"{load_path} does not exist.")


    def save(self, local_path=None) -> None:
        """Save the memory to the local file."""
        # @TODO load and store whole memory
        if local_path:
            save_json(file_path=local_path, json_dict=self.recent_history, indent=4)
        else:
            save_json(file_path=os.path.join(self.memory_path, self.storage_filename), json_dict=self.recent_history,
                      indent=4)

    def get_frames(self, start_frame_id, end_frame_id=None):
        frames = []
        for frame in self.recent_history[constants.IMAGES_MEM_BUCKET]:
            if frame[0] >= start_frame_id:
                if end_frame_id is not None and frame[0] > end_frame_id:
                    break
                frames.append(frame)

        return frames
