from .profile import Profile
from .registry import auto_register
import json


class ToolAgentProfile(Profile):
    
    def __init__(self, args) -> None:
        super().__init__(args)
        
    def load_policy(self):
        return self.config["policy"]
        
    def load_cache(self):
        return self.config["cache"]


@auto_register("api_agent")
class APIAgentProfile(ToolAgentProfile):
    
    def __init__(self, args) -> None:
        super().__init__(args)

        self._cur = self.config["current_file"]
    
    def load_data(self, lazy_load, max_instance):
        
        split = self._cur["split"]
        prefix = self._cur["prefix"]
        
        if split == 'train':
            cur_file = self.config["train_file"][prefix]
        elif split == 'eval':
            cur_file = self.config["eval_file"][prefix]
        else:
            cur_file = self.config["test_file"][prefix]
        
        if lazy_load:
            return self._load_data_generator(cur_file, max_instance)
        else:
            return self._load_data_list(cur_file, max_instance)
    
    def _load_data_list(self, cur_file, max_instance):
        
        data_list = []
        for text in open(cur_file).readlines()[:max_instance]:
            json_item = json.loads(text)
            data_list.append(json_item)
        return data_list
    
    def _load_data_generator(self, cur_file, max_instance):

        count = 0

        # length = len(open(cur_file).readlines())
        # start = length // 2
        with open(cur_file) as f:
            while (count < max_instance or max_instance == -1):
                count += 1
                text = f.readline()
                
                # if count < start:
                #     continue
                
                if not text:
                    break
                json_item = json.loads(text)
                yield json_item
    
    def load_env(self):
        return self.config["env"]

    def load_core(self):
        return self.config["core"]
    
    def load_prompt(self):
        return self.config["prompt"]

    def load_final(self):
        return self.config["final"]
