from env.leetcode_env import LeetcodeEnv
from utils import log
import json
import threading
import random


class LeetcodeSampler:

    def __init__(self, profile) -> None:

        self.logger = log.get_loguru()

        self.id2api = {}

        self.rapid_env = LeetcodeEnv(profile)
        self._lock = threading.Lock()

        self.leetcode_check_file = profile.load_input()["leetcode_check"]

        # api json
        self._raw_f = None
        self.single_raw_api_file = profile.load_output()["leetcode"]["single"]
        self.parallel_raw_api_file = profile.load_output()["leetcode"]["parallel"]
        self.multiple_raw_api_file = profile.load_output()["leetcode"]["multiple"]


    def load_single_generator(self):
        
        for line in open(self.leetcode_check_file).readlines():
            yield json.loads(line)

    def _get_single_raw_f(self):
        if self._raw_f is None:
            self._raw_f = open(self.single_raw_api_file, 'w')
        return self._raw_f

    def run_single(self, api):
                
        single_list = []
        for param in api["call_parameter"]:
            json_item = {
                "api_id": api["api_id"],
                "tool_description": "",
                "api_description": api["api_info"]["api_description"],
                "category_name": api["api_info"]["category_name"],
                "tool_name": api["api_info"]["tool_name"],
                "api_name": api["api_info"]["api_name"],
                "method": api["api_info"]["method"],
                "parameters": api["api_info"]["parameters"],
                "call_parameter": param
            }
            single_list.append(json_item)
        
        fw = self._get_single_raw_f()
        with self._lock:
            for json_item in single_list:
                fw.writelines(json.dumps(json_item) + "\n")

    def _get_parallel_raw_f(self):
        if self._raw_f is None:
            self._raw_f = open(self.parallel_raw_api_file, 'w')
        return self._raw_f

    def run_parallel(self, api):
        
        if len(api["call_parameter"]) > 1:
            
            json_item = {
                "api_id": api["api_id"],
                "tool_description": "",
                "api_description": api["api_info"]["api_description"],
                "category_name": api["api_info"]["category_name"],
                "tool_name": api["api_info"]["tool_name"],
                "api_name": api["api_info"]["api_name"],
                "method": api["api_info"]["method"],
                "parameters": api["api_info"]["parameters"],
                "call_parameter": api["call_parameter"]
            }            
            fw = self._get_parallel_raw_f()
            with self._lock:
                fw.writelines(json.dumps(json_item) + "\n")
    
    def _get_multiple_raw_f(self):
        if self._raw_f is None:
            self._raw_f = open(self.multiple_raw_api_file, 'w')
        return self._raw_f

    def run_multiple(self, api):

        for param in api["call_parameter"]:
            json_item = {
                "api_id": api["api_id"],
                "tool_description": "",
                "api_description": api["api_info"]["api_description"],
                "category_name": api["api_info"]["category_name"],
                "tool_name": api["api_info"]["tool_name"],
                "api_name": api["api_info"]["api_name"],
                "method": api["api_info"]["method"],
                "parameters": api["api_info"]["parameters"],
                "call_parameter": param
            }
            if api["api_id"] not in self.id2api:
                self.id2api[api["api_id"]] = [json_item]
            else:
                self.id2api[api["api_id"]].append(json_item)

    def run_multiple_sampling(self):
        
        sampling_set = set()
        for api_id, api_list in self.id2api.items():
            for _ in range(2):
                candidate_api_id = random.sample(list(self.id2api.keys()), 1)
                sampling_set.add((api_id, candidate_api_id[0]))
        
        fw = self._get_multiple_raw_f()
        for sample in sampling_set:
            assert len(sample) == 2
            fw.writelines(json.dumps({
                "api_1": self.id2api[sample[0]],
                "api_2": self.id2api[sample[1]],
            }) + "\n")
