from .action import APIAction
from utils import api_util
import json
import copy
import uuid


class ChainToolExecution(APIAction):
    
    name: str = "tool_execution"
    fields: list = ["sess_id", "step", "traj_id", name]
    
    def __init__(self, profile, policy, api_env, handler):
        super(ChainToolExecution, self).__init__(profile, handler)
        self.policy = policy
        self.api_env = api_env
        self.beam = 1
        self.teacher_forcing = True if "execution" in self.teacher_forcing_list else False
        
    def observation_gen(self, state, traj_id, message):
        
        # import pdb;pdb.set_trace()
        api_traj = state.retrieve_api(traj_id, neighbor=True)
        api_idx = api_traj.get_memory()["api"]["ID"]
        parameter = api_traj.get_memory()["parameter"]
        api = message["api_list"][api_idx]
        
        status, tool_result = self.api_env(
            api['category_name'],
            api['tool_name'],
            api["api_info"]['api_name'],
            parameter,
            fmt=True,
        )
        assert isinstance(tool_result, dict)
        
        observation = str(tool_result['response']).replace("\\", "").replace("\n\n", "")
        return observation

    def get_oracle_observation(self, state, traj_id, message):
        
        api_traj = state.retrieve_api(traj_id, neighbor=True)
        api_idx = api_traj.get_memory()["oracle_api"]["ID"]
        parameter = api_traj.get_memory()["oracle_parameter"]
        api = message["api_list"][api_idx]
        status, tool_result = self.api_env(
            api['category_name'],
            api['tool_name'],
            api["api_info"]['api_name'],
            parameter,
            fmt=True,
        )
        assert isinstance(tool_result, dict)
        observation = tool_result['response'].replace("\\", "").replace("\n\n", "")
        return observation

    def run(self, state, message, traj_id, recover=False):
        
        if not self.load_serialized_obj(state, traj_id):
            
            if recover: return "leaf_node"

            if self.teacher_forcing:
                oracle_observation = self.get_oracle_observation(state, traj_id, message)
            else:
                oracle_observation = ""
            
            free_observation = self.observation_gen(state, traj_id, message)
            
            self.logger.info("free observation: {}".format(free_observation[:124]))

            if self.teacher_forcing:
                observation = oracle_observation[:4096]
            else:
                observation = free_observation[:4096]
            
            state.update(traj_id, {
                "observation": observation,
                "oracle_observation": oracle_observation,
                "free_observation": free_observation,
                "reward": 1,
                "child_id": [str(uuid.uuid4()) for _ in range(self.beam)] 
            })
            self.save_serialized_obj(state, traj_id)
        
        traj = state.get(traj_id)
        traj_memory = traj.get_memory()

        for idx in range(self.beam):
            state.next_traj(traj, traj.get_memory()['child_id'][idx])

        self.logger.info("[tool] step-{} {}: observation={}.".format(
            str(state.step), self.name, json.dumps(traj_memory["observation"])
        ))


class TreeToolExecution(ChainToolExecution):
    
    name: str = "tool_execution"
    fields: list = ["sess_id", "step", "traj_id", name]
    
    def __init__(self, profile, policy, api_env, handler):
        super(TreeToolExecution, self).__init__(profile, policy, api_env, handler)
        self.beam = 2