import inspect
import base64
from typing import Type, AnyStr, List, Any, Dict, Tuple
import numpy as np
import importlib
import re
import os
import backoff
import asyncio

from harl import constants
from harl.common.skills.skill_registry import SkillRegistry
from harl.common.skills.skill import Skill
from harl.utils.singleton import Singleton
from harl.common.skills.smacv2 import *
from harl.common.llm_logger import Logger
from harl.utils.skill_utils import parse_obs
from harl.common.skills.smacv2.atomic_actions.move import move_north, move_south, move_east, move_west
from harl.common.skills.smacv2.atomic_actions.combat import attack
from harl.common.skills.smacv2.atomic_actions.heal import heal
from harl.common.skills.smacv2.atomic_actions.basic import stop
from harl.common.skills.smacv2.composite_skills import default_tactic as default_action, find_path
import random

logger = Logger()
SKILLS = {}
def register_skill(name):
    def decorator(skill):

        skill_name = name
        skill_function = skill
        skill_code = inspect.getsource(skill)

        # Remove unnecessary annotation in skill library
        if f"@register_skill(\"{name}\")\n" in skill_code:
            skill_code = skill_code.replace(f"@register_skill(\"{name}\")\n", "")

        skill_code_base64 = base64.b64encode(skill_code.encode('utf-8')).decode('utf-8')

        # Create execution context with required imports
        exec_globals = globals().copy()
        exec_globals.update({
            'move_north': move_north,
            'move_south': move_south, 
            'move_east': move_east,
            'move_west': move_west,
            'attack': attack,
            'heal': heal,
            'stop': stop,
            'default_action': default_action,
            'find_path': find_path,
            'parse_obs': parse_obs,
            'random': random,
            'List': List,
            'Tuple': Tuple,
            'Dict': Dict
        })

        # Execute in prepared context
        exec_locals = {}
        exec(skill_code, exec_globals, exec_locals)
        skill_function = exec_locals[skill_name]  # Get function from locals

        skill_ins = Skill(skill_name,
                       skill_function,
                       "" , # skill_embedding
                       skill_code,
                       skill_code_base64)
        SKILLS[skill_name] = skill_ins

        return skill_ins

    return decorator

class SMACv2SkillRegistry(SkillRegistry, metaclass=Singleton):

    def __init__(self,
                 *args,
                 skill_configs: Dict[str, Any],
                 embedding_provider=None,
                 **kwargs):

        if skill_configs[constants.SKILL_CONFIG_REGISTERED_SKILLS] is None:
            skill_configs[constants.SKILL_CONFIG_REGISTERED_SKILLS] = SKILLS

        super(SMACv2SkillRegistry, self).__init__(skill_configs=skill_configs,
                                                embedding_provider=embedding_provider)
    
    def retrieve_skills(self, query_task: str, skill_num: int, unit_type: str, scenario_name: str =None) -> List[str]:
        skill_num = min(skill_num, len(self.skills))
        # target_skills = []
        target_skills = [skill for skill in self.recent_skills]

        # if unit_type == "medivac":
        #     target_skills = [skill for skill in target_skills if "attack" not in skill]

        # final_query = (
        #     f"{unit_type} control script for {query_task}" 
        # ).lower()
        # query_task = final_query

        task_emb = self.get_embedding(unit_type, query_task)
        
        sorted_skills = sorted(self.skills.items(), key=lambda x: -np.dot(x[1].skill_embedding, task_emb))

        for skill in sorted_skills:

            skill_name, skill = skill

            # if unit_type == "medivac" and ("attack" in skill_name):
            #     continue

            if len(target_skills) >= skill_num:
                break
            else:
                if skill_name not in target_skills:
                    target_skills.append(skill_name)

        self.recent_skills = []

    #    # Add unit-specific required skills
    #     if unit_type == "medivac":
    #         basic_skills = ["heal", "move_north", "move_south", "move_east", "move_west", 
    #                     "stop"]
    #     else:
    #         basic_skills = ["attack", "move_north", "move_south", "move_east", "move_west", 
    #                     "stop"]
    #         # basic_skills = ["attack", "move_north", "move_south", "move_east", "move_west", 
    #         #             "stop", "focus_fire", "kiting_north", "kiting_south", "kiting_east", "kiting_west"]
            
    #     target_skills.extend([skill for skill in basic_skills if skill not in target_skills])


        return target_skills

    def get_skill_information(self,
                              skill_list,
                              skill_library_with_code = False
                              ):

        filtered_skill_library = []

        for skill_name in skill_list:
            skill_item = self.get_from_skill_library(skill_name, skill_library_with_code = skill_library_with_code)
            filtered_skill_library.append(skill_item)

        return filtered_skill_library

    async def execute_skill(self,
                      skill_name: str = "move_north",
                      skill_params: Dict = None):
        try:
            if skill_name in self.skills:
                skill_function = self.skills[skill_name].skill_function
                skill_response = skill_function(**skill_params)
                # try:
                #     coro = asyncio.to_thread(skill_function, **skill_params)
                #     # Create task and wait with timeout
                #     skill_response = await asyncio.wait_for(coro, timeout=10.0)
                # except asyncio.TimeoutError:
                #     raise TimeoutError(f"Skill '{skill_name}' execution timed out after 10 seconds")
            else:
                raise ValueError(f"Function '{skill_name}' not found in the skill library.")
        except Exception as e:
            # logger.error(f"Error executing skill {skill_name}: {str(e)}")
            raise e

        return skill_response
    
    def register_skill_from_code(self, skill_code: str, overwrite = False) -> Tuple[bool, str]:
        """Register the skill function from the code string.

        Args:
            skill_code: the code of skill.
            overwrite: the flag indicates whether to overwrite the skill with the same name or not.

        Returns:
            bool: the true value means that there is no problem in the skill_code. The false value means that we may need to re-generate it.
            str: the detailed information about the bool.
        """
        def lower_func_name(skill_code):
            skill_name, _ = self.convert_code_to_skill_info(skill_code)
            replaced_name = skill_name

            # To make sure the skills in .py files will not be overwritten.
            # The skills not in .py files can still be overwritten.
            if replaced_name in self.skills.keys():
                replaced_name = replaced_name+'_generated'

            return skill_code.replace(skill_name, replaced_name)

        def check_param_description(skill) -> bool:
            docstring = inspect.getdoc(skill)
            if docstring:
                params = inspect.signature(skill).parameters
                if len(params) > 0:
                    for param in params.values():
                        if not re.search(rf"\s+{param.name}.*:\s*([^\n]+)", docstring):
                            return False
                    return True
                else:
                    return True
            else:
                return True

        def check_protection_conflict(skill):
            for word in self.skill_names_allow:
                if word in skill:
                    return True

            for word in self.skill_names_deny:
                if word in skill:
                    return False

            return True

        info = None

        if skill_code.count('(') < 2:
            info = "Skill code contains no functionality."
            logger.error(info)
            return True, info

        skill_code = lower_func_name(skill_code)
        skill_name, _ = self.convert_code_to_skill_info(skill_code)

        # Always avoid adding skills that are ambiguous with existing pre-defined ones.
        if check_protection_conflict(skill_name) == False:
            info = f"Skill '{skill_name}' conflicts with protected skills."
            for word in self.skill_names_deny:
                if word in skill_name:
                    for protected_skill in self.skill_names_basic:
                        if word in protected_skill:
                            self.recent_skills.append(protected_skill)
            logger.write(info)
            return True, info

        if overwrite:
            if skill_name in self.skills:
                self.delete_skill(skill_name)
                logger.write(f"Skill '{skill_name}' will be overwritten.")

        if skill_name in self.skills:
            info = f"Skill '{skill_name}' already exists."
            logger.write(info)
            return True, info

        try:
            # Create execution context
            exec_globals = globals().copy()
            
            # Add required imports
            from harl.common.skills.smacv2.atomic_actions.move import (
                move_north, move_south, move_east, move_west
            )
            from harl.common.skills.smacv2.atomic_actions.combat import attack
            from harl.common.skills.smacv2.atomic_actions.heal import heal
            from harl.common.skills.smacv2.atomic_actions.basic import stop
            from harl.common.skills.smacv2.composite_skills import default_tactic as default_action, find_path
            from harl.utils.skill_utils import parse_obs
            from typing import List, Tuple, Dict
            import random

            # from harl.common.skills.smacv2.composite_skills.attack_control import focus_fire, kiting_north, kiting_south, kiting_east, kiting_west
            
            # Create execution context with required imports
            exec_globals.update({
                'move_north': move_north,
                'move_south': move_south, 
                'move_east': move_east,
                'move_west': move_west,
                'attack': attack,
                'heal': heal,
                'stop': stop,
                'default_action': default_action,
                'find_path': find_path,
                'parse_obs': parse_obs,
                'random': random,
                'List': List,
                'Tuple': Tuple,
                'Dict': Dict
            })

            # Execute in prepared context
            exec_locals = {}
            exec(skill_code, exec_globals, exec_locals)
            skill = exec_locals[skill_name]  # Get function from locals
        except:
            info = "The skill code is invalid."
            logger.error(info)
            return False, info

        if check_param_description(skill) == False:
            info = "The format of parameter description is wrong."
            logger.error(info)
            return False, info

        unit_race = "all"
        for race in constants.SMAC_UNIT_TYPE.keys():
            if race in skill_name.lower():
                unit_race = race
                break
        
        skill_code_base64 = base64.b64encode(skill_code.encode('utf-8')).decode('utf-8')
        skill_ins = Skill(skill_name,
                          skill,
                          self.get_embedding(unit_race, inspect.getdoc(skill)),
                          skill_code,
                          skill_code_base64,
                        #   exec_globals,
                          )
        
        if self.skill_configs["offline"] == True:
            sorted_skills = sorted(self.skills.items(), key=lambda x: -np.dot(x[1].skill_embedding, skill_ins.skill_embedding))
            # Check if skill embeddings are too similar
            if sorted_skills:
                top_skill_name, top_skill = sorted_skills[0]
                similarity = np.dot(top_skill.skill_embedding, skill_ins.skill_embedding)
                if similarity > 0.95:  # Threshold for similarity
                    info = f"Skill '{skill_name}' is too similar to existing skill '{top_skill_name}'"
                    logger.write(info)
                    return False, info

        self.skills[skill_name] = skill_ins
        self.recent_skills.append(skill_name)

        info = f"Skill \033[32m'{skill_name}'\033[0m has been\033[34m registered\033[0m."
        logger.write(info)
        return True, info