from .generator.single import SingleGenerator
from .generator.parallel import ParallelGenerator
from .generator.multiple import MultipleGenerator
from .generator.multiple_parallel import MultipleParallelGenerator
from .generator.format import FormatGenerator
from pipeline.generator.base import Action
from server.policy import LLMPolicy
from env.rapid_env import RapidEnv
from utils import api_util, log
import json
import threading
import random
import uuid


class RapidGenerator:

    def __init__(self, profile, action):
        
        self.logger = log.get_loguru()
        
        self.action = action
        self.policy = LLMPolicy(profile)
        self.rapid_env = RapidEnv(profile)
        
        # for candidate api
        self.blacklist_file = profile.load_env()["rapid"]['black_list']
        self.catgegory_file = profile.load_env()["rapid"]['category_file']
        
        self.id2tool = {}
        self.id2category = {}
        self.id2api = {}

        if self.action == Action.Single:
            self.generator = SingleGenerator(profile, self.rapid_env, self.policy)
        elif self.action == Action.Parallel:
            self.generator = ParallelGenerator(profile, self.rapid_env, self.policy)
        elif self.action == Action.Multiple:
            self.generator = MultipleGenerator(profile, self.rapid_env, self.policy)
        elif self.action == Action.MultipleParallel:
            self.generator = MultipleParallelGenerator(profile, self.rapid_env, self.policy)
        elif self.action == Action.Format:
            self.generator = FormatGenerator(profile, self.rapid_env, self.policy)
        else:
            raise Exception("Not support this action {}".format(self.action))

        self._load_blacklist()
        self._load_candidate_api()
    
    def _load_blacklist(self):
        self.black_list = json.load(open(self.blacklist_file))

    def _load_candidate_api(self):

        if len(self.id2tool):
            return
        
        for c, c_info in json.load(open(self.catgegory_file)).items():
            category_ids = set()
            for t, t_info in c_info.items():
                if t in self.black_list or api_util.standardize(t) in self.black_list:
                    continue

                tool_ids = set()
                for a, a_info in t_info.items():
                    api_id = a_info['api_id']

                    if a_info['method'] != "GET":
                        continue

                    category_ids.add(api_id)
                    tool_ids.add(api_id)
                for a, a_info in t_info.items():
                    api_id = a_info['api_id']

                    if a_info['method'] != "GET":
                        continue

                    self.id2tool[api_id] = tool_ids

            for t, t_info in c_info.items():
                if t in self.black_list or api_util.standardize(t) in self.black_list:
                    continue
                
                for a, a_info in t_info.items():
                    api_id = a_info['api_id']

                    if a_info['method'] != "GET":
                        continue

                    self.id2category[api_id] = category_ids
                    self.id2api[api_id] = a_info

        self.logger.info("filter category size = {} tool size = {}".format(len(self.id2category), len(self.id2tool)))
        
    def load_generator(self):
        return self.generator.load_generator()
        
    def run(self, api):
        
        if self.action == Action.Format:
            return self.generator.run(api, self.id2tool, self.id2category, self.id2api)
        else:
            return self.generator.run(api)
