from .base import BaseGenerator, Action
from utils import api_util
import json
import uuid


class SingleGenerator(BaseGenerator):

    def __init__(self, profile, env, policy):
        super().__init__(profile, env, policy)

        self.in_file = profile.load_input()[env.name]["single"]
        self.out_file = profile.load_output()[env.name]["single"]

        self.template = profile.load_template(env.name)
        self.examples = profile.load_examples(env.name, "single")

        self._fw = None

    def load_generator(self):
        
        for line in open(self.in_file).readlines():
            yield json.loads(line)

    def _get_out_fw(self):
        
        if self._fw is None:
            self._fw = open(self.out_file, 'w')
        return self._fw

    def run(self, api):
        
        category_name = api['category_name']
        tool_name = api['tool_name']
        api_name = api['api_name']
        api_param = api['call_parameter']
        api_description = api["api_description"]
        
        # import pdb;pdb.set_trace()
        query = self.build_query(api)
        print("[query]:{}\n".format(query))
        if not self.query_check(query, api_param):
            print("skip: {}".format(query))
            return
        

        req_status, req_response = self.env(category_name, tool_name, api_name, api_param, fmt=True, check=True)
        if req_status:
            response = str(req_response["response"])[:4096]
            print("[response]:{}\n".format(response))
            api_info = {
                "name": api['api_name'],
                "description": api["api_description"],
                "call_parameter": api_param,
                "call_response": response
            }
            api["function_calling"] = {
                "call_parameter": api_param,
                "call_response": response
            }
            answer = self.build_answer(query, api_info).strip("Answer:")
            print("[answer]:{}\n".format(answer))
            if not self.answer_check(query, response, answer):
                print("skip: Q: {} A: {}".format(query, answer))
                return
            if not self.final_answer_check(query, answer):
                print("skip: Q: {} A: {}".format(query, answer))
                return
            
            plan = self.build_plan(query, api).strip('[]')
            print("[plan]:{}\n".format(plan))
            selection = self.build_selection(plan, api_name, api_description)
            print("[selection]:{}".format(selection))

            api["query"] = query
            api["answer"] = answer
            api["plan"] = plan
            api["selection"] =selection
            
            # import pdb;pdb.set_trace()
            fw = self._get_out_fw()
            with self._lock:
                fw.writelines(json.dumps(api) + "\n")
                print("_flush")
                fw.flush()

    def build_query(self,api):
        instruction = self.template["single_query"]
        
        messages = []
        for i,item in enumerate(self.examples):
            function_calling = {
                "name": item["golden"][0]["name"],
                "description": item["golden"][0]["description"],
                "parameter_setting":item["golden"][0]["parameters"],
                "call_parameter": item["golden"][0]["call_parameter"]
            }
            user_info = instruction.format(function_calling=function_calling)
            assistant_info = "query: {}".format(item["query"])
            messages.append({"content": user_info,  "role": "user"})
            messages.append({"content": assistant_info,  "role": "assistant"})

        function_calling = {
            "name": api['api_name'],
            "description": api["api_description"],
            "parameter_setting":api["parameters"],
            "call_parameter": api["call_parameter"]
        }
        user_info = instruction.format(function_calling=function_calling)
        messages.append({"content": user_info,  "role": "user"})

        result = self.policy(messages).replace("\\","")
        # import pdb;pdb.set_trace()
        if "query:" in result:
            start_index = result.find("query:") + len("query:")
            result = result[start_index:]
        return result
    
    def build_plan(self, query, api):

        instruction = self.template["single_plan"]
        
        messages = []
        for i,item in enumerate(self.examples):
            api_info = {
                "name": item["golden"][0]['name'],
                "description": item["golden"][0]["description"],
                "parameters": item["golden"][0]["call_parameter"]
            }
            user_info = instruction.format(query=item["query"],function_info=api_info)
            assistant_info = "Task Plan:{}".format(item["tool_planning"])
            messages.append({"content": user_info,  "role": "user"})
            messages.append({"content": assistant_info,  "role": "assistant"})

        api_info = {
            "name": api['api_name'],
            "description": api["api_description"],
            "parameters": api["call_parameter"]
        }
        user_info = instruction.format(query=query, function_info=api_info)
        messages.append({"content": user_info,  "role": "user"})
        
        # import pdb;pdb.set_trace()
        result = self.policy(messages).replace("\\", "")

        if "Task Plan:" in result:
            start_index = result.find("Task Plan:") + len("Task Plan:")
            result = result[start_index:].strip()
            
        if result[0] == '[':
            result = result[1:]
        return result
    
