from env.rapid_env import RapidEnv
from utils import api_util, log
import json
import threading
import random


class RapidSampler:

    def __init__(self, profile) -> None:

        self.logger = log.get_loguru()
        self.blacklist_file = profile.load_env()['rapid']['black_list']
        self.catgegory_file = profile.load_env()['rapid']['category_file']

        self.category2tool = {}
        
        self._load_blacklist()
        self._load_api()
        
        self.logger.info("load category size = {}".format(len(self.category2tool)))
        self.logger.info("load tool size = {}".format(sum([len(v) for k, v in self.category2tool.items()])))

        self.rapid_env = RapidEnv(profile)
        self._lock = threading.Lock()

        self.rapid_check_file = profile.load_input()["rapid_check"]

        # api json
        self._raw_f = None
        self.single_raw_api_file = profile.load_output()["rapid"]["single"]
        self.parallel_raw_api_file = profile.load_output()["rapid"]["parallel"]
        self.multiple_raw_api_file = profile.load_output()["rapid"]["multiple"]

        self.id2api = {}

    def _load_blacklist(self):

        self.black_list = json.load(open(self.blacklist_file))

    def _load_api(self):
        
        filter_number = 0
        for c, c_info in json.load(open(self.catgegory_file)).items():
            self.category2tool[c] = {}
            for t, t_info in c_info.items():
                
                if t in self.black_list or api_util.standardize(t) in self.black_list:
                    filter_number += 1
                    continue
                
                self.category2tool[c][t] = set()
                for a, a_info in t_info.items():
                    
                    api_id = a_info['api_id']
                    self.category2tool[c][t].add(api_id)        
        self.logger.info("filter tool size = {}".format(filter_number))

    def load_single_generator(self):
        
        for line in open(self.rapid_check_file).readlines():
            yield json.loads(line)

    def _get_single_raw_f(self):
        if self._raw_f is None:
            self._raw_f = open(self.single_raw_api_file, 'w')
        return self._raw_f

    def run_single(self, api):
        
        properties = {}
        required = []
        for p in api["api_info"]['required_parameters']:
            properties[p['name'].lower()] = {
                "description": p['description'],
                "type": p['type']
            }
            required.append(p['name'].lower())
        for p in api["api_info"]['optional_parameters']:
            properties[p['name'].lower()] = {
                "description": p['description'],
                "type": p['type']
            }
        
        single_list = []
        for param in api["parameters"]:
            if len(param['parameter']):
                json_item = {
                    "api_id": api["api_id"],
                    "tool_description": api["api_info"]["tool_description"],
                    "api_description": api["api_info"]["api_description"],
                    "category_name": api["api_info"]["category_name"],
                    "tool_name": api["api_info"]["tool_name"],
                    "api_name": api["api_info"]["api_name"],
                    "method": api["api_info"]["method"],
                    "parameters": {
                        "properties": properties,
                        "required": required,
                        "type": "object"
                    },
                    "call_parameter": param['parameter']
                }
                single_list.append(json_item)
        
        fw = self._get_single_raw_f()
        with self._lock:
            for json_item in single_list:
                fw.writelines(json.dumps(json_item) + "\n")

    def _get_parallel_raw_f(self):
        if self._raw_f is None:
            self._raw_f = open(self.parallel_raw_api_file, 'w')
        return self._raw_f

    def run_parallel(self, api):

        if len(api["parameters"]) > 1:
            properties = {}
            required = []
            for p in api["api_info"]['required_parameters']:
                properties[p['name'].lower()] = {
                    "description": p['description'],
                    "type": p['type']
                }
                required.append(p['name'].lower())
            for p in api["api_info"]['optional_parameters']:
                properties[p['name'].lower()] = {
                    "description": p['description'],
                    "type": p['type']
                }
            
            call_parameter = []
            for param in api["parameters"]:
                call_parameter.append(param['parameter'])
            
            json_item = {
                "api_id": api["api_id"],
                "tool_description": api["api_info"]["tool_description"],
                "api_description": api["api_info"]["api_description"],
                "category_name": api["api_info"]["category_name"],
                "tool_name": api["api_info"]["tool_name"],
                "api_name": api["api_info"]["api_name"],
                "method": api["api_info"]["method"],
                "parameters": {
                    "properties": properties,
                    "required": required,
                    "type": "object"
                },
                "call_parameter": call_parameter
            }            
            fw = self._get_parallel_raw_f()
            with self._lock:
                fw.writelines(json.dumps(json_item) + "\n")
    
    def _get_multiple_raw_f(self):
        if self._raw_f is None:
            self._raw_f = open(self.multiple_raw_api_file, 'w')
        return self._raw_f

    def run_multiple(self, api):

        properties = {}
        required = []
        for p in api["api_info"]['required_parameters']:
            properties[p['name'].lower()] = {
                "description": p['description'],
                "type": p['type']
            }
            required.append(p['name'].lower())
        for p in api["api_info"]['optional_parameters']:
            properties[p['name'].lower()] = {
                "description": p['description'],
                "type": p['type']
            }
        
        for param in api["parameters"]:
            if len(param['parameter']):
                json_item = {
                    "api_id": api["api_id"],
                    "tool_description": api["api_info"]["tool_description"],
                    "api_description": api["api_info"]["api_description"],
                    "category_name": api["api_info"]["category_name"],
                    "tool_name": api["api_info"]["tool_name"],
                    "api_name": api["api_info"]["api_name"],
                    "method": api["api_info"]["method"],
                    "parameters": {
                        "properties": properties,
                        "required": required,
                        "type": "object"
                    },
                    "call_parameter": param['parameter']
                }
                if api["api_id"] not in self.id2api:
                    self.id2api[api["api_id"]] = [json_item]
                else:
                    self.id2api[api["api_id"]].append(json_item)

    def run_multiple_sampling(self):
        
        sampling_set = set()
        cy2tool = {}
        for category, tool_dict in self.category2tool.items():
            cy2tool[category] = {}
            count = 0
            for tool, api_list in tool_dict.items():
                
                sampling_tool = set()
                cy2tool[category][tool] = []
                for api_id in api_list:
                    if api_id in self.id2api:
                        cy2tool[category][tool].append(api_id)
                        count += 1
                
                # sampling from tool                
                # if len(cy2tool[category][tool]) > 5:
                #     while len(sampling_tool) < len(cy2tool[category][tool]):
                #         sampled_elem = tuple(random.sample(cy2tool[category][tool], 4))
                #         sampling_tool.add(sampled_elem)
                #     sampling_set = sampling_set.union(sampling_tool)
            
            sampling_tool = set()
            if len(cy2tool[category].keys()) > 3:
                loop = 0
                while len(sampling_tool) < count and loop < 5000:
                    loop += 1
                    api_list = []
                    for tool in random.sample(list(cy2tool[category].keys()), 2):
                        api_list.extend(cy2tool[category][tool])
                    
                    if len(api_list) > 3:
                        sampled_elem = tuple(random.sample(api_list, 2))
                        sampling_tool.add(sampled_elem)

                sampling_set = sampling_set.union(sampling_tool)
            
        fw = self._get_multiple_raw_f()
        for sample in sampling_set:
            assert len(sample) == 2
            fw.writelines(json.dumps({
                "api_1": self.id2api[sample[0]],
                "api_2": self.id2api[sample[1]],
            }) + "\n")
