import json
import os
from typing import Dict, Any, List
import time
import asyncio
import re
from harl.common.llm_logger import Logger
from harl.common.planner.base import BasePlanner
from harl.utils.check import check_planner_params
from harl.utils.json_utils import parse_semi_formatted_text, JsonFrameStructure
from harl.utils.template_matching import match_templates_images, selection_box_identifier
from harl.utils.file_utils import assemble_project_path, read_resource_file
from harl.utils.json_utils import load_json, parse_semi_formatted_text
from harl.utils.image_utils import process_minimap_targets
from harl.utils.skill_utils import test_skill_code, convert_expression_to_skill
from harl.utils.singleton import Singleton
from harl import constants
from harl.configs.config import Config

logger = Logger()
config = Config()
PROMPT_EXT = ".prompt"
JSON_EXT = ".json"


async def gather_information_get_completion_parallel(llm_provider, text_input_map, current_frame_path, time_stamp,
                                                     text_input, get_text_template, i,video_prefix,gathered_information_JSON):

    logger.write(f"Start gathering text information from the {i + 1}th frame")

    text_input = text_input_map if text_input is None else text_input
    image_introduction = text_input["image_introduction"]

    # Set the last frame path as the current frame path
    image_introduction[-1] = {
        "introduction": image_introduction[-1]["introduction"],
        "path": f"{current_frame_path}",
        "assistant": image_introduction[-1]["assistant"]
    }
    text_input["image_introduction"] = image_introduction
    message_prompts = llm_provider.assemble_prompt(template_str=get_text_template, params=text_input)

    logger.debug(f'{logger.UPSTREAM_MASK}{json.dumps(message_prompts, ensure_ascii=False)}\n')

    success_flag = False
    while not success_flag:
        try:
            response, info = await llm_provider.create_completion_async(message_prompts)
            logger.debug(f'{logger.DOWNSTREAM_MASK}{response}\n')

            # Convert the response to dict
            processed_response = parse_semi_formatted_text(response)
            success_flag = True
        except Exception as e:
            logger.error(f"Response is not in the correct format: {e}, retrying...")
            success_flag = False

            # wait 2 seconds for the next request and retry
            await asyncio.sleep(2)

    # Convert the response to dict
    if processed_response is None or len(response) == 0:
        logger.warn('Empty response in gather text information call')
        logger.debug("response", response, "processed_response", processed_response)

    objects = processed_response
    objects_index = str(video_prefix) + '_' + str(time_stamp)
    gathered_information_JSON.add_instance(objects_index, objects)
    logger.write(f"Finish gathering text information from the {i + 1}th frame")

    return True


def gather_information_get_completion_sequence(llm_provider, text_input_map, current_frame_path, time_stamp,
                                               text_input, get_text_template, i, video_prefix, gathered_information_JSON):

    logger.write(f"Start gathering text information from the {i + 1}th frame")
    text_input = text_input_map if text_input is None else text_input

    image_introduction = text_input["image_introduction"]

    # Set the last frame path as the current frame path
    image_introduction[-1] = {
        "introduction": image_introduction[-1]["introduction"],
        "path": f"{current_frame_path}",
        "assistant": image_introduction[-1]["assistant"]
    }
    text_input["image_introduction"] = image_introduction

    message_prompts = llm_provider.assemble_prompt(template_str=get_text_template, params=text_input)

    logger.debug(f'{logger.UPSTREAM_MASK}{json.dumps(message_prompts, ensure_ascii=False)}\n')

    response, info = llm_provider.create_completion(message_prompts)

    logger.debug(f'{logger.DOWNSTREAM_MASK}{response}\n')
    success_flag = False
    while not success_flag:
        try:
            # Convert the response to dict
            processed_response = parse_semi_formatted_text(response)
            success_flag = True
        except Exception as e:
            logger.error(f"Response is not in the correct format: {e}, retrying...")
            success_flag = False

            time.sleep(2)

    # Convert the response to dict
    if processed_response is None or len(response) == 0:
        logger.warn('Empty response in gather text information call')
        logger.debug("response", response, "processed_response", processed_response)

    objects = processed_response
    objects_index = str(video_prefix) + '_' + time_stamp
    gathered_information_JSON.add_instance(objects_index, objects)

    logger.write(f"Finish gathering text information from the {i + 1}th frame")

    return True


async def get_completion_in_parallel(llm_provider, text_input_map, extracted_frame_paths, text_input,get_text_template,video_prefix,gathered_information_JSON):
    tasks =[]

    for i, (current_frame_path, time_stamp) in enumerate(extracted_frame_paths):

        task=gather_information_get_completion_parallel(llm_provider, text_input_map, current_frame_path, time_stamp,
                                                   text_input, get_text_template, i,video_prefix,gathered_information_JSON)

        tasks.append(task)

        # wait 2 seconds for the next request
        time.sleep(2)

    return await asyncio.gather(*tasks)


async def get_completion_in_parallel_tool(
        llm_provider,
        text_input_map,
        extracted_frame_paths,
        inventory_names,
        text_input,
        get_text_template,
        video_prefix,
        gathered_information_JSON,
):
    tasks = []

    for i, (current_frame_path) in enumerate(extracted_frame_paths):
        inventory_name = inventory_names[i]
        text_input["image_introduction"][0]["inventory_name"] = inventory_name
        task = gather_information_get_completion_parallel(
            llm_provider,
            text_input_map,
            current_frame_path,
            i,
            text_input,
            get_text_template,
            i,
            video_prefix,
            gathered_information_JSON,
        )

        tasks.append(task)

        # wait 2 seconds for the next request
        time.sleep(2)

    return await asyncio.gather(*tasks)


def get_completion_in_sequence(llm_provider, text_input_map, extracted_frame_paths, text_input, get_text_template,
                               video_prefix, gathered_information_JSON):

    for i, (current_frame_path, time_stamp) in enumerate(extracted_frame_paths):
        gather_information_get_completion_sequence(llm_provider, text_input_map, current_frame_path, time_stamp,
                                                   text_input, get_text_template, i,video_prefix,gathered_information_JSON)

    return True


class InformationGathering():

    def __init__(
            self,
            input_map: Dict = None,
            template: str = None,
            object_detector: Any = None,
            search_engine: Any = None,
            llm_provider: Any = None,
            query_input_map: Dict = None,
            get_query_template: str = None,
    ):

        self.input_map = input_map
        self.template = template
        self.object_detector = object_detector
        self.search_engine = search_engine
        self.llm_provider = llm_provider
        self.query_input_map = query_input_map
        self.get_query_template = get_query_template


    def _pre(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return input


    async def __call__(self, *args, input: Dict[str, Any] = None, class_=None, **kwargs) -> Dict[str, Any]:

        gather_information_configurations = input["gather_information_configurations"]

        object_detector_gathered_information = None
        llm_description_gathered_information = None
        web_search_gathered_information = None

        input = self.input_map if input is None else input
        input = self._pre(input=input)

        ego_minimap = input["ego_minimap"]

        # image_files: List[str] = []
        # if "image_introduction" in input.keys():
        #     for image_info in input["image_introduction"]:
        #         image_files.append(image_info["path"])

        flag = True
        processed_response = {}

        # Gather information by web search
        if gather_information_configurations[constants.WEB_SEARCH] is True:
            logger.write(f"Using web search to gather information")

            # try:
            #     gather_information_success_flag = False
            #     # while gather_information_success_flag is False:
            try:
                # Call the LLM provider for gather information json
                message_prompts = self.llm_provider.assemble_prompt(template_str=self.get_query_template, params=input)

                logger.debug(f'{logger.UPSTREAM_MASK}{json.dumps(message_prompts, ensure_ascii=False)}\n')

                generated_query_response, info = await asyncio.run(self.llm_provider.create_completion_async(message_prompts))

                # Convert the response to dict
                generated_queries = parse_semi_formatted_text(generated_query_response)
                # generated_queries_list = []
                # for i in range(1, config.max_keywords + 1):
                #     key = f"Target_query_name_{i}"
                #     if key in generated_queries:
                #         generated_queries_list.append(generated_queries[key])
                # generated_queries = generated_queries_list
                generated_queries = generated_queries['recommended_tactics'].split(',')[:config.max_keywords]

                logger.write(f"Generated queries: {generated_queries}")

                search_results = [self.search_engine.search(query) for query in generated_queries] 

                logger.write(f"The search results: {search_results}")

                distill_prompt = self.search_engine.get_distill_websearch_prompt(
                    question=input["task_description"], query='.\n'.join(generated_queries), results='.\n'.join(search_results))

                response, info = await asyncio.run(self.llm_provider.create_completion_async(distill_prompt))

                web_search_response = parse_semi_formatted_text(response)

                web_search_gathered_information = web_search_response['summarization_websearch']

                gather_information_success_flag = True

            except Exception as e:
                logger.error(f"Response of queries is not in the correct format: {e}, retrying...")
                gather_information_success_flag = False

                # # Wait 2 seconds for the next request and retry
                # time.sleep(2)

            # except Exception as e:
            #     logger.error(f"Error in gather information by web search: {e}")

        # Gather information by LLM provider
        if gather_information_configurations["llm_description"] is True:

            logger.write(f"Using LLM description to gather information")

            # try:

            #     gather_information_success_flag = False
            #     while gather_information_success_flag is False:
            try:
                message_prompts = self.llm_provider.assemble_prompt(template_str=self.template, params=input)
                logger.debug(f'{logger.UPSTREAM_MASK}{json.dumps(message_prompts, ensure_ascii=False)}\n')
                # Get current loop or create new one
                response, info = await self.llm_provider.create_completion_async(message_prompts)
                # response, info = self.llm_provider.create_completion(message_prompts)
                # response, info = asyncio.run(self.llm_provider.create_completion_async(message_prompts))
                logger.debug(f'{logger.DOWNSTREAM_MASK}{response}\n')

                # Convert the response to dict
                processed_response = parse_semi_formatted_text(response)
                gather_information_success_flag = True if processed_response['region_of_interest'] != '' else False

            except Exception as e:
                logger.error(f"Response of region of interest is not in the correct format: {e}, retrying...")
                gather_information_success_flag = False
                processed_response['region_of_interest'] = ''

                # # Wait 2 seconds for the next request and retry
                # time.sleep(2)

            #     llm_description_gathered_information = processed_response

            # except Exception as e:
            #     logger.error(f"Error in gather region of interest information: {e}")
            #     flag = False

        # Assemble the gathered_information_JSON
        # if flag:
        #     objects = []

        #     if llm_description_gathered_information is not None and "target_unit_name" in llm_description_gathered_information:
        #         objects.append(llm_description_gathered_information["target_unit_name"])

        #     processed_response["objects"] = objects

        # # Gather information by object detector, which is grounding dino.
        # if gather_information_configurations[constants.OBJECT_DETECTOR] is True:
        #     logger.write(f"Using object detector to gather information")
        #     if self.object_detector is not None:
        #         try:
        #             target_object_name = processed_response[constants.TARGET_UNIT_NAME].lower() \
        #                 if constants.NONE_TARGET_OBJECT_OUTPUT not in processed_response[constants.TARGET_UNIT_NAME].lower() else ""

        #             image_source, boxes, logits, phrases = self.object_detector.detect(image_path=image_files[-1],
        #                                                                                text_prompt=target_object_name,
        #                                                                                box_threshold=0.3, device='cuda')
        #             processed_response["boxes"] = boxes
        #             processed_response["logits"] = logits
        #             processed_response["phrases"] = [self.map_unit_description(phrase) for phrase in phrases]
        #         except Exception as e:
        #             logger.error(f"Error in gather information by object detector: {e}")
        #             flag = False
        
        processed_response.update({"ego_minimap":ego_minimap})

        if web_search_gathered_information is not None:
            processed_response.update({constants.WEB_SEARCH:web_search_gathered_information})

        success = self._check_success(data=processed_response)

        data = dict(
            flag=flag,
            success=success,
            input=input,
            res_dict=processed_response,
        )

        data = self._post(data=data)

        return data
    
    def map_unit_description(self, phrase: str) -> str:
        """Map unit description to standardized format."""
        # Color mappings
        color_maps = {
            "green": "ally",
            "red": "enemy",
            "blue": "ally"
        }
        
        # Clean up common patterns
        phrase = phrase.lower()
        phrase = re.sub(r'round|&|text', '', phrase).strip()
        
        # Split into color and unit type
        parts = phrase.split()
        if len(parts) >= 2:
            color = parts[0]
            unit_type = ' '.join(parts[1:])
            
            # Map color to ally/enemy
            prefix = color_maps.get(color, color)
            return f"{prefix} {unit_type}"
        
        return phrase


    def _post(self, *args, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return data


    def _check_success(self, *args, data, **kwargs):

        success = False

        prop_name = "image_description"

        if prop_name in data.keys():
            desc = data[prop_name]
            success = desc is not None and len(desc) > 0
        return success


    def _replace_icon(self, extracted_frame_paths):
        extracted_frames = [frame[0] for frame in extracted_frame_paths]
        extracted_timesteps = [frame[1] for frame in extracted_frame_paths]
        extracted_frames = self.icon_replacer(image_paths=extracted_frames)
        extracted_frame_paths = list(zip(extracted_frames, extracted_timesteps))
        return extracted_frame_paths


    def gather_information_of_new_icon(self, cur_new_icon_image_shot_path, cur_new_icon_name_image_shot_path):
        # if there is a new icon in the screenshot, save it to the workdir for later template matching

        # request the llm to decide if there is a new icon and get the name of the new icon

        # if there is a new icon, rename it with the name in LLM response

        # if LLM response is empty, delete the new icon images

        # return the list of icon paths
        pass


    def template_matching_for_current_toolbar(self, sr_file_list, base_template_file_list, work_template_file_list):

        matching_dict = match_templates_images(sr_file_list, base_template_file_list, work_template_file_list)
        selected_position=None

        for sr_file in sr_file_list:
            is_selected=selection_box_identifier(sr_file,config.selection_box_region)
            if is_selected:
                selected_position=sr_file_list.index(sr_file)+1
                break

        for key in matching_dict:
            matching_dict[key] = os.path.splitext(os.path.basename(matching_dict[key]))[0]

        return matching_dict,selected_position

class SuccessDetection():

    def __init__(self,
                 input_map: Dict = None,
                 template: Dict = None,
                 llm_provider: Any = None,
                 ):
        self.input_map = input_map
        self.template = template
        self.llm_provider = llm_provider


    def _pre(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return input


    def __call__(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:

        input = self.input_map if input is None else input
        input = self._pre(input=input)

        flag = True
        processed_response = {}

        try:

            # Call the LLM provider for success detection
            message_prompts = self.llm_provider.assemble_prompt(template_str=self.template, params=input)

            logger.debug(f'{logger.UPSTREAM_MASK}{json.dumps(message_prompts, ensure_ascii=False)}\n')

            response, info = self.llm_provider.create_completion(message_prompts)

            logger.debug(f'{logger.DOWNSTREAM_MASK}{response}\n')

            # Convert the response to dict
            processed_response = parse_semi_formatted_text(response)

        except Exception as e:
            logger.error(f"Error in success detection: {e}")
            flag = False

        data = dict(
            flag=flag,
            input=input,
            res_dict=processed_response,
        )

        data = self._post(data=data)
        return data


    def _post(self, *args, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return data


class SelfReflection():

    def __init__(self,
                 input_map: Dict = None,
                 template: Dict = None,
                 llm_provider: Any = None,
                 ):
        self.input_map = input_map
        self.template = template
        self.llm_provider = llm_provider


    def _pre(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return input


    async def __call__(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:

        input = self.input_map if input is None else input
        input = self._pre(input=input)

        flag = True
        processed_response = {}

        try:

            # Call the LLM provider for self reflection
            message_prompts = self.llm_provider.assemble_prompt(template_str=self.template, params=input)

            logger.debug(f'{logger.UPSTREAM_MASK}{json.dumps(message_prompts, ensure_ascii=False)}\n')

            response, info = await self.llm_provider.create_completion_async(message_prompts)

            # response, info = self.llm_provider.create_completion(message_prompts)

            logger.debug(f'{logger.DOWNSTREAM_MASK}{response}\n')

            # Convert the response to dict
            processed_response = parse_semi_formatted_text(response)

        except Exception as e:
            logger.error(f"Error in self reflection: {e}")
            flag = False

        data = dict(
            flag=flag,
            input=input,
            res_dict=processed_response,
        )

        data = self._post(data=data)
        return data


    def _post(self, *args, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return data
    
class TaskInference():

    def __init__(self,
                 input_map: Dict = None,
                 template: Dict = None,
                 llm_provider: Any = None,
                 ):

        self.input_map = input_map
        self.template = template
        self.llm_provider = llm_provider


    def _pre(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return input


    async def __call__(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:

        input = self.input_map if input is None else input
        input = self._pre(input=input)

        flag = True
        processed_response = {}

        try:

            # Call the LLM provider for information summary
            message_prompts = self.llm_provider.assemble_prompt(template_str=self.template, params=input)

            logger.debug(f'{logger.UPSTREAM_MASK}{json.dumps(message_prompts, ensure_ascii=False)}\n')

            response, info = await self.llm_provider.create_completion_async(message_prompts)

            # response, info = self.llm_provider.create_completion(message_prompts)

            logger.debug(f'{logger.DOWNSTREAM_MASK}{response}\n')

            # Convert the response to dict
            processed_response = parse_semi_formatted_text(response)

            processed_response['info_summary'] = self.combine_response_values(processed_response)

        except Exception as e:
            logger.error(f"Error in task inference: {e}")
            flag = False

        data = dict(
            flag=flag,
            input=input,
            res_dict=processed_response,
            # res_json=res_json,
        )

        data = self._post(data=data)
        return data
    
    def combine_response_values(self, processed_response):
        if not processed_response.get('info_summary') or len(processed_response['info_summary']) < 20:
            # Get all values except task_guidance, filter out empty strings
            values = [
                str(v) for k, v in processed_response.items() 
                if k != 'task_guidance' and v and v.strip()
            ]
            # Join non-empty values
            return ' '.join(values)
        return processed_response['info_summary']

    def _post(self, *args, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return data
    
class SkillGeneration():

    def __init__(self,
                 input_map: Dict = None,
                 template: Dict = None,
                 llm_provider: Any = None,
                 ):

        self.input_map = input_map
        self.template = template
        self.llm_provider = llm_provider


    def _pre(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return input


    async def __call__(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:

        input = self.input_map if input is None else input
        input = self._pre(input=input)

        flag = False
        processed_response = {}
        error_info = ""
        valid_skills = []
        for _ in range(config.max_retries):
            try:
                input['error_info'] = error_info
                # Call the LLM provider for information summary
                message_prompts = self.llm_provider.assemble_prompt(template_str=self.template, params=input)

                logger.debug(f'{logger.UPSTREAM_MASK}{json.dumps(message_prompts, ensure_ascii=False)}\n')

                response, info = await self.llm_provider.create_completion_async(message_prompts)

                # response, info = self.llm_provider.create_completion(message_prompts)

                logger.debug(f'{logger.DOWNSTREAM_MASK}{response}\n')

                # Convert the response to dict
                processed_response = parse_semi_formatted_text(response)
                replaced_name = None
                if 'replaced_name' in processed_response:
                    replaced_name = processed_response['replaced_name']
                updated_docstring = None
                if 'updated_docstring' in processed_response:
                    updated_docstring = processed_response['updated_docstring']
                key_strategies = None
                if 'key_strategies' in processed_response:
                    key_strategies = processed_response['key_strategies']
                args = 'obs (str): Observation string containing game state'
                # combine updated_docstring key_strategies args into updated_docstring
                # Key strategies: {key_strategies}  Args: {args}
                if updated_docstring is not None:
                    if key_strategies is not None:
                        updated_docstring = f"{updated_docstring}\nKey strategies:\n    {key_strategies}"
                    updated_docstring = f"{updated_docstring}\nArgs:\n    {args}"
                
                if constants.SKILL_GENERATION_MODULE in processed_response:
                    all_generated_actions = processed_response[constants.SKILL_GENERATION_MODULE]
                    if not all_generated_actions or all_generated_actions[0]['code'] == '':
                        break
                    success = False
                    error_info = ""
                    for extracted_skills in all_generated_actions:
                        if extracted_skills['code'] == '':
                            continue
                        flag, info, skill_code = test_skill_code(skill_code=extracted_skills['code'], original_code=input['action_code'], replaced_name=replaced_name, updated_docstring=updated_docstring)
                        if flag:
                            extracted_skills['code'] = skill_code
                            valid_skills.append(extracted_skills)
                            success = True
                        else:
                            error_info += info + "\n"
                    if not success:
                        raise ValueError(error_info)
                break

            except Exception as e:
                logger.error(f"Response of skill generation is not in the correct format: {e}, retrying...")
                valid_skills = []
                flag = False

                time.sleep(2)
            
        processed_response[constants.SKILL_GENERATION_MODULE] = valid_skills

        data = dict(
            flag=flag,
            input=input,
            res_dict=processed_response,
            # res_json=res_json,
        )

        data = self._post(data=data)
        return data
        
    def _post(self, *args, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return data
    
class SkillRefine():

    def __init__(self,
                 input_map: Dict = None,
                 template: Dict = None,
                 llm_provider: Any = None,
                 ):

        self.input_map = input_map
        self.template = template
        self.llm_provider = llm_provider


    def _pre(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return input


    async def __call__(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:

        input = self.input_map if input is None else input
        input = self._pre(input=input)

        flag = False
        processed_response = {}
        error_info = ""
        valid_skills = []
        for _ in range(config.max_retries):
            try:
                input['error_info'] = error_info
                # Call the LLM provider for information summary
                message_prompts = self.llm_provider.assemble_prompt(template_str=self.template, params=input)

                logger.debug(f'{logger.UPSTREAM_MASK}{json.dumps(message_prompts, ensure_ascii=False)}\n')

                response, info = await self.llm_provider.create_completion_async(message_prompts)

                # response, info = self.llm_provider.create_completion(message_prompts)

                logger.debug(f'{logger.DOWNSTREAM_MASK}{response}\n')

                # Convert the response to dict
                processed_response = parse_semi_formatted_text(response)
                if constants.SKILL_GENERATION_MODULE in processed_response:
                    all_generated_actions = processed_response[constants.SKILL_GENERATION_MODULE]
                    if all_generated_actions[0]['code'] == '':
                        break
                    success = False
                    error_info = ""
                    for extracted_skills in all_generated_actions:
                        if extracted_skills['code'] == '':
                            continue
                        flag, info, skill_code = test_skill_code(skill_code=extracted_skills['code'], original_code=input['action_code'])
                        if flag:
                            extracted_skills['code'] = skill_code
                            valid_skills.append(extracted_skills)
                            success = True
                        else:
                            error_info += info + "\n"
                    if not success:
                        raise ValueError(error_info)
                break

            except Exception as e:
                logger.error(f"Response of skill refine is not in the correct format: {e}, retrying...")
                flag = False
                time.sleep(2)
            
        processed_response[constants.SKILL_GENERATION_MODULE] = valid_skills

        data = dict(
            flag=flag,
            input=input,
            res_dict=processed_response,
            # res_json=res_json,
        )

        data = self._post(data=data)
        return data
        
    def _post(self, *args, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return data
    
class ActionPlanning():

    def __init__(self,
                 input_map: Dict = None,
                 template: Dict = None,
                 llm_provider: Any = None,
                 ):

        self.input_map = input_map
        self.template = template
        self.llm_provider = llm_provider


    def _pre(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return input


    async def __call__(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:

        input = self.input_map if input is None else input
        input = self._pre(input=input)

        flag = True
        processed_response = {}

        action_planning_success_flag = False
        error_info = ""
        for _ in range(config.max_retries):
            try:
                input['error_info'] = error_info
                message_prompts = self.llm_provider.assemble_prompt(template_str=self.template, params=input)

                logger.debug(f'{logger.UPSTREAM_MASK}{json.dumps(message_prompts, ensure_ascii=False)}\n')

                # Call the LLM provider for decision making
                response, info = await self.llm_provider.create_completion_async(message_prompts)

                # response, info = self.llm_provider.create_completion(message_prompts)

                logger.debug(f'{logger.DOWNSTREAM_MASK}{response}\n')

                if response is None or len(response) == 0:
                    logger.warn('No response in decision making call')
                    logger.debug(input)

                # Convert the response to dict
                processed_response = parse_semi_formatted_text(response)

                skill_steps = processed_response.get('skills', [])
                
                skill_steps = [i for i in skill_steps if i != '']
                skill_steps = skill_steps[:config.number_of_execute_skills]

                # # Debug
                # skill_steps = ["race_medivac_melee_ranged_navi_A_star_score_type_default_center(obs='current')"]
                # processed_response['skills'] = skill_steps

                obs = input.get("observation", "")
        
                # 7. ego_minimap
                ego_minimap = input.get("ego_minimap", "")
                ego_minimap = ego_minimap.replace("\n\n", "\n")
                obs += f"\n7. Ego Minimap:\n"
                obs += f"{ego_minimap}\n"

                # 8. region_of_interest
                region_of_interest = input.get("region_of_interest", "")
                obs += f"8. Region of Interest:\n"
                obs += f"{region_of_interest}\n"

                for skill in skill_steps:
                    if skill == '':
                        raise ValueError("Empty skill")
                    skill_name, _ = convert_expression_to_skill(skill)
                    # Debug
                    # skill_name = "race_medivac_melee_ranged_navi_A_star_score_type_default_center"
                    skill_params = {"obs": obs}
                    if kwargs.get('skill_registry') is not None:
                        skill_function = kwargs['skill_registry'].skills[skill_name].skill_function
                        skill_function(**skill_params)

                # skill_steps = processed_response['skills']

                action_planning_success_flag = True
                break

            except Exception as e:

                logger.error(f"Response of action planning is not in the correct format: {type(e).__name__}: {str(e)}, retrying...")
                # logger.error_ex(e)
                error_info = f"Undefined Skill: {str(e)}" if isinstance(e, KeyError) else f"{type(e).__name__}: {str(e)}"
                action_planning_success_flag = False
                skill_steps = ["race_melee_ranged_medivac_navi_A_star_score_type_default_center(obs='current')"]
                processed_response['skills'] = skill_steps
                flag = False
                time.sleep(2)

        data = dict(
            flag=flag,
            input=input,
            res_dict=processed_response,
        )

        data = self._post(data=data)
        return data


    def _post(self, *args, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
        return data
    
class SMACv2Planner(BasePlanner):

    def __init__(self,
                 llm_provider: Any = None,
                 planner_params: Dict = None,
                 use_task_inference: bool = False,
                 use_self_reflection: bool = False,
                 information_gathering_max_steps: int = 1,  # 5,
                 object_detector: Any = None,
                 search_engine: Any = None,
                 ):
        """
        inputs: input key-value pairs
        templates: template for composing the prompt
        """

        super(BasePlanner, self).__init__()

        self.llm_provider = llm_provider

        self.use_task_inference = use_task_inference
        self.use_self_reflection = use_self_reflection
        self.information_gathering_max_steps = information_gathering_max_steps

        self.object_detector = object_detector
        self.search_engine = search_engine
        self.set_internal_params(planner_params=planner_params,
                                 use_task_inference=use_task_inference)


    # Allow re-configuring planner
    def set_internal_params(self,
                            planner_params: Dict = None,
                            use_task_inference: bool = False):

        self.planner_params = planner_params
        if not check_planner_params(self.planner_params):
            raise ValueError(f"Error in planner_params: {self.planner_params}")

        self.inputs = self._init_inputs()
        self.templates = self._init_templates()

        self.information_gathering_ = InformationGathering(input_map=self.inputs[constants.INFORMATION_GATHERING_MODULE],
                                                           template=self.templates[constants.INFORMATION_GATHERING_MODULE],
                                                           query_input_map=self.inputs[constants.INFORMATION_QUERY_GATHERING_MODULE],
                                                           get_query_template=self.templates[constants.INFORMATION_QUERY_GATHERING_MODULE],
                                                           object_detector=self.object_detector,
                                                           search_engine=self.search_engine,
                                                           llm_provider=self.llm_provider)

        self.action_planning_ = ActionPlanning(input_map=self.inputs[constants.ACTION_PLANNING_MODULE],
                                               template=self.templates[constants.ACTION_PLANNING_MODULE],
                                               llm_provider=self.llm_provider)

        self.success_detection_ = SuccessDetection(input_map=self.inputs["success_detection"],
                                                   template=self.templates["success_detection"],
                                                   llm_provider=self.llm_provider)
        
        self.skill_generation_ = SkillGeneration(input_map=self.inputs[constants.SKILL_GENERATION_MODULE],
                                                 template=self.templates[constants.SKILL_GENERATION_MODULE],
                                                 llm_provider=self.llm_provider)
        
        self.skill_refine_ = SkillRefine(input_map=self.inputs[constants.SKILL_REFINE_MODULE],
                                                 template=self.templates[constants.SKILL_REFINE_MODULE],
                                                 llm_provider=self.llm_provider)

        if self.use_self_reflection:
            self.self_reflection_ = SelfReflection(input_map=self.inputs[constants.SELF_REFLECTION_MODULE],
                                                   template=self.templates[constants.SELF_REFLECTION_MODULE],
                                                   llm_provider=self.llm_provider)
        else:
            self.self_reflection_ = None

        if use_task_inference:
            self.task_inference_ = TaskInference(input_map=self.inputs[constants.TASK_INFERENCE_MODULE],
                                                 template=self.templates[constants.TASK_INFERENCE_MODULE],
                                                 llm_provider=self.llm_provider)
        else:
            self.task_inference_ = None


    def _init_inputs(self):

        input_examples = dict()
        prompt_paths = self.planner_params["prompt_paths"]
        input_example_paths = prompt_paths["inputs"]

        for key, value in input_example_paths.items():
            path = assemble_project_path(value)
            if path.endswith(PROMPT_EXT):
                input_examples[key] = read_resource_file(path)
            else:
                if path is not None and path.endswith(JSON_EXT):
                    input_examples[key] = load_json(path)
                else:
                    input_examples[key] = dict()

        return input_examples


    def _init_templates(self):

        templates = dict()
        prompt_paths = self.planner_params["prompt_paths"]
        template_paths = prompt_paths["templates"]

        for key, value in template_paths.items():
            path = assemble_project_path(value)
            if path.endswith(PROMPT_EXT):
                templates[key] = read_resource_file(path)
            else:
                if path is not None and path.endswith(JSON_EXT):
                    templates[key] = load_json(path)
                else:
                    templates[key] = dict()

        return templates


    async def information_gathering(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:

        if input is None:
            input = self.inputs[constants.INFORMATION_GATHERING_MODULE]

        # image_file = input["image_introduction"][0]["path"]

        for i in range(self.information_gathering_max_steps):
            data = await self.information_gathering_(input=input, class_=None)

            success = data["success"]

            if success:
                break

        return data


    async def action_planning(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:

        if input is None:
            input = self.inputs[constants.ACTION_PLANNING_MODULE]

        data = await self.action_planning_(input=input, skill_registry=kwargs.get('skill_registry', None))

        return data


    async def success_detection(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:

        if input is None:
            input = self.inputs["success_detection"]

        data = await self.success_detection_(input=input)

        return data


    async def self_reflection(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:

        if input is None:
            input = self.inputs[constants.SELF_REFLECTION_MODULE]

        data = await self.self_reflection_(input=input)

        return data


    async def task_inference(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:

        if input is None:
            input = self.inputs[constants.TASK_INFERENCE_MODULE]

        data = await self.task_inference_(input=input)

        return data
    
    async def skill_generation(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:

        if input is None:
            input = self.inputs[constants.SKILL_GENERATION_MODULE]

        data = await self.skill_generation_(input=input)

        return data
    
    async def skill_refine(self, *args, input: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:

        if input is None:
            input = self.inputs[constants.SKILL_GENERATION_MODULE]

        data = await self.skill_refine_(input=input)

        return data
