from typing import (
    Any,
    List,
    Dict,
    Union,
    Tuple,
    Optional,
)
import os
from PIL.Image import Image
import numpy as np
from threading import Lock

from harl import constants
from harl.common.llm_logger import Logger
from harl.configs.config import Config
from harl.common.memory.base import BaseMemory
from harl.utils.json_utils import load_json, save_json
from harl.utils.singleton import Singleton
from harl.common.base.base_embedding import EmbeddingProvider

import sys

def calculate_memory_usage(obj):
    """Recursively calculate the memory usage of an object."""
    seen_ids = set()

    def inner(o):
        if id(o) in seen_ids:
            return 0
        seen_ids.add(id(o))
        size = sys.getsizeof(o)
        if isinstance(o, dict):
            size += sum(inner(k) + inner(v) for k, v in o.items())
        elif isinstance(o, (list, tuple, set)):
            size += sum(inner(i) for i in o)
        return size

    return inner(obj)

class GlobalMemory(BaseMemory, metaclass=Singleton):

    storage_filename = "global_memory.json"

    def __init__(
        self,
        memory_path: str = '',
        max_recent_steps: int = constants.MAX_RECENT_STEPS,
        embedding_provider: Optional[EmbeddingProvider] = None,
        unit_races = None,
        max_hops: int = 3,
    ) -> None:

        self.logger = Logger()
        self.config = Config()

        self.max_recent_steps = max_recent_steps
        self.memory_path = memory_path
        os.makedirs(self.memory_path, exist_ok=True)

        self.embedding_provider = embedding_provider

        self.unit_races = unit_races

        # Public working space for the agent to store information during loop
        self.working_area: Dict[str, Any] = {}

        self.task_duration = 3
        self.max_hops = max_hops

        # @TODO First memory summary should be based on environment spec
        self.recent_history = [{
            "agent_knowledge": {},
            # "agent_id": {"ally": {}, #{unit_id: {"pos": (x,y), "race": "terran/zerg/protoss", "last_seen": t, "directly": True/False}}
            #              "enemy": {}},
            "agent_task": {},
            
        } for _ in range(self.config.algo_args["train"]["n_rollout_threads"])]
    
    def reset(self, env_id=None) -> None:
        if env_id is not None:
            self.recent_history[env_id] = {
                "agent_knowledge": {},
                "agent_task": {},
            }
        else:
            self.recent_history = [{
                "agent_knowledge": {},
                "agent_task": {},
            } for _ in range(self.config.algo_args["train"]["n_rollout_threads"])]
    
    def reset_knowledge(self):
        for env_id in range(self.config.algo_args["train"]["n_rollout_threads"]):
            self.recent_history[env_id]["agent_knowledge"] = {}
            # self.recent_history[env_id]["agent_task"] = {}
    
    def update_ally_task(self, env_id: int, agent_id: int, task: str) -> None:
        """Update unit position from an agent's perspective and share information"""
            
        # Initialize unit entry if not exists
        # if agent_id not in self.recent_history[env_id]["agent_task"]:
        self.recent_history[env_id]["agent_task"][agent_id] = task
    
    def get_ally_task(self, env_id: int, agent_id: int) -> str:
        #Get ally's task except agent_id to string
        ally_task = ""
        for agent in self.recent_history[env_id]["agent_task"]:
            if agent != agent_id:
                ally_task += f"Agent {agent}: {self.recent_history[env_id]['agent_task'][agent]}\n"
        return ally_task
    
    def update_self_position(self, env_id: int, agent_id: int, own_info, timestep: int) -> None:
        # Initialize unit entry if not exists
        self.recent_history[env_id]["agent_knowledge"][agent_id] = {
            "ally": {},
            "enemy": {},
            "own_info": own_info,
            "timestep": timestep
        }

    def update_unit_position(self, env_id: int, agent_id: int, own_info, unit_id: int, unit_type: str, race: str,
                           relative_pos: Tuple[float, float], health, shield, timestep: int) -> None:
        """Update unit position from an agent's perspective and share information"""
            
        # Initialize unit entry if not exists
        if agent_id not in self.recent_history[env_id]["agent_knowledge"]:
            self.recent_history[env_id]["agent_knowledge"][agent_id] = {
                "ally": {},
                "enemy": {},
                "own_info": own_info,
                "timestep": timestep
            }
        
        # Update position from this agent's perspective
        self.recent_history[env_id]["agent_knowledge"][agent_id][unit_type][unit_id] = {
            "pos": relative_pos,
            "race": race,
            "health": health,
            "shield": shield,
            "last_seen": timestep,
            "directly": True
        }
    
    def share_unit_information_all(self) -> None:
        """Share unit information between agents in all environments."""
        for env_id in range(len(self.recent_history)):
            self.share_unit_information_env(env_id)
        
        memory_usage = calculate_memory_usage(self.recent_history)
        print(f"Memory usage of self.recent_history: {memory_usage} bytes")
    
    def share_unit_information_env(self, env_id: int) -> None:
        # Update shared knowledge for all agents
        all_agents = set(self.recent_history[env_id]["agent_knowledge"].keys())
        for observer_id in all_agents:
            own_info = self.recent_history[env_id]["agent_knowledge"][observer_id]["own_info"]
            for agent_id in all_agents:
                if observer_id == agent_id or agent_id in self.recent_history[env_id]["agent_knowledge"][observer_id]["ally"]:
                    continue
                agent_info = self.recent_history[env_id]["agent_knowledge"][agent_id]["own_info"]
                timestep = self.recent_history[env_id]["agent_knowledge"][agent_id]["timestep"]
                relative_pos = ((agent_info["pos"][0] - own_info["pos"][0])*32/own_info["sight_range"], (agent_info["pos"][1] - own_info["pos"][1])*32/own_info["sight_range"])
                # relative_pos = (round((agent_info["pos"][0] - own_info["pos"][0])*32/own_info["sight_range"],2), round((agent_info["pos"][1] - own_info["pos"][1])*32/own_info["sight_range"]),2)
                self.recent_history[env_id]["agent_knowledge"][observer_id]["ally"][agent_id] = {
                    "pos": relative_pos,
                    "race": agent_info["type"],
                    "health": agent_info["health"],
                    "shield": agent_info["shield"],
                    "last_seen": timestep,
                    "directly": False
                }

            minimap = self.share_unit_information(env_id, observer_id, self.max_hops)
            # Update observer's knowledge based on shared information
            for unit_type, unit_type_info in minimap.items():
                for shared_unit_id, shared_info in unit_type_info.items():
                    if unit_type == "ally" and observer_id == shared_unit_id:
                        continue
                    # Only update if information is newer
                    existing_info = self.recent_history[env_id]["agent_knowledge"][observer_id][unit_type].get(shared_unit_id, {})
                    if not existing_info or shared_info["last_seen"] > existing_info.get("last_seen", 0):
                        self.recent_history[env_id]["agent_knowledge"][observer_id][unit_type][shared_unit_id] = {
                            "pos": shared_info["pos"],
                            "race": shared_info["race"],
                            "health": shared_info["health"],
                            "shield": shared_info["shield"],
                            "last_seen": shared_info["last_seen"],
                            "directly": True if shared_info["hops"] == 0 else False 
                        }

    def share_unit_information(self, env_id: int, agent_id: int, max_hops: int = 2) -> Dict[int, Dict]:
        """Share unit information between agents through multi-hop propagation."""
        minimap = {
            "ally": {},
            "enemy": {}
        }
        visited_sources = set()
        
        def get_allies_in_sight(observer_id: int) -> List[int]:
            """Get all ally IDs that the given observer can see"""
            allies = []
            if observer_id in self.recent_history[env_id]["agent_knowledge"]:
                for unit_id in self.recent_history[env_id]["agent_knowledge"][observer_id]["ally"]:
                    if unit_id != observer_id:  # Exclude self
                        allies.append(unit_id)
            return allies
        
        def combine_relative_positions(pos1: Tuple[float, float], 
                                    pos2: Tuple[float, float],
                                    current_agent) -> Tuple[float, float]:
            """Combine two relative positions"""
            own_sight_range = constants.SMAC_SIGHT_RANGE.get(self.unit_races[env_id][agent_id], 11)
            current_agent_sight_range = constants.SMAC_SIGHT_RANGE.get(self.unit_races[env_id][current_agent], 11)
            return (pos1[0] + pos2[0] * current_agent_sight_range / own_sight_range, pos1[1] + pos2[1] * current_agent_sight_range / own_sight_range)
            # return (round(pos1[0] + pos2[0] * current_agent_sight_range / own_sight_range, 2), round(pos1[1] + pos2[1] * current_agent_sight_range / own_sight_range, 2))
        
        def process_hop(current_agent: int, hop_count: int, 
                    accumulated_pos: Tuple[float, float]) -> None:
            """Recursively process information sharing through multiple hops"""
            if hop_count > max_hops or current_agent in visited_sources:
                return
                
            visited_sources.add(current_agent)
            
            if current_agent in self.recent_history[env_id]["agent_knowledge"]:
                # Process ally units
                for unit_id, unit_info in self.recent_history[env_id]["agent_knowledge"][current_agent]["ally"].items():
                    if unit_id == agent_id:
                        continue
                    final_pos = combine_relative_positions(accumulated_pos, unit_info["pos"], current_agent)
                    if unit_id not in minimap["ally"] or unit_info["last_seen"] > minimap["ally"][unit_id]["last_seen"]:
                        directly = True if hop_count == 0 else False
                        minimap["ally"][unit_id] = {
                            "pos": final_pos,
                            "type": "ally",
                            "race": unit_info["race"],
                            "health": unit_info["health"],
                            "shield": unit_info["shield"],
                            "last_seen": unit_info["last_seen"],
                            "source": current_agent,
                            "hops": hop_count,
                            "directly": directly 
                        }
                
                # Process enemy units
                for unit_id, unit_info in self.recent_history[env_id]["agent_knowledge"][current_agent]["enemy"].items():
                    final_pos = combine_relative_positions(accumulated_pos, unit_info["pos"], current_agent)
                    if unit_id not in minimap["enemy"] or unit_info["last_seen"] > minimap["enemy"][unit_id]["last_seen"]:
                        minimap["enemy"][unit_id] = {
                            "pos": final_pos,
                            "type": "enemy",
                            "race": unit_info["race"],
                            "health": unit_info["health"],
                            "shield": unit_info["shield"],
                            "last_seen": unit_info["last_seen"],
                            "source": current_agent,
                            "hops": hop_count,
                            "directly": True if hop_count == 0 else False   
                        }
            
            # Process next hop through visible allies
            allies = get_allies_in_sight(current_agent)
            for ally_id in allies:
                if ally_id not in visited_sources:
                    ally_pos = self.recent_history[env_id]["agent_knowledge"][current_agent]["ally"][ally_id]["pos"]
                    new_accumulated_pos = combine_relative_positions(accumulated_pos, ally_pos, current_agent)
                    process_hop(ally_id, hop_count + 1, new_accumulated_pos)
        
        # Start processing from the requesting agent
        process_hop(agent_id, 0, (0, 0))
        
        return minimap

    def add_recent_history_kv(
        self,
        key: str,
        info: Any,
    ) -> None:

        """Add recent info (skill/image/reasoning) to memory."""
        if key not in self.recent_history:
            self.recent_history[key] = []

        self.recent_history[key].append(info)

        if len(self.recent_history[key]) > self.max_recent_steps:
            self.recent_history[key].pop(0)


    def add_recent_history(
        self,
        information
    ) -> None:

        """Add recent info to memory."""
        for key, value in information.items():
            if key not in self.recent_history:
                self.recent_history[key] = []
            self.recent_history[key].append(value)

            if len(self.recent_history[key]) > self.max_recent_steps:
                self.recent_history[key].pop(0)


    def get_recent_history(
        self,
        key: str,
        k: int = 1,
    ) -> List[Any]:

        """Query recent info (skill/image/reasoning) from memory."""

        if key not in self.recent_history or len(self.recent_history[key]) == 0:
            return [""]

        if k is None:
            k = 1

        return self.recent_history[key][-k:] if len(self.recent_history[key]) >= k else self.recent_history[key]


    def update_info_history(self, data: Dict[str, Any]):
        self.working_area.update(data)
        self.add_recent_history(data)


    def add_summarization(self, summary: str) -> None:
        self.recent_history[constants.SUMMARIZATION_MEM_BUCKET] = [summary]


    def get_summarization(self) -> str:
        return self.recent_history[constants.SUMMARIZATION_MEM_BUCKET][-1]


    def add_task_guidance(self, task_description: str, long_horizon: bool) -> None:
        self.recent_history[constants.LAST_TASK_GUIDANCE] = task_description
        self.recent_history[constants.LAST_TASK_DURATION] = self.task_duration
        if long_horizon:
            self.recent_history['long_horizon_task'] = task_description


    def get_task_guidance(self, use_last = True) -> str:
        if use_last:
            return self.recent_history[constants.LAST_TASK_GUIDANCE]
        else:
            self.recent_history[constants.LAST_TASK_DURATION] -= 1
            if self.recent_history[constants.LAST_TASK_DURATION] >= 0:
                return self.recent_history[constants.LAST_TASK_GUIDANCE]
            else:
                return self.recent_history['long_horizon_task']
    
    def similarity_search(
        self,
        data: Union[str, Image],
        top_k: int,
        **kwargs: Any,
    ) -> List[Union[str, Image]]:
        """Retrieve the keys from the store.

        Args:
            data: the query data.
            top_k: the number of results to return.
            **kwargs: Other keyword arguments that subclasses might use.

        Returns:
            the corresponding values from the memory.
        """
        if not self.embedding_provider:
            self.logger.error("No embedding provider available for similarity search")
            return [""]
        
        if isinstance(data, Image):
            # Get embedding for query image
            query_embedding = self.embedding_provider.embed_image_query(data)
            
            # Get embeddings for recent images in memory
            recent_images = self.recent_history[constants.IMAGES_MEM_BUCKET]
            if not recent_images:
                return []
                
            # Calculate similarities with recent images
            similarities = []
            for img in recent_images:
                img_embedding = self.embedding_provider.embed_image_query(img)
                similarity = np.dot(query_embedding, img_embedding)
                similarities.append(similarity)
                
            # Get indices of top k similar images
            top_k_indices = np.argsort(similarities)[-top_k:][::-1]
            
            # Return corresponding images and their associated information
            results = []
            for idx in top_k_indices:
                img_info = {
                    'image': recent_images[idx],
                    'action': self.recent_history['action'][idx] if idx < len(self.recent_history['action']) else None,
                    'description': self.recent_history['image_description'][idx] if idx < len(self.recent_history['image_description']) else None,
                    'similarity': similarities[idx]
                }
                results.append(img_info)
                
            return results
        else:
            # Handle text similarity search if needed
            self.logger.error("Text similarity search not implemented")
            return [""]

    def load(self, load_path=None) -> None:
        """Load the memory from the local file."""
        # @TODO load and store whole memory
        if load_path != None:
            if os.path.exists(os.path.join(load_path)):
                self.recent_history = load_json(load_path)
                self.logger.write(f"{load_path} has been loaded.")
            else:
                self.logger.error(f"{load_path} does not exist.")


    def save(self, local_path=None) -> None:
        """Save the memory to the local file."""
        # @TODO load and store whole memory
        if local_path:
            save_json(file_path=local_path, json_dict=self.recent_history, indent=4)
        else:
            save_json(file_path=os.path.join(self.memory_path, self.storage_filename), json_dict=self.recent_history,
                      indent=4)
