from abc import ABC
from abc import abstractmethod
from utils import log, api_util
from parallel.runner import ActionQueue
import json
import copy


class Action(ABC):
    
    def __init__(self, profile) -> None:
        
        self.logger = log.get_loguru()
        self.profile = profile
    
    @abstractmethod
    def run(self, *args, **kwargs):
        return NotImplemented


class CachedToolAction(Action):
    
    """
        Tools action with caching capabilities
    """    
    def __init__(self, profile, agent_handler):
        super().__init__(profile)

        self.handler = agent_handler
        self.handler.build_action_db(self.name, self.fields)
    
    def run(self, state, message):
        pass

    def load_serialized_obj(self, state, traj_id):
        '''
            query the cached object based on the current execution state.
            each action corresponds to a db table, and each table contains fields with the same names.
        '''
        session_id = state.get_session_id()
        action_param = {
            "sess_id": session_id,
            "step": "fixed_step",
            "traj_id": traj_id,
        }
        serial = self.handler.load_action_feedback(session_id, self.name, self.name, action_param)
        if serial is not None:
            state.loads(traj_id, serial)
            return True
        return False
    
    def save_serialized_obj(self, state, traj_id):
        '''
            store the cached object based on the current execution state.
        '''
        session_id = state.get_session_id()
        action_param = {
            "sess_id": session_id,
            "step": "fixed_step",
            "traj_id": traj_id,
        }

        serial = state.dumps(traj_id)
        self.handler.save_action_feedback(session_id, self.name, self.name, serial, action_param)    


class BeamCachedToolAction(CachedToolAction):
    
    def __init__(self, profile, agent_handler):
        super().__init__(profile, agent_handler)
        self.core_proflile = self.profile.load_core()

    def run(self, state, message, json_format=False, recover=False):
        pass

class APIAction(BeamCachedToolAction):
    """
        1. Load the formatted APIs and find the tools from it.
        2. Load the system message with specific format.
    """
    
    def __init__(self, profile, agent_handler):
        super().__init__(profile, agent_handler)
        
        self.policy_aka = self.profile.load_policy()["aka"]
        self.prompt = json.load(open(self.profile.load_prompt()[self.policy_aka], 'r'))
        
        core_profile = self.profile.load_core()
        self.teacher_forcing_list = core_profile["teacher_forcing"]
        self.reward = core_profile["reward"]

        self.action_runner = ActionQueue()
        
    def build_sys_msg(self, message):
        
        template = self._load_template(message["api_list"])
        sys_msg = self.prompt["system"]["common"].format(template=template)
        return sys_msg.replace("\\", "")

    def _load_template(self, api_list):
        available_api_list = []
        tool_idx = 0
        for api in api_list:
            # if "calculate" != api['category_name'] and "BFCLjava" not in api['tool_name']:
            #     api_name = api_util.change_name(api_util.standardize(api["api_info"]["api_name"]))
            # else:
            # api_name = api_util.change_name(api_util.standardize(api["api_info"]["api_name"]))
            api_name = api["api_info"]["api_name"]
            available_api = {
                "ID": tool_idx,
                "name": api_name,
                "description": api["api_info"]["api_description"],
                "parameters": api["api_info"]["api_param"]
            }
            available_api_list.append(available_api)
            tool_idx += 1
        return available_api_list

    def build_parameter_sys_msg(self, message, api_idx):
            
        template = self._load_api_template(message["api_list"], api_idx)
        sys_msg = self.prompt["system"]["parameter"].format(template=template)
        return sys_msg

    def _load_api_template(self, api_list, api_idx):
        
        tool_idx = 0
        api = api_list[api_idx]
        # if "calculate" != api['category_name'] and "BFCLjava" not in api['tool_name']:
        #     api_name = api_util.change_name(api_util.standardize(api["api_info"]["api_name"]))
        # else:
        # api_name = api_util.change_name(api_util.standardize(api["api_info"]["api_name"]))
        api_name = api["api_info"]["api_name"]
        available_api = {
            "name": api_name,
            "description": api["api_info"]["api_description"],
            "parameters": api["api_info"]["api_param"]
        }
        return available_api
