from .base import BaseGenerator
from utils import api_util
import json
import profile
import random
import uuid


class FormatGenerator(BaseGenerator):

    def __init__(self, profile, env, policy):
        super().__init__(profile, env, policy)

        self._fw = None
        self.id2tool = None
        self.id2category = None
        self.id2api = None

        self.out_files = profile.load_output()[env.name]

    def get_candidate_api(self, api_id, max_number=3):
        
        candidate_tools = self.id2tool[api_id]
        candidate_categorys = self.id2category[api_id]
        if len(candidate_tools) > max_number:
            candidates = random.sample(candidate_tools, max_number)
        elif len(candidate_categorys) > max_number:
            candidates = random.sample(candidate_categorys, max_number)
        else:
            candidates = candidate_categorys
        
        if api_id not in candidates:
            candidates.append(api_id)
        random.shuffle(candidates)
        
        candidate_apis = []
        select_idx = -1
        for idx, candidate_id in enumerate(candidates):
            
            if api_id == candidate_id:
                select_idx = idx
            
            api_info = self.id2api[candidate_id]

            if 'parameters' in api_info:
                api_param = api_info['parameters']
            else:
                properties = {}
                required = []
                for p in api_info['required_parameters']:
                    properties[p['name']] = {
                        "description": p['description'],
                        "type": p['type'],
                        "default": p['default']
                    }
                    required.append(p['name'])
                for p in api_info['optional_parameters']:
                    properties[p['name']] = {
                        "description": p['description'],
                        "type": p['type'],
                        "default": p['default']
                    }
                api_param = {
                    "properties": properties,
                    "required": required,
                    "type": "object"
                }
            candidate_apis.append({
                "api_id": api_id,
                "category_name": api_info['category_name'],
                "tool_name": api_info['tool_name'],
                "tool_description": api_info['tool_description'] if 'tool_description' in api_info else "",
                "api_info": {
                    "api_name": api_info['api_name'],
                    "api_description": api_info['api_description'],
                    "api_param": api_param,
                    "method": api_info['method']
                }
            })
        
        assert select_idx > -1
        return select_idx, candidate_apis
    
    
    def get_raw_candidate_api(self, api_id, max_number=3):
        
        candidate_tools = self.id2tool[api_id]
        candidate_categorys = self.id2category[api_id]
        candidates = [api_id]

        candidate_apis = []
        select_idx = -1
        for idx, candidate_id in enumerate(candidates):
            
            if api_id == candidate_id:
                select_idx = idx
            
            api_info = self.id2api[candidate_id]

            if 'parameters' in api_info:
                api_param = api_info['parameters']
            else:
                properties = {}
                required = []
                for p in api_info['required_parameters']:
                    properties[p['name']] = {
                        "description": p['description'],
                        "type": p['type'],
                        "default": p['default']
                    }
                    required.append(p['name'])
                for p in api_info['optional_parameters']:
                    properties[p['name']] = {
                        "description": p['description'],
                        "type": p['type'],
                        "default": p['default']
                    }
                api_param = {
                    "properties": properties,
                    "required": required,
                    "type": "object"
                }
            candidate_apis.append({
                "api_id": api_id,
                "category_name": api_info['category_name'],
                "tool_name": api_info['tool_name'],
                "tool_description": api_info['tool_description'] if 'tool_description' in api_info else "",
                "api_info": {
                    "api_name": api_info['api_name'],
                    "api_description": api_info['api_description'],
                    "api_param": api_param,
                    "method": api_info['method']
                }
            })
        
        assert select_idx > -1
        return select_idx, candidate_apis
    
    def get_raw_java_candidate_api(self, json_data):
        
        
        candidate_apis = []
        candidate_apis.append({
            "api_id": json_data["api_id"],
            "category_name": json_data['category_name'],
            "tool_name": json_data['tool_name'],
            "tool_description": json_data['tool_description'] if 'tool_description' in json_data else "",
            "api_info": {
                "api_name": json_data['api_name'],
                "api_description": json_data['api_description'],
                "api_param": json_data['parameters'],
                "method": json_data['method']
            }
        })
        return 0, candidate_apis

    def load_generator(self):
        
        for data_type in ["single", "parallel", "multiple"]:
        # for data_type in ["single", "parallel", "multiple", "multiple_parallel"]:
        # for data_type in ["single", "parallel"]:
            
            out_file_format = self.out_files[data_type].replace(".jsonl", "_format.jsonl")
            self._fw = open(out_file_format, 'w')

            print(self.out_files[data_type])
            for line in open(self.out_files[data_type]).readlines():
                try:
                    yield (data_type, json.loads(line))
                except:
                    continue

    def _get_format_f(self):
        assert self._fw is not None
        return self._fw

    def run(self, data, id2tool, id2category, id2api):
        
        if self.id2tool is None:
            self.id2tool = id2tool
            self.id2category = id2category
            self.id2api = id2api
        
        data_type, json_data = data

        if data_type == "single":
            self._single_format(json_data)
        elif data_type == "parallel":
            self._parallel_format(json_data)
        elif data_type == "multiple":
            self._multiple_format(json_data)
        elif data_type == "multiple_parallel":
            self._multiple_format(json_data)

    def _single_format(self, json_data):

        if "api_id" not in json_data:
            json_data['category_name'] = "pcode"
            json_data['tool_name']  = "java"
            json_data['method'] = "GET"
            json_data['tool_description'] = ""
            _key = "{}##{}##{}".format(json_data['category_name'], json_data['tool_name'], json_data['api_name'])
            json_data['api_id']= uuid.uuid5(uuid.NAMESPACE_DNS, _key).hex
        
        api_id = json_data['api_id']
        selected_idx, candidate_apis = self.get_raw_candidate_api(api_id)
        # selected_idx, candidate_apis = self.get_raw_java_candidate_api(json_data)
        # selected_idx, candidate_apis = self.get_candidate_api(api_id, max_number=2)
        json_data["selection"]["ID"] = selected_idx
        golden = [{
            "api_id": api_id,
            "category_name": json_data['category_name'],
            "tool_name": json_data['tool_name'],
            "tool_description": json_data['tool_description'],
            "api_info": {
                "api_name": json_data['api_name'],
                "api_description": json_data['api_description'],
                "api_param": json_data['parameters'],
                "method": json_data['method']
            },
            "parameters": json_data['call_parameter'],
            "conv": {
                "ID": selected_idx,
                "selection":json_data["selection"],
                "plan": json_data['plan'],
                "answer": json_data['answer']
            }
        }]

        format_data = {
            "session_id": uuid.uuid4().hex,
            "api_list": candidate_apis,
            "query": json_data['query'],
            "golden": golden,
            "answer":json_data['answer']
        }
        fw = self._get_format_f()
        with self._lock:
            fw.writelines(json.dumps(format_data) + "\n")
        
    def _parallel_format(self, json_data):
        
        api_id = json_data['api_id']
        selected_idx, candidate_apis = self.get_raw_candidate_api(api_id)
        # selected_idx, candidate_apis = self.get_candidate_api(api_id, max_number=2)
        golden = []

        if "call_parameter" in json_data:
            _call_parameters = json_data["call_parameter"]
        else:
            _call_parameters = json_data["call_parameters"]
        
        for idx, item in enumerate(_call_parameters):

            if 'input_paramters' in item:
                _input_paramters = item['input_paramters']
            else:
                _input_paramters = item
            
            json_data["selection_list"][idx]["ID"] = selected_idx
            format_data = {
                "api_id": api_id,
                "category_name": json_data['category_name'],
                "tool_name": json_data['tool_name'],
                "tool_description": json_data['tool_description'],
                "api_info": {
                    "api_name": json_data['api_name'],
                    "api_description": json_data['api_description'],
                    "api_param": json_data['parameters'],
                    "method": json_data['method']
                },
                "parameters": _input_paramters,
                "conv": {
                    "ID": selected_idx,
                    "selection":json_data["selection_list"][idx],
                    "plan": json_data['plan_list'][idx],
                    "answer": json_data['answer_list'][idx]
                }
            }
            golden.append(format_data)

        format_data = {
            "session_id": uuid.uuid4().hex,
            "api_list": candidate_apis,
            "query": json_data['query'],
            "golden": golden,
            "answer":json_data["final_answer"]
        }
        fw = self._get_format_f()
        with self._lock:
            fw.writelines(json.dumps(format_data) + "\n")

    def _multiple_format(self, json_data):
        
        candi_name ,candidate_apis = {}, []
        api_id0 = json_data["golden_api_list"][0]['api_id']
        api_id1 = json_data["golden_api_list"][1]['api_id']
        if api_id0 == api_id1:
            return 
        
        selected_idx0, candidate_apis0 = self.get_raw_candidate_api(api_id0)
        selected_idx1, candidate_apis1 = self.get_raw_candidate_api(api_id1)
        # selected_idx0, candidate_apis0 = self.get_candidate_api(api_id0, max_number=1)
        # selected_idx1, candidate_apis1 = self.get_candidate_api(api_id1, max_number=1)
        
        for cand in candidate_apis0:
            if "{}_{}".format(cand["api_id"],cand["api_info"]["api_name"]) not in candi_name:
                candi_name["{}_{}".format(cand["api_id"],cand["api_info"]["api_name"])] = cand
        for cand in candidate_apis1:
            if "{}_{}".format(cand["api_id"],cand["api_info"]["api_name"]) not in candi_name:
                candi_name["{}_{}".format(cand["api_id"],cand["api_info"]["api_name"])] = cand
        candidate_apis = [v for k, v in candi_name.items()]
        random.shuffle(candidate_apis)
        
        api2id = {}
        for index, api in enumerate(candidate_apis):
            api2id["{}_{}".format(api["api_id"], api["api_info"]["api_name"])] = index

        golden = []
        for idx, api in enumerate(json_data["golden_api_list"]):
            selection = json_data["selection_list"][idx]
            try:
                selection["ID"] = api2id["{}_{}".format(api["api_id"],api["api_name"])]
            except:
                print("error")
                return
            
            format_data = {
                "api_id": api['api_id'],
                "category_name": api['category_name'],
                "tool_name": api['tool_name'],
                "tool_description": api['tool_description'],
                "api_info": {
                    "api_name": api['api_name'],
                    "api_description": api['api_description'],
                    "api_param": api['parameters'],
                    "method": api['method']
                },
                "parameters": api['api_calling']['input_paramters'] if 'input_paramters' in api['api_calling'] else api['api_calling']['call_parameter'],
                "conv": {
                    "ID": api2id["{}_{}".format(api["api_id"],api["api_name"])],
                    "selection":selection,
                    "plan": json_data['plan_list'][idx],
                    "answer": json_data['answer_list'][idx]
                }
            }
            golden.append(format_data)
        
        format_data = {
            "session_id": uuid.uuid4().hex,
            "api_list": candidate_apis,
            "query": json_data['query'],
            "golden": golden,
            "answer":json_data["final_answer"]
        }
        fw = self._get_format_f()
        with self._lock:
            fw.writelines(json.dumps(format_data) + "\n")