from env.rapid_env import RapidEnv
from utils import api_util, log
import json
import threading

class RapidCollector:

    def __init__(self, profile) -> None:

        self.logger = log.get_loguru()
        self.blacklist_file = profile.load_env()['rapid']['black_list']
        self.catgegory_file = profile.load_env()['rapid']['category_file']

        self.invoke_cache_file = profile.load_output()['invoke_cache_file']
        self.available_api_file = profile.load_output()['available_api_file']

        self.category2tool = {}
        self.id2api = {}
        self.api2id = None

        self._invoke_cache_f = None
        self.invoke_cache = {}
        self._available_api_f = None

        self._load_blacklist()
        self._load_api()
        self._load_available_api()
        
        self.logger.info("load category size = {}".format(len(self.category2tool)))
        self.logger.info("load tool size = {}".format(sum([len(v) for k, v in self.category2tool.items()])))
        self.logger.info("load api size = {}".format(len(self.id2api)))

        self.rapid_env = RapidEnv(profile)
        self._lock = threading.Lock()

        # Extend API Parameter
        self.xlam_gen_in_file = profile.load_input()['xlam_gen']
        self.xlam_gen_out_file = profile.load_output()['xlam_gen']
        self._xlam_gen_f = None



        # Merge API
        self.merge_gen_out_file = profile.load_output()['merge_gen']
        self.merge_available_api = {}

        # Check API
        self.check_gen_out_file = profile.load_output()['check_gen']
        self.check_available_api = {}

    def _get_invoke_cache_f(self):
        
        if self._invoke_cache_f is None:
            self._invoke_cache_f = open(self.invoke_cache_file, 'a')
        return self._invoke_cache_f

    def _get_available_api_f(self):
        
        if self._available_api_f is None:
            self._available_api_f = open(self.available_api_file, 'w')
        return self._available_api_f

    def _load_blacklist(self):

        self.black_list = json.load(open(self.blacklist_file))

    def _load_api(self):
        
        filter_number = 0
        for c, c_info in json.load(open(self.catgegory_file)).items():
            self.category2tool[c] = {}
            for t, t_info in c_info.items():
                
                if t in self.black_list or api_util.standardize(t) in self.black_list:
                    filter_number += 1
                    continue
                
                self.category2tool[c][t] = set()
                for a, a_info in t_info.items():
                    
                    api_id = a_info['api_id']
                    self.category2tool[c][t].add(api_id)
                    self.id2api[api_id] = a_info
        
        self.logger.info("filter tool size = {}".format(filter_number))

    def _load_available_api(self):
        
        for line in open(self.invoke_cache_file).readlines():
            try:
                json_item = json.loads(line)
            except:
                continue
            key = json_item['api_id']
            value = json_item['content']
            self.invoke_cache[key] = value
        
        self.logger.info("Load available api cache size {}".format(len(self.invoke_cache)))

    def load_tool_generator(self):
        
        for c, c_info in self.category2tool.items():
            for t, t_info in c_info.items():
                yield t_info

    def _get_parameter(self, api_info):
        default_parameter = {}
        for param in api_info["required_parameters"]:
            if "default" not in param:
                continue
            default_parameter[param['name']] = param['default']
        for param in api_info["optional_parameters"]:
            if "default" not in param:
                continue
            default_parameter[param['name']] = param['default']
        return default_parameter

    def run_tool(self, tools):
        
        if not len(tools):
            return

        _fw = self._get_invoke_cache_f()
        _api_f = self._get_available_api_f()

        count = 0
        pass_count = 0
        pass_api = []
        
        for api_id in tools:
            
            api_info = self.id2api[api_id]

            if api_info['method'] != "GET":
                continue
            
            count += 1

            if api_id in self.invoke_cache:
                
                req_status = self.invoke_cache[api_id]['status']
                req_response = self.invoke_cache[api_id]['response']

            else:
                parameter = self._get_parameter(api_info)

                category_name = api_info['category_name']
                tool_name = api_info['tool_name']
                api_name = api_info['api_name']
                
                req_status, req_response = self.rapid_env(category_name, tool_name, api_name, parameter, fmt=True, check=True)
                with self._lock:
                    if req_status:
                        req_response['response'] = req_response['response'][:48]
                    _fw.writelines(json.dumps({
                        "api_id": api_id,
                        "content": {
                            "status": req_status,
                            "response": req_response
                        }
                    }) + "\n")
                    _fw.flush()
            
            if req_status:
                pass_count += 1
                pass_api.append(api_id)
        
        if count <= 0:
            return 
        
        pass_rate = pass_count / count
        if pass_rate >= 0.5:
            
            with self._lock:
                _api_f.writelines(json.dumps({
                    "category_name": api_info['category_name'],
                    "tool_name": api_info['tool_name'],
                    "pass_rate": pass_rate,
                    "pass_count": pass_count,
                    "pass_api": pass_api
                }) + "\n")
                _api_f.flush()
        
        self.logger.info("tool {} pass rate {}".format(api_info['tool_name'], pass_rate))

    def _load_api2id(self):

        if self.api2id is None or not len(self.api2id):
            self.api2id = {}
            for api_id, api in self.id2api.items():
                api_name = api_util.change_name(api_util.standardize(api['api_name']))
                self.api2id[api['api_name']] = api_id
                api_name = api_util.change_name(api['api_name'])
                self.api2id[api['api_name']] = api_id
                api_name = api['api_name']
                self.api2id[api['api_name']] = api_id
            self.logger.info("load api2id size={}".format(len(self.api2id)))

    def load_xlam_generator(self):
        
        for line in open(self.xlam_gen_in_file).readlines():
            json_item = json.loads(line)
            yield json_item["param_list"]

    def _get_xlam_gen_f(self):
        if self._xlam_gen_f is None:
            self._xlam_gen_f = open(self.xlam_gen_out_file, 'w')
        return self._xlam_gen_f

    def run_xlam_api(self, apis):

        self._load_api2id()
        _fw = self._get_xlam_gen_f()

        for api in apis:
            if api['tool_name'] in self.api2id:

                api_id = self.api2id[api['tool_name']]
                api_info = self.id2api[api_id]

                parameter = api['tool_arguments']

                category_name = api_info['category_name']
                tool_name = api_info['tool_name']
                api_name = api_info['api_name']
                
                req_status, req_response = self.rapid_env(category_name, tool_name, api_name, parameter, fmt=True, check=True)
                
                with self._lock:
                    if req_status:
                        req_response['response'] = req_response['response'][:48]
                        _fw.writelines(json.dumps({
                            "api_id": api_id,
                            "parameter": parameter,
                            "content": {
                                "status": req_status,
                                "response": req_response
                            }
                        }) + "\n")
                        _fw.flush()

    def load__generator(self):

        for line in open(self._gen_in_file).readlines():
            json_item = json.loads(line)
            yield json_item

    def _get__gen_f(self):

        if self.__gen_f is None:
            self.__gen_f = open(self._gen_out_file, 'w')
        return self.__gen_f

    def run__api(self, api):
        
        self._load_api2id()
        _fw = self._get__gen_f()

        api_id = api['api_id']
        if api_id not in self.id2api:
            return
        
        api_info = self.id2api[api_id]
        
        parameter = api['parameter']

        category_name = api_info['category_name']
        tool_name = api_info['tool_name']
        api_name = api_info['api_name']
        
        req_status, req_response = self.rapid_env(category_name, tool_name, api_name, parameter, fmt=True, check=True)
        
        with self._lock:
            if req_status:
                req_response['response'] = req_response['response'][:48]
                _fw.writelines(json.dumps({
                    "api_id": api_id,
                    "parameter": parameter,
                    "content": {
                        "status": req_status,
                        "response": req_response
                    }
                }) + "\n")
                _fw.flush()

    def load_merge_generator(self):

        for file in [self.available_api_file, self._gen_out_file, self.xlam_gen_out_file]:
            for line in open(file).readlines():
                try:
                    json_item = json.loads(line)
                except:
                    continue
                yield json_item

    def add2available(self, api_id, api_info, parameter):

        with self._lock:
            if api_id not in self.merge_available_api:
                self.merge_available_api[api_id] = {
                    "api_info": api_info,
                    "parameter": set()
                }
            
            sorted_parameter = dict(sorted(parameter.items()))
            category_name = api_info['category_name']
            tool_name = api_info['tool_name']
            api_name = api_info['api_name']
            req_status, req_response = self.rapid_env(category_name, tool_name, api_name, sorted_parameter, fmt=True, check=True)
            if req_status:
                self.merge_available_api[api_id]["parameter"].add(json.dumps(sorted_parameter))

    def run_merge_api(self, api):
        
        if "pass_api" in api:
            for api_id in api['pass_api']:
                api_info = self.id2api[api_id]
                parameter = self._get_parameter(api_info)
                self.add2available(api_id, api_info, parameter)
        else:
            api_id = api['api_id']
            api_info = self.id2api[api_id]
            parameter = api['parameter']
            self.add2available(api_id, api_info, parameter)
        
    def merge(self):

        with open(self.merge_gen_out_file, 'w') as fw:
            for key, value in self.merge_available_api.items():
                fw.writelines(json.dumps({
                    "api_id": key,
                    "api_info": value["api_info"],
                    "parameter": list(value["parameter"])
                }) + "\n")
    
    def load_check_generator(self):
        
        for line in open(self.merge_gen_out_file).readlines():
            json_item = json.loads(line)
            yield json_item

    def add2check(self, api_id, api_info, parameters):
        
        with self._lock:
            if api_id not in self.check_available_api:
                self.check_available_api[api_id] = {
                    "api_info": api_info,
                    "parameters": []
                }
            
            category_name = api_info['category_name']
            tool_name = api_info['tool_name']
            api_name = api_info['api_name']
            
            for param_text in parameters:
                param = json.loads(param_text)
                try:
                    param = self.rapid_env.parameter_format(api_info, param)
                    req_status, req_response = self.rapid_env.check_param(category_name, tool_name, api_name, param, fmt=True)
                except Exception as e:
                    req_status = False
                    req_response = {"error": str(e), "response": ""}
                
                if req_status:
                    self.check_available_api[api_id]["parameters"].append({
                        "parameter": param,
                        "req_status": req_status,
                        "req_response": req_response,
                    })

    def run_check_parameter(self, api):
        api_id = api['api_id']
        api_info = api['api_info']
        parameters = api['parameter']
        self.add2check(api_id, api_info, parameters)

    def merge_check(self):
        
        with open(self.check_gen_out_file, 'w') as fw:
            for key, value in self.check_available_api.items():
                
                fw.writelines(json.dumps({
                    "api_id": key,
                    "api_info": value["api_info"],
                    "parameters": value["parameters"],
                }) + "\n")

