from abc import ABC
from abc import abstractmethod
from enum import Enum
from utils import api_util
import threading


class Action(Enum):

    Single = "single"
    Single_java = "single_java"
    Parallel = "parallel"
    Multiple = "multiple"
    MultipleParallel = "multiple_parallel"
    Format = "format"

class Generator(ABC):

    def __init__(self, profile):
        self.profile = profile

    @abstractmethod
    def run(self, *args, **kwargs):
        return NotImplemented

    @abstractmethod
    def load_generator(self, *args, **kwargs):
        return NotImplemented

class BaseGenerator(Generator):

    def __init__(self, profile, env, policy):
        super().__init__(profile)
        
        self.env = env
        self.policy = policy
        self._lock = threading.Lock()

    def run(self, *args, **kwargs):
        pass
        
    def load_generator(self, *args, **kwargs):
        pass
    
    def build_answer(self,query,function_info):
        instruction = self.template["single_answer"]
        
        messages = []
        for i,item in enumerate(self.examples):
            sample_api_info = {
                "name": item["golden"][0]['name'],
                "description": item["golden"][0]["description"],
                "call_parameter": item["golden"][0]["call_parameter"],
                "call_response": item["golden"][0]["call_response"]  
            }
            user_info = instruction.format(query=item["query"],function_info=sample_api_info)
            assistant_info = "[Answer]:{}".format(item["single_answer"])
            messages.append({"content": user_info,  "role": "user"})
            messages.append({"content": assistant_info,  "role": "assistant"})

        user_info = instruction.format(query=query, function_info=function_info)
        messages.append({"content": user_info,  "role": "user"})
        result = self.policy(messages).replace("\\","")
        
        if "[Answer]:" in result:
            answer_start = result.find("[Answer]:")
            result = result[answer_start+len("[Answer]:"):].strip()
        if "Answer:" in result:
            answer_start = result.find("Answer:")
            result = result[answer_start+len("Answer:"):].strip()
        
        return result
    
    def build_selection(self,target_plan, api_name, api_description):
        api_info = {
            "api_name":api_util.change_name(api_util.standardize(api_name)),
            "description":api_description
        }
        prompt = self.template["selection_reason"].format(function=api_info,plan=target_plan)
        result = self.policy(prompt).replace("\\","")

        if "[Selection reason]:" in result:
            start_index = result.find('[Selection reason]:') 
            result = result[start_index+len("[Selection reason]:"):]

        if '[' in  result and ']' in result:
            start_index = result.find('[') 
            end_idx = result.rfind(']')
            result = result[start_index+1:end_idx]
        return {"ID":0,"api_name":api_util.change_name(api_util.standardize(api_name)),"reason":result}
    
    def build_final_answer(self,plan_list,answer_list,query,final_plan):
        instruction = self.template["final_answer"]
        messages = []
        for i,item in enumerate(self.examples):
            user_info = instruction.format(query=item["query"],context=item["answer_context"])
            assistant_info = "[Answer]:{}".format(item["final_answer"])
            messages.append({"content": user_info,  "role": "user"})
            messages.append({"content": assistant_info,  "role": "assistant"})

        context = ""
        for idx in range(len(plan_list)):
            context = context + "Subtask{idx}: {task}\nSubanswer{idx}: {summary}\n".format(idx=idx,task=plan_list[idx],summary=answer_list[idx])
        user_info = instruction.format(query=query,context=context)
        messages.append({"content": user_info,  "role": "user"})
        answer = self.policy(messages).replace("\\","")

        # import pdb;pdb.set_trace()
        if "[Answer]:" in answer:
            start_index = answer.find('[Answer]:') + len('[Answer]:')
            answer = answer[start_index:]
        if "Answer:" in answer:
            answer_start = answer.find("Answer:")
            answer = answer[answer_start+len("Answer:"):].strip()
        return answer
    
    def query_check(self, query,parameter):
        prompt = self.template["query_check"].format(Query=query,parameters=parameter)
        result = self.policy(prompt)
        # import pdb;pdb.set_trace()
        if "YES" in result or "Yes" in result:
            return True
        else:  
            print(result)
            return False
    
    def answer_check(self, query, api_resp, answer):
        prompt = self.template["response_check"].format(query=query, function_resp=api_resp, answer=answer)        
        result = self.policy(prompt)
        if "YES" in result or "Yes" in result:
            return True
        else :
            print(result)
            return False
        
    def final_answer_check(self,query , answer):
        prompt = self.template["final_answer_check"].format(query=query, answer=answer)        
        result = self.policy(prompt)
        if "YES" in result or "Yes" in result:
            return True
        else :
            print(result)
            return False
