from .agent import ToolAgent
from cache.db import AgentAlchemyDatabase
from cache.handler import AgentHandler
from .registry import auto_register
from env.rapid_env import RapidEnv
from env.leetcode_env import LeetcodeEnv
from env.pseudo_env import PseudoEnv
from core.registry import CoreRegistry
import json


@auto_register("api_agent")
class APIAgent(ToolAgent):
    
    def __init__(self, profile) -> None:
        super().__init__(profile)

        self.env_profile = self.profile.load_env()
        platform = self.env_profile['platform']
        if platform == "rapidapi":
            self.api_env = RapidEnv(profile)
        elif platform == "leetcode":
            self.api_env = LeetcodeEnv(profile)
        elif platform == "pseudo":
            self.api_env = PseudoEnv(profile)
        else:
            raise Exception("Not Support Env: {}".format(platform))

        self.core_profile = self.profile.load_core()
        self.core_name = self.core_profile["name"]
        self.core_memory = self.core_profile["memory"]
        
        self.policy_profile = self.profile.load_policy()
        policy_aka = self.policy_profile['aka']
        split = self.profile._cur["split"]
        prefix = self.profile._cur["prefix"]

        db = self.profile.load_cache()["db"].format(
            policy_aka=policy_aka,
            core_name=self.core_name,
            core_memory=self.core_memory,
            split=split,
            prefix=prefix,
        )
        self.handler = AgentHandler(AgentAlchemyDatabase(db))
        self.core = CoreRegistry.create_instance(
            self.core_name,
            self.profile,
            self.policy,
            self.api_env,
            self.handler
        )
        
        self.final_profile = self.profile.load_final()
        self.traj_file = self.final_profile["traj"].format(
            policy_aka=policy_aka,
            core_name=self.core_name,
            core_memory=self.core_memory,
            split=split,
            prefix=prefix,
        )
        self.conv_file = self.final_profile["conv"].format(
            policy_aka=policy_aka,
            core_name=self.core_name,
            core_memory=self.core_memory,
            split=split,
            prefix=prefix,
        )

    def run(self, message):
        '''
            Message includes: 'session_id', 'query', and 'api_list', which represent the session ID, user query, and list of available APIs, respectively."
        '''
        # if message['session_id'] in [
        #     "344ce6b47a8743f9b46ee1b52c700ecc"
        # ]:
        #     return 
        assert "session_id" in message and "query" in message and 'api_list' in message
        self.logger.info("[user]: {} [session]: {}".format(message["query"], message['session_id']))
        
        state = self.core.pre_process(message)
        # each loop iteration runs all actions under the current step
        loop = 1
        while loop:
            loop = self.core.run(state, message)
        
        # save result into json file
        self.core.post_process(state, message)
        self.logger.info("user query processing completed !!!")
        
    def finalize(self):
        
        self.logger.info("processed number: {}".format(self.core.processed.value))
        
        # display statistical information
        traj_count = 0 
        with open(self.traj_file, 'w') as f:
            for line in self.core.traj:
                if line is not None:
                    traj_count += 1
                    f.writelines(json.dumps(line) + "\n")
                    
        self.logger.info("save traj file: {} size: {}".format(self.traj_file, traj_count))

        conv_count = 0
        with open(self.conv_file, 'w') as f:
            for line in self.core.conv:
                if line is not None:
                    conv_count += 1
                    f.writelines(json.dumps(line) + "\n")
        
        self.logger.info("save conv file: {} size: {}".format(self.conv_file, conv_count))
        