from tenacity import retry, stop_after_attempt
from .base import BaseGenerator
from utils import api_util
import json
import os
import random
import uuid


class MultipleParallelGenerator(BaseGenerator):

    def __init__(self, profile, env, policy):
        super().__init__(profile, env, policy)

        self.in_file = profile.load_input()[env.name]["multiple_parallel"]
        self.out_file = profile.load_output()[env.name]["multiple_parallel"]
        self.template = profile.load_template(env.name)
        self.examples = profile.load_examples(env.name, "multiple_parallel")

        self._fw = None

    def load_generator(self):
        
        for line in open(self.in_file).readlines():
            yield json.loads(line)

    def _get_out_fw(self):
        
        if self._fw is None:
            self._fw = open(self.out_file, 'a')
        return self._fw
    
    def run(self, api_set):
        '''
            api:dict{'api_1', 'api_2', 'api_3', 'api_4'}
                'api_1':list(dict{'api_id', 'tool_description', 'api_description', 'category_name', 'tool_name', 'api_name', 'method', 'parameters', 'call_parameter'})
        '''
        api_list,param_list = [] ,[]
        for idx,target_api in api_set.items():
            random_number = random.randint(1, min(3,len(target_api)))
            for api_calling in target_api[:random_number]:
                category_name = api_calling['category_name']
                tool_name = api_calling['tool_name']
                api_name = api_calling['api_name']
                api_param = api_calling['call_parameter'] # dict(single para setting)
                
                # API Validation
                try:
                    req_status, req_response = self.env(category_name, tool_name, api_name, api_param, fmt=True, check=True)
                    if not req_status:
                        return
                except:
                    return
                observation = json.dumps(req_response['response']).replace("\\", "").replace("\n\n", "")[:4096]
                api_calling["api_calling"]={"call_parameter":api_param,"call_response":observation}
                api_list.append(api_calling)

        query = self.build_query(api_list)
        print("[query]:[{}]\n".format(query))
        param_list = [api_calling['call_parameter'] for api_calling in api_list]
        if not self.query_check(query,param_list):
            print("skip: {}".format(query))
            return

        plan_list = self.build_plan(query, api_list)
        print("[plan_list]:{}\n".format(str(plan_list)))
        if len(plan_list) == 0:
            return 

        answer_list ,selection_list = [],[]
        for idx,plan in enumerate(plan_list):
            api_info = {
                "name": api_util.change_name(api_util.standardize(api_list[idx]['api_name'])),
                "description": api_list[idx]['api_description'],
                "call_parameter": api_list[idx]["api_calling"]["call_parameter"],
                "call_response": api_list[idx]["api_calling"]["call_response"]
            }
            answer = self.build_answer(plan, api_info)
            if not self.answer_check(plan, api_list[idx]["api_calling"] , answer):
                print("skip: Q: {} A: {}".format(query, answer))
                return
            print("[subplan]:[{}]\n[sub answer]:[{}]".format(plan,answer))
            answer_list.append(answer)

            selection = self.build_selection(plan, api_info["name"], api_info['description'])
            print("[selection]:[{}]\n".format(selection))
            selection_list.append(selection)
            # import pdb;pdb.set_trace()
        
        final_plan = ""
        answer_without_plan = self.build_final_answer(plan_list,answer_list,query,final_plan)
        if not self.final_answer_check(query , answer_without_plan):
            print("skip: Q: {} A: {}".format(query, answer_without_plan))
            return   
        print("[final_answer_without_plan]:[{}]".format(answer_without_plan))
    

        api_set["query"] = query
        api_set["golden_api_list"] = api_list
        api_set["plan_list"] = plan_list
        api_set["final_plan"] = final_plan
        api_set["answer_list"] = answer_list
        api_set["final_answer"] = answer_without_plan
        api_set["selection_list"] = selection_list
        
        # import pdb;pdb.set_trace()
        fw = self._get_out_fw()
        with self._lock:
            fw.writelines(json.dumps(api_set) + "\n") 
        
        # self._multiple_format(api_set)
    
    def build_query(self, api_list):
        instruction = self.template["multiple_query"]
        messages = []
        for i, item in enumerate(self.examples):
            function_call_chain = []
            for call in item["golden"]:
                function_calling = {
                    "name": call["name"],
                    "description": call["description"],
                    "parameter_setting":call["parameters"],
                    "call_parameter":call["call_parameter"]
                }
                function_call_chain.append(function_calling)
            user_info = instruction.format(function_calling=function_call_chain)
            assistant_info = "User Query:{}".format(item["query"])
            messages.append({"content": user_info,  "role": "user"})
            messages.append({"content": assistant_info,  "role": "assistant"})

        function_calling_chain = []
        for api in api_list:
            function_calling = {
                    "name": api_util.change_name(api_util.standardize(api['api_name'])),
                    "description": api["api_description"],
                    "parameter_setting":api["parameters"],
                    "call_parameter":api["call_parameter"]
                }
            function_calling_chain.append(function_calling)
        user_info = instruction.format(function_calling=function_calling_chain)
        messages.append({"content": user_info,  "role": "user"})
        result = self.policy(messages).replace("\\","")

        if "User Query:" in result:
            start_index = result.find("User Query:") + len("User Query:")
            result = result[start_index:]
        
        return result
    
    
    def build_plan(self, target_query, target_api_list):
        instruction = self.template["multiple_plan"]
        messages = []
        for i, item in enumerate(self.examples):
            query = item["query"]
            function_info = []

            for a in item["golden"]:
                function_info.append({
                    "name":a["name"],
                    "description":a["description"],
                    "call_parameter":a["call_parameter"]
                })
            user_info = instruction.format(query = query,function_info = function_info,call_num = len(item["golden"]))
            assistant_info = "[Task Plan]:{}".format(item["tool_planning"])
            messages.append({"content": user_info,  "role": "user"})
            messages.append({"content": assistant_info,  "role": "assistant"})
             
        function_info = []
        for api in target_api_list:
            function_info.append({
                    "name":api_util.change_name(api_util.standardize(api['api_name'])),
                    "description":api["api_description"],
                    "call_parameter":api["call_parameter"]
                })
        user_info = instruction.format(query = target_query,function_info= function_info,call_num = len(target_api_list))
        messages.append({"content": user_info,  "role": "user"})
        result = self.policy(messages)

        # import pdb;pdb.set_trace()
        result = result.replace("\\", "")
        if "[Task Plan]:" in result:
            start_index = result.find("[Task Plan]:") + len("[Task Plan]:")
            result = result[start_index:]
        result = result.strip("[]")
        plan_list = []
        for plan in result.split('\n'):
            if len(plan) and plan[0].isdigit():
                plan_list.append(plan)
        if len(plan_list) != len(target_api_list):
            return []
        return plan_list

    def build_final_plan(self,query,plan_list):
        messages = []
        instruction = self.template["final_plan"]
        for i, item in enumerate(self.examples):
            user_info = "{}\n***\n[User Query]:{}\n***\n[Task Plan list]:{}".format(instruction, item["query"],item["tool_planning"])
            assistant_info = "[Final plan]:{}".format(item["final_plan"])
            messages.append({"content": user_info,  "role": "user"})
            messages.append({"content": assistant_info,  "role": "assistant"})
        messages.append({"content":"{}\n***\n[User Query]:{}\n***\n[Task Plan list]:{}".format(instruction,query,plan_list) ,  "role": "user"})
        final_plan = self.policy(messages)
        if 'Final plan:' in final_plan:
            start_index = final_plan.find('Final plan:') + len('Final plan:')
            final_plan = final_plan[start_index:].strip()
        if "[" in result and "]" in result:
            start_index = result.find("[") + 1
            end_index = result.find("]")
            result = result[start_index:end_index]
        # import pdb;pdb.set_trace()
        return final_plan
    
    
    