from tenacity import retry, stop_after_attempt
from .base import BaseGenerator
from utils import api_util
import json
import uuid


class ParallelGenerator(BaseGenerator):
    
    def __init__(self, profile, env, policy):
        super().__init__(profile, env, policy)

        self.in_file = profile.load_input()[env.name]["parallel"]
        self.out_file = profile.load_output()[env.name]["parallel"]
        self.template = profile.load_template(env.name)
        self.examples = profile.load_examples(env.name, "parallel")

        self._fw = None

    def load_generator(self):
        
        for line in open(self.in_file).readlines():
            yield json.loads(line)

    def _get_out_fw(self):
        
        if self._fw is None:
            self._fw = open(self.out_file, 'a')
        return self._fw

    def run(self, api):
        
        category_name = api['category_name']
        tool_name = api['tool_name']
        api_name = api['api_name']
        api_param = api['call_parameter'][:3]
        api['call_parameter'] = api_param
        api_description = api["api_description"]

        function_calling = []
        for para in api_param:
            req_status, req_response = self.env(category_name, tool_name, api_name, para, fmt=True, check=True)
            if not req_status:
                return
            observation = json.dumps(req_response['response']).replace("\\", "").replace("\n\n", "")[:4096]
            function_calling.append({"call_parameter":para,"call_response":observation})
        api["function_calling"] = function_calling

        # import pdb;pdb.set_trace()
        query = self.build_query(api)
        print("[query]:{}\n".format(query))
        if not self.query_check(query,api_param):
            print("skip: {}".format(query))
            return
        
        plan_list = self.build_plan(query, api)
        if len(plan_list) == 0:
            return 
        print("[plan]:{}\n".format(plan_list))

        answer_list,selection_list = [],[]
        for idx,plan in enumerate(plan_list):
            api_info = {
                "name": api_util.change_name(api_util.standardize(api['api_name'])),
                "description": api["api_description"],
                "call_parameter": function_calling[idx]["call_parameter"],
                "call_response": function_calling[idx]["call_response"]
            }
            answer = self.build_answer(plan, api_info)
            if not self.answer_check(plan, function_calling[idx] , answer):
                print("skip: Q: {} A: {}".format(query, answer))
                return
            print("[subplan]:[{}]\n[sub answer]:[{}]".format(plan,answer))
            answer_list.append(answer)

            selection = self.build_selection(plan, api_name, api_description)
            print("[selection]:[{}]\n".format(selection))
            selection_list.append(selection)

        
        final_answer = self.build_final_answer(plan_list,answer_list,query,"")
        if not self.final_answer_check(query , final_answer):
            print("skip: Q: {} A: {}".format(query, final_answer))
            return
        print("[final_answer]:[{}]".format(final_answer))

        api["query"] = query
        api["plan_list"] = plan_list
        api["final_plan"] = ""
        api["answer_list"] = answer_list
        api["final_answer"] = final_answer
        api["selection_list"] = selection_list

        # import pdb;pdb.set_trace()
        fw = self._get_out_fw()
        with self._lock:
            fw.writelines(json.dumps(api) + "\n") 
    
    def build_query(self, api):
        instruction = self.template["single_query"]
        
        messages = []
        for i,item in enumerate(self.examples):
            function_call_chain = []
            for call in item["golden"]:
                function_call_chain.append(call["call_parameter"])
            function_calling = {
                "name": item["golden"][0]["name"],
                "description": item["golden"][0]["description"],
                "parameter_setting":item["golden"][0]["parameters"],
                "function_call_chain": function_call_chain
            }
            user_info = instruction.format(function_calling=function_call_chain)
            assistant_info = "query: {}".format(item["query"])
            messages.append({"content": user_info,  "role": "user"})
            messages.append({"content": assistant_info,  "role": "assistant"})

        function_calling = {
            "name": api_util.change_name(api_util.standardize(api['api_name'])),
            "description": api["api_description"],
            "parameter_setting":api["parameters"],
            "function_call_chain": api["call_parameter"]
        }
        user_info = instruction.format(function_calling=function_calling)
        messages.append({"content": user_info,  "role": "user"})
        result = self.policy(messages).replace("\\","")

        if "query:" in result:
            start_index = result.find("query:") + len("query:")
            result = result[start_index:]
        
        # result = result.split('\n')[0]
        return result
        
    def build_plan(self, target_query, target_api):
        instruction = self.template["parallel_plan"]
        messages = []
        for i, item in enumerate(self.examples):
            query = item["query"]
            api_info = {
                "name":item["golden"][0]["name"],
                "description":item["golden"][0]["description"]
            }
            call_parameters = [p['call_parameter'] for p in item["golden"]]
            user_info = instruction.format(query = query,function_info= json.dumps(api_info),call_parameters=call_parameters,call_num = len(call_parameters))
            assistant_info = "[Task Plan]:{}".format(item["tool_planning"])
            messages.append({"content": user_info,  "role": "user"})
            messages.append({"content": assistant_info,  "role": "assistant"})
            
        api_info = {
            "name":api_util.change_name(api_util.standardize(target_api['api_name'])),
            "description":target_api["api_description"]
            }
        call_parameters = target_api["call_parameter"]
        user_info = instruction.format(query = target_query,function_info= json.dumps(api_info),call_parameters=target_api['call_parameter'],call_num = len(target_api['call_parameter']))
        messages.append({"content": user_info,  "role": "user"})
        # import pdb;pdb.set_trace()
        result = self.policy(messages).replace("\\", "")

        # import pdb;pdb.set_trace()
        if "[Task Plan]:" in result:
            start_index = result.find("[Task Plan]:") + len("[Task Plan]:")
            result = result[start_index:]
        if "Task Plan:" in result:
            start_index = result.find("Task Plan:") + len("Task Plan:")
            result = result[start_index:]
        result = result.strip("[]")
        plan_list = []
        for plan in result.split('\n'):
            if len(plan) and plan[0].isdigit():
                plan_list.append(plan)
        if len(plan_list) != len(call_parameters):
            return []

        return plan_list
    
    def build_final_plan(self,query,plan_list):
        messages = []
        instruction = self.template["final_plan"]
        for i, item in enumerate(self.examples):
            user_info = "{}\n***\n[User Query]:{}\n***\n[Task Plan list]:{}".format(instruction, item["query"],item["tool_planning"])
            assistant_info = "[Final plan]:{}".format(item["final_plan"])
            messages.append({"content": user_info,  "role": "user"})
            messages.append({"content": assistant_info,  "role": "assistant"})
        messages.append({"content":"{}\n***\n[User Query]:{}\n***\n[Task Plan list]:{}".format(instruction,query,plan_list) ,  "role": "user"})
        final_plan = self.policy(messages)
        if 'Final plan:' in final_plan:
            start_index = final_plan.find('Final plan:') + len('Final plan:')
            final_plan = final_plan[start_index:].strip()
        
        # import pdb;pdb.set_trace()
        return final_plan

    