from env.leetcode_env import LeetcodeEnv
from server.policy import LLMPolicy
from utils import log
import json
import threading
import os
import uuid


class LeetcodeCollector:

    def __init__(self, profile) -> None:

        self.logger = log.get_loguru()

        self.root_dir = profile.load_env()["leetcode"]['root_dir']

        self.id2api = {}
        self.api2id = None
        self.logger.info("load api size = {}".format(len(self.id2api)))

        self.leetcode_env = LeetcodeEnv(profile)
        self.policy = LLMPolicy(profile)

        self._lock = threading.Lock()
        self.available_api_file = profile.load_output()['available_api_file']
        self._available_api_f = None
    
    def _get_available_api_f(self):
        if self._available_api_f is None:
            self._available_api_f = open(self.available_api_file, 'a')
        return self._available_api_f

    def load_tool_generator(self):

        api_ids = set()
        if os.path.exists(self.available_api_file):
            for line in open(self.available_api_file).readlines():
                json_item = json.loads(line)
                api_ids.add(json_item['api_id'])
        
        for category_name in os.listdir(self.root_dir):
            cate_dir = os.path.join(self.root_dir, category_name)
            if os.path.isdir(cate_dir):
                for tool_name in os.listdir(cate_dir):
                    tool_dir = os.path.join(cate_dir, tool_name)
                    if os.path.isdir(tool_dir):
                        tool_json = os.path.join(tool_dir, tool_name + "_add_schema.json")
                        for item in json.load(open(tool_json))["api_list"]:
                            _key = "{}##{}##{}".format(category_name, tool_name, item['name'])
                            api_id = uuid.uuid5(uuid.NAMESPACE_DNS, _key).hex
                            if api_id in api_ids:
                                print("SKIP: {}".format(api_id))
                                continue
                            yield (category_name, tool_name, item)
    
    def run_tool(self, tool):

        category_name, tool_name, api = tool
        api_name = api['name']
        
        _key = "{}##{}##{}".format(category_name, tool_name, api_name)
        api_id = uuid.uuid5(uuid.NAMESPACE_DNS, _key).hex
        
        required = []
        properties = {}
        for parameter in api['required_parameters']:
            properties.update(parameter)
            required.append(list(parameter.keys())[0])
        for parameter in api['optional_parameters']:
            properties.update(parameter)

        api_info = {
            "api_id": api_id,
            "api_description": api['description'],
            "category_name": category_name,
            "tool_name": tool_name,
            "api_name": api_name,
            "method": "GET",
            "parameters": {
                "properties": properties,
                "required": required,
                "type": "object",
            }
        }
        
        parameters = self.build_parameter(api_info)
        call_parameter = []
        call_parameter.extend(api['examples'])

        for parameter in parameters:
            try:
                r_status, _ = self.leetcode_env.check_param(category_name, tool_name, api_name, parameter, fmt=True)
                if r_status:
                    req_status, req_response = self.leetcode_env(category_name, tool_name, api_name, parameter, fmt=True, check=True)
                    if req_status:
                        call_parameter.append(parameter)
            except:
                continue
        
        if len(call_parameter):
            
            _api_f = self._get_available_api_f()
            with self._lock:
                _api_f.writelines(json.dumps({
                    "api_id": api_id,
                    "api_info": api_info,
                    "call_parameter": call_parameter,
                }) + "\n")
                _api_f.flush()
        
        self.logger.info("tool {} pass rate {}".format(api_name, len(call_parameter)))


    def build_parameter(self, text):

        prompts = "You are tasked with generating a set of candidate parameters for an API. The API has specific requirements for its parameters, including the names, types, and whether they are required. Your goal is to create valid and well-structured examples of parameter sets that strictly follow the API's specifications.\n***\nHere are the details you need to consider when generating these examples:\n1. **Parameter Names**: Use the exact names as defined by the API. Do not modify or abbreviate them.\n2. **Parameter Types**: Ensure that the values assigned to each parameter match the required data types (e.g., integer, float, string, boolean).\n3. **Required Parameters**: Include all parameters that are marked as required by the API. Optional parameters can be included or omitted, depending on the example.\n4. **Value Range and Context**: Generate realistic and varied values that make sense within the context of the API's purpose.\n***\n[Example]:\n[Output]:\n{{\"parameter1\": <value_of_type_defined>, \"parameter2\": <value_of_type_defined>,...\"parameterN\":<value_of_type_defined>}}\nPlease generate 3-5 different examples of these parameter sets, ensuring that each one adheres strictly to the API's requirements: {text}.\n[Output]:\n"
        
        result = self.policy(prompts.format(text=text))
        
        param_list = []
        for item in result.split("\n"):
            if "{" in item and "}" in item:
                start_index = item.find("{")
                end_index = item.rfind("}") + 1
                param = item[start_index:end_index]
                try:
                    print(param)
                    param = eval(param)
                    assert isinstance(param, dict)
                    param_list.append(param)
                except:
                    continue
        return param_list