from utils.llm.chain import Chain

class Query2FOL(Chain):
    def __init__(self, prompts_dir, llm_name):
        prompt_f = f"{prompts_dir}"
        super().__init__(prompt_f, llm_name)
    
    def parse_response(self, response):
        try:
            parts = response.strip().split("following FOL query:")
            final_answer = parts[1].strip()
        except:
            final_answer = response
        return final_answer, response

class GetTypeAxiom(Chain):
    def __init__(self, prompts_dir, llm_name):
        prompt_f = f"{prompts_dir}"
        super().__init__(prompt_f, llm_name)
    
    def parse_response(self, response):
        try:
            parts = response.strip().split("First-Order Logic Statement:")
            final_answer = parts[1].strip()
        except:
            final_answer = response
        return final_answer, response
    


class GetMonolithicProof(Chain):
    def __init__(self, llm_name, dataset_name, prompts_dir):
        prompt_f = f"{prompts_dir}/GetMonolithicProof-{dataset_name}.yaml"
        self.dataset_name = dataset_name
        super().__init__(prompt_f, llm_name)
    
    def parse_response(self, input_response):
        input = input_response.lower()
        start_marker = "therefore,"
        end_marker = "<|eot_"
        start = input.find(start_marker)
        end = input.find(end_marker)
        if start == -1 and end == -1:
            response = input
        else:
            response = input[start:end]
        if self.dataset_name == 'Recipe-MPR':
            return response, input_response
        if 'false' in response.lower() and 'true' not in response.lower():
            final_answer = 'false'
        elif 'true' in response.lower() and 'false' not in response.lower():
            final_answer = 'true'
        else:
            final_answer = 'misformatted answer'
        return final_answer, input_response
    
class GetMonolithicProofRAG(Chain):
    def __init__(self, llm_name, dataset_name, prompts_dir):
        prompt_f = f"{prompts_dir}/GetMonolithicProofRAG-{dataset_name}.yaml"
        super().__init__(prompt_f, llm_name)
    def parse_response(self, input_response):
        input = input_response.lower()
        start_marker = "therefore,"
        end_marker = "<|eot_"
        start = input.find(start_marker)
        end = input.find(end_marker)
        if start == -1 and end == -1:
            response = input
        else:
            response = input[start:end]
        if 'false' in response.lower() and 'true' not in response.lower():
            final_answer = 'false'
        elif 'true' in response.lower() and 'false' not in response.lower():
            final_answer = 'true'
        else:
            final_answer = 'misformatted answer'
        return final_answer, input_response


class GetFewshotProof(Chain):
    def __init__(self, llm_name, dataset_name, prompts_dir):
        self.dataset_name = dataset_name
        prompt_f = f"{prompts_dir}/GetFewshotProof-{dataset_name}.yaml"
        super().__init__(prompt_f, llm_name)
        
    
    def parse_response(self, input_response):
        input = input_response.lower()
        start_marker = "therefore,"
        end_marker = "<|eot_"
        start = input.find(start_marker)
        end = input.find(end_marker)
        if start == -1 and end == -1:
            response = input
        else:
            response = input[start:end]
        if self.dataset_name == 'Recipe-MPR':
            return response, input_response
        if 'false' in response.lower() and 'true' not in response.lower():
            final_answer = 'false'
        elif 'true' in response.lower() and 'false' not in response.lower():
            final_answer = 'true'
        else:
            final_answer = 'misformatted answer'
        return final_answer, input_response
    
class GetFewshotProofRAG(Chain):
    def __init__(self, llm_name, dataset_name, prompts_dir):
        prompt_f = f"{prompts_dir}/GetFewShotProofRAG-{dataset_name}.yaml"
        super().__init__(prompt_f, llm_name)
    def parse_response(self, input_response):
        input = input_response.lower()
        start_marker = "therefore,"
        end_marker = "<|eot_"
        start = input.find(start_marker)
        end = input.find(end_marker)
        if start == -1 and end == -1:
            response = input
        else:
            response = input[start:end]
        if 'false' in response.lower() and 'true' not in response.lower():
            final_answer = 'false'
        elif 'true' in response.lower() and 'false' not in response.lower():
            final_answer = 'true'
        else:
            final_answer = 'misformatted answer'
        return final_answer, input_response