from abc import ABC, abstractmethod
from src.prompt.cot_prompt import cot_system_prompt, cot_gsm8k_shots, cot_algebra_shots, cot_math_shots, cot_folio_shots
from src.prompt.pal_prompt import pal_system_prompt, pal_gsm8k_shots, pal_algebra_shots, pal_math_shots, pal_folio_shots
from src.prompt.codenl_prompt import codenl_math_shots, codenl_gsm8k_shots, codenl_algebra_shots, codenl_folio_shots
from src.prompt.nlcode_prompt import nlcode_math_shots, nlcode_gsm8k_shots, nlcode_algebra_shots, nlcode_folio_shots
from src.prompt.metamath_prompt import incmath_system_prompt, incmath_prompt
from src.prompt.codenl_one_round import codenl_one_round_system_prompt, codenl_one_round_base_prompt, codenl_one_round_shots_without_output, codenl_one_round_shots_with_output
from src.prompt.nlcode_one_round import nlcode_one_round_system_prompt, nlcode_one_round_base_prompt, nlcode_one_round_shots
import random
from openai import OpenAI
from typing import List, Union
from utils.parser_utils import extract_answer, extract_program
import os
from dotenv import load_dotenv
from utils.python_executor import PythonExecutor
from utils.logical_prover import evaluate_prover9
from utils.utils import read_jsonl, write_jsonl
from together import Together

load_dotenv(override=True)

class Reasoning(ABC):
    """
    Abstract class for reasoning
    """
    def __init__(self, model_name, zero_shot=False):
        """
        Args:    
            model_name: str, the name of the model to be used, like "gpt-3.5-turbo"
            zero_shot: bool, whether to use zero-shot learning
        """
        openai_organization = os.getenv('OPENAI_ORG_ID')
        openai_api_key = os.getenv('OPENAI_API_KEY')
        together_api_key = os.getenv('TOGETHER_API_KEY')
        if 'gpt' in model_name:
            self.client = OpenAI(
                organization=openai_organization,
                api_key=openai_api_key
                )
        else:
            self.client = Together(
                api_key = together_api_key
            )
        self.model_name = model_name
        self.zero_shot = zero_shot


    @abstractmethod
    def reason(self, question):
        """
        reason and get the answer to the question
        Args:
            question: str, the question to be answered
        Returns:
            answer: str, the answer to the question
            reasoning_path: str, the reasoning path from the OpenAI API
        """
        raise NotImplementedError


class OneRoundReasoning(Reasoning):
    """
    One round reasoning, call the client once to get the answer.
    """
    
    def reason(self, question, question_type, temperature=0, max_tokens=2048, verbose=False, num_shots=1):
        """
        Get the answer to the question
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "GSM8K"
            temperature: float, the temperature for the completion
            max_tokens: int, the maximum number of tokens to generate
            verbose: bool, whether to print the prompt
        Returns:
            answer: str, the answer to the question
            reasoning_path: str, the reasoning path from the OpenAI API
        """

        prompt = self.construct_prompt(question, question_type)
        if verbose:
            print(f"Prompt: {prompt}")
        try:
            completion = self.client.chat.completions.create(
                model = self.model_name,
                messages = [
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": prompt}
                ],
                temperature = temperature,
                max_tokens = max_tokens,
                n = num_shots,
            )
        except Exception as e:
            # print('---')
            print(f"Error whne calling the API: {e}")
            return '', ''
        
        reasoning_paths = [completion.choices[i].message.content for i in range(num_shots)]
        answers = [self.get_answer(reasoning_path) for reasoning_path in reasoning_paths]
        if num_shots == 1:
            return answers[0], reasoning_paths[0]
        # print(f"Question: {question}\nAnswer: {answer}")
        return answers, reasoning_paths

    @abstractmethod
    def construct_prompt(self, question, question_type):
        """
        Construct the prompt for the question
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "GSM8K"
            base_prompt: str, the base prompt for the question
        Returns:
            prompt: str, the prompt for the question
        """
        return NotImplementedError
        # if question_type.lower() == "math":
        #     prompt = self.base_math_prompt +  f"Question: {question}\nAnswer:"  
        # if question_type.lower() == "gsm8k":
        #     prompt = self.base_gsm8k_prompt + f"Question: {question}\nAnswer:"
        # else:
        #     raise ValueError(f"Invalid question type: {question_type}")
        # return prompt

    @abstractmethod
    def get_answer(self, reasoning_path):
        """
        Extract the answer from the completion
        Args:
            reasoning_path: str, the reasoning path from the OpenAI API
        Returns:
            answer: str, the answer to the question
        """
        raise NotImplementedError

class COT(OneRoundReasoning):
    """
    Chain of thought reasoning
    """
    def __init__(self, model_name, zero_shot):
        super().__init__(model_name, zero_shot)
        # self.model_name = model_name
        self.system_prompt = cot_system_prompt
        self.base_prompt = "Please think step by step. "

        # FOLIO instruction
        self.folio_instruction = """The following is a first-order logic (FOL) problem.
The problem is to determine whether the conclusion follows from the premises.
The premises are given in the form of a set of first-order logic sentences.
The conclusion is given in the form of a single first-order logic sentence.
The task is to evaluate the conclusion as 'True', 'False', or 'Uncertain' given the premises.
"""

        # print('The zero_shot is: ', self.zero_shot)
        if not self.zero_shot:
            self.base_prompt += "Here are some examples: \n"


    def construct_prompt(self, question, question_type):
        """
        Construct the prompt for the question
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "GSM8K"
        Returns:
            prompt: str, the prompt for the question
        """
        if self.zero_shot: # zero-shot prompting, don't use the examples
            return self.base_prompt + f"Question: {question}\nAnswer:"
        elif question_type.lower() == "gsm8k":
            prompt = self.base_prompt + '-----'.join(cot_gsm8k_shots) + f'Question: {question}\nAnswer:'
            return prompt
        elif question_type.lower() == "aime": ## left for future use
            prompt = self.base_prompt + '-----'.join(cot_algebra_shots[:4]) + f'Question: {question}\nAnswer:'
            return prompt
        elif question_type.lower() in ['algebra', 'prealgebra', 'intermediate algebra', 'precalculus', 'geometry', 'number theory', 'counting & probability']:
            prompt = self.base_prompt + '-----'.join(cot_math_shots[:4]) + f'Question: {question}\nAnswer:'
            return prompt
        elif question_type.lower() == "folio":
            prompt = self.folio_instruction \
                + self.base_prompt \
                + '-----'.join(cot_folio_shots) \
                + f"<PREMISES>\n{question['premises']}\n</PREMISES>\n<CONCLUSION>{question['conclusion']}</CONCLUSION><EVALUATE>"
            return prompt
        else:
            raise ValueError(f"Invalid question type: {question_type}")
        

    def get_answer(self, reasoning_path):
        """
        Extract the answer from the completion
        Args:
            reasoning_path: str, the reasoning path from the OpenAI API
        Returns:
            answer: str, the answer to the question
        """
        # todo: maybe add some regular expression to extract the answer
        
        return extract_answer(reasoning_path)

    #   return completion.split("The answer is: ")[1]

class PAL(OneRoundReasoning):
    """
    Program aided language model
    """
    def __init__(self, model_name, zero_shot=False):
        super().__init__(model_name, zero_shot)
        # self.model_name = model_name
        self.system_prompt = pal_system_prompt
        # "You will write python program to solve math problems. You will only write code blocks."
        self.base_prompt = "Let's use python to solve the math problem. You will only write code block. "

        # FOLIO instruction
        self.folio_instruction = """The following is a first-order logic (FOL) problem.
The problem is to determine whether the conclusion follows from the premises.
The premises are given in the form of a set of first-order logic sentences.
The conclusion is given in the form of a single first-order logic sentence.
The task is to translate each of the premises and conclusions into FOL expressions, so that the expressions can be evaluated by a theorem solver to determine whether the conclusion follows from the premises.
Expressions should be adhere to the format of the Python NLTK package logic module. Follow the format of examples shown below to format your output. 
"""

        if not zero_shot:
            self.base_prompt += "Here are some examples: \n"

    def construct_prompt(self, question, question_type):
        """
        Construct the prompt for the question
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "GSM8K"
        Returns:
            prompt: str, the prompt for the question
        """
        if self.zero_shot: # zero-shot prompting, don't use the examples
            return self.base_prompt + f"Question: {question}\n"
        elif question_type.lower() == "gsm8k":
            prompt = self.base_prompt + '-----'.join(pal_gsm8k_shots) + f"Question: {question}\n"
            return prompt
        elif question_type.lower() == "aime": ## left for future use
            prompt = self.base_prompt + '-----'.join(pal_algebra_shots[:4]) + f"Question: {question}\n"
            return prompt
        elif question_type.lower() in ['algebra', 'prealgebra', 'intermediate algebra', 'precalculus', 'geometry', 'number theory', 'counting & probability']:
            prompt = self.base_prompt + '-----'.join(pal_math_shots[:4]) + f"Question: {question}\n"
            return prompt
        elif question_type.lower() == "folio":
            prompt = self.folio_instruction + '-----'.join(pal_folio_shots) + f"<PREMISES>\n{question['premises']}\n</PREMISES>\n<CONCLUSION>{question['conclusion']}</CONCLUSION><EVALUATE>"
            return prompt
        else:
            raise ValueError(f"Invalid question type: {question_type}")

    def get_answer(self, reasoning_path):
        """
        Extract the answer from the completion
        Args:
            reasoning_path: str, the reasoning path from the OpenAI API
        Returns:
            answer: str, the answer to the question
        """
        
        code, token = extract_program(reasoning_path)
        if token == "PYTHON":
            executor = PythonExecutor(get_answer_expr="solution()")
            prediction, report = executor.apply(code)
        elif token == "FOL":
            # FOL evaluation
            premises, conclusion = code[:-1], code[-1]
            prediction = evaluate_prover9(premises=premises, conclusion=conclusion)
        return prediction


class CodeNL_OneRound(OneRoundReasoning):
    """
    First write a solution in Python code, then step-by-step analyze the problem based the code and its executed results in natural language to obtain the final answer.
    """
    def __init__(self, model_name, zero_shot=False):
        super().__init__(model_name, zero_shot)
        self.system_prompt = codenl_one_round_system_prompt
        self.base_prompt = codenl_one_round_base_prompt
        if not zero_shot:
            self.base_prompt += "Here are some examples: \n"

    def construct_prompt(self, question, question_type):
        """
        Construct the prompt for the question
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "GSM8K"
        Returns:
            prompt: str, the prompt for the question
        """
        if self.zero_shot: # zero-shot prompting, don't use the examples
            return self.base_prompt + f"Question: {question}\nAnswer:"
    
        elif question_type.lower() == "aime": ## left for future use
            prompt = self.base_prompt + '-----'.join(codenl_one_round_shots_without_output[:4]) + f'Question: {question}\nAnswer:'
            return prompt
        elif question_type.lower() in ['algebra', 'prealgebra', 'intermediate algebra', 'precalculus', 'geometry', 'number theory', 'counting & probability']:
            prompt = self.base_prompt + '-----'.join(codenl_one_round_shots_without_output[:4]) + f'Question: {question}\nAnswer:'
            return prompt
        elif question_type.lower() == "folio":
            prompt = self.base_prompt + '-----'.join(codenl_one_round_shots_without_output[:4]) + f'Question: {question}\nAnswer:'
            return prompt
        else:
            raise ValueError(f"Invalid question type: {question_type}")
        

    def get_answer(self, reasoning_path):
        """
        Extract the answer from the completion
        Args:
            reasoning_path: str, the reasoning path from the OpenAI API
        Returns:
            answer: str, the answer to the question
        """
        # todo: maybe add some regular expression to extract the answer
        
        return extract_answer(reasoning_path)

class NLCode_OneRound(OneRoundReasoning):
    """
    First write detailed reasoning path in natural language, then translate the reasoning into Python code. Display the final result in latex. 
    """
    def __init__(self, model_name, zero_shot=False):
        super().__init__(model_name, zero_shot)
        self.system_prompt = nlcode_one_round_system_prompt
        self.base_prompt = nlcode_one_round_base_prompt
        if not zero_shot:
            self.base_prompt += "Here are some examples: \n"

    def construct_prompt(self, question, question_type):
        """
        Construct the prompt for the question
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "GSM8K"
        Returns:
            prompt: str, the prompt for the question
        """
        if self.zero_shot: # zero-shot prompting, don't use the examples
            return self.base_prompt + f"Question: {question}\nAnswer:"
    
        elif question_type.lower() == "aime": ## left for future use
            prompt = self.base_prompt + '-----'.join(nlcode_one_round_shots[:4]) + f'Question: {question}\nAnswer:'
            return prompt
        elif question_type.lower() in ['algebra', 'prealgebra', 'intermediate algebra', 'precalculus', 'geometry', 'number theory', 'counting & probability']:
            prompt = self.base_prompt + '-----'.join(nlcode_one_round_shots[:4]) + f'Question: {question}\nAnswer:'
            return prompt
        else:
            raise ValueError(f"Invalid question type: {question_type}")
        

    def get_answer(self, reasoning_path):
        """
        Extract the answer from the completion
        Args:
            reasoning_path: str, the reasoning path from the OpenAI API
        Returns:
            answer: str, the answer to the question
        """
        executor = PythonExecutor(get_answer_expr="solution()")
        code, token = extract_program(reasoning_path)
        prediction, report = executor.apply(code)
    
        return prediction



class TwoRoundsReasoning(Reasoning):
    """
    Two rounds reasoning, call the client twice to get the answer.
    """

    def extract_first_round_results(self, first_round_method='cot', problem_type='algebra', hard_level='Level 5', is_train=False):
        """
        Extract the reasoning path from the first round
        Args:
            first_round_method: str, the method used in the first round
            problem_type: str, the type of the question
        Returns:
            answers: dict, the reasoning path from the first round
        """
        math_subsets = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra', 'precalculus', 'prealgebra']
        if problem_type in math_subsets:
            answers = {}
            if 'gpt' in self.model_name:
                # file_path = f'results/gpt/{self.model_name}_{problem_type}_{hard_level}_{first_round_method}_new.jsonl'
                file_path = f'new_results/gpt/{self.model_name}_{problem_type}_{hard_level}_{first_round_method}_greedy.jsonl'
                # print(f"File path: {file_path}")
            elif 'meta-llama' in self.model_name:
                file_path = f'new_results/{self.model_name}_{problem_type}_{hard_level}_{first_round_method}_greedy.jsonl'
            else:
                raise ValueError(f"Invalid model name: {self.model_name}")
            if is_train:
                file_path = file_path.replace(".jsonl", "_train.jsonl")
            
            if not os.path.exists(file_path):
                print(f"File {file_path} does not exist")
                return answers
            data = read_jsonl(file_path)
            for row in data:
                answers[row['idx']] = row['reasoning_path']
            return answers
        # FOLIO
        elif problem_type == "folio":
            answers = {}
            file_path = f"folio_results/{self.model_name}_folio_{first_round_method}.jsonl"
            if is_train:
                file_path = file_path.replace(".jsonl", "_train.jsonl")
            if not os.path.exists(file_path):
                print(f"File {file_path} does not exist")

            data = read_jsonl(file_path)
            for row in data:
                answers[row['example_id']] = row['reasoning_path']  
            return answers

    def reason(self, question, question_type, idx, temperature=0, max_tokens=2048, verbose=False, num_shots=1):
        """
        reason and get the answer to the question via two rounds of reasoning
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "gsm8k"
            idx: int, the index of the question
            temperature: float, the temperature for the completion
            max_tokens: int, the maximum number of tokens to generate
            verbose: bool, whether to print the prompt
            is_train: bool, whether to use the training data

        Returns:
            answer: str, the answer to the question
            reasoning_path: str, the reasoning path from the OpenAI API
        """
        ##------first round------##
        # prompt1 = self.construct_prompt_first_round(question, question_type)
        # if verbose:
        #     print(f"First round prompt: {prompt1}")
        # try:
        #     completion = self.client.chat.completions.create(
        #         model = self.model_name,
        #         messages = [
        #             {"role": "system", "content": self.system_prompt_1},
        #             {"role": "user", "content": prompt1}
        #         ],
        #         temperature = temperature,
        #         max_tokens = max_tokens,
        #     )
        # except Exception as e:
        #     print('Api error in first round call')
        #     print(f"Error: {e}")
        #     return '', '', ''

        # reasoning_path_1 = completion.choices[0].message.content
        # print(self.first_stage_answers)
        reasoning_path_1 = self.first_stage_answers[idx]

        ##------second round------##
        prompt2 = self.construct_prompt_second_round(question, question_type, reasoning_path_1)
        if verbose:
            print(f"Second round prompt: {prompt2}")
        try:
            completion = self.client.chat.completions.create(
                model = self.model_name,
                messages = [
                    {"role": "system", "content": self.system_prompt_2},
                    {"role": "user", "content": prompt2}
                ],
                temperature = temperature,
                max_tokens = max_tokens,
                n= num_shots,
            )
        except Exception as e:
            print('Api error in second round call')
            print(f"Error: {e}")
            return '', '', ''
        
        reasoning_path_2 = completion.choices[0].message.content

        answer = self.get_answer(reasoning_path_2)
        return answer, reasoning_path_1, reasoning_path_2
    
    @abstractmethod
    def construct_prompt_first_round(self, question, question_type):
        """
        Construct the prompt for the first round
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "GSM8K"
        Returns:
            prompt: str, the prompt for the question
        """
        raise NotImplementedError
    
    @abstractmethod
    def construct_prompt_second_round(self, question, question_type, reasoning_path_1):
        """
        Construct the prompt for the second round
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "GSM8K"
            reasoning_path_1: str, the reasoning path from the first round
        Returns:
            prompt: str, the prompt for the question
        """
        raise NotImplementedError
    
    @abstractmethod
    def get_answer(self, reasoning_path):
        """
        Extract the answer from the completion
        Args:
            reasoning_path: str, the reasoning path from the OpenAI API
        Returns:
            answer: str, the answer to the question
        """
        raise NotImplementedError
    
class CodeNL(TwoRoundsReasoning):
    """
    First generate code then generate natural language
    """
    def __init__(self, model_name, data_type, zero_shot=False, is_train=False, hard_level='Level 5'):
        super().__init__(model_name, zero_shot)
        self.system_prompt_1 = pal_system_prompt
        self.system_prompt_2 = cot_system_prompt

        self.base_prompt_1 = "Let's use python to solve the math problem. "
        self.base_prompt_2 = "Please think step by step about the question based on the provided code(it may be wrong) and the executed output. "

        self.folio_base_prompt_1 = "Translate each of the premises and conclusion into first-order logic (FOL) expressions, so that the expression can be evaluated by a theorm solver to determine whether the conclusion follows from the premises or not. Expressions should be adhere to the format of the Python NLTK package logic module."
        self.folio_base_prompt_2 = "Please think step by step to determine whether the conclusion follows from the premises. The task is to evaluate the conclusion as 'True', 'False', or 'Uncertain' given the premises."

        if not zero_shot:
            self.base_prompt_1 += "Here are some examples: \n"
            self.base_prompt_2 += "Here are some examples: \n"
        
        self.first_stage_answers = self.extract_first_round_results(first_round_method='pal', problem_type=data_type, is_train=is_train, hard_level=hard_level)


    def construct_prompt_first_round(self, question, question_type):
        """
        Construct the prompt for the first round
        Args:
            question: str, the question to be answered
        Returns:
            prompt: str, the prompt for the question
        """
        # use four shots for the first round and four shots for the second round
        if self.zero_shot: # zero-shot prompting, don't use the examples
            return self.base_prompt_1 + f"Question: {question}"
        elif question_type.lower() == "gsm8k":
            return self.base_prompt_1 + '-----'.join(pal_gsm8k_shots[:4]) + f"Question: {question}"
        elif question_type.lower() == "aime": ## left for future use
            return self.base_prompt_1 + '-----'.join(pal_algebra_shots[:4]) + f"Question: {question}"
        elif question_type.lower() in ['algebra', 'prealgebra', 'intermediate algebra', 'precalculus', 'geometry', 'number theory', 'counting & probability']:
            return self.base_prompt_1 + '-----'.join(pal_math_shots[:4]) + f"Question: {question}"
        elif question_type == "folio":
            return self.folio_base_prompt_1 + "-----".join(pal_folio_shots) + f"<PREMISES>\n{question['premises']}\n</PREMISES>\n<CONCLUSION>{question['conclusion']}</CONCLUSION><EVALUATE>"
        else:
            raise ValueError(f"Invalid question type: {question_type}")


    def construct_prompt_second_round(self, question, question_type, reasoning_path_1):
        """
        Construct the prompt for the second round
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "GSM8K"
            reasoning_path_1: str, the reasoning path from the first round
        Returns:
            prompt: str, the prompt for the question
        """
        ### get the answer from the first round
        code, token = extract_program(reasoning_path_1)
        if token == "PYTHON":
            executor = PythonExecutor(get_answer_expr="solution()")
            prediction, report = executor.apply(code)
        elif token == "FOL":
            premises, conclusion = code[:-1], code[-1]
            prediction = evaluate_prover9(premises=premises, conclusion=conclusion)
            

        if self.zero_shot: # zero-shot prompting, don't use the examples
            return self.base_prompt_2 + f"Question: {question}\n Code: {reasoning_path_1}\nOutput: {prediction}\nAnswer:"
        
        elif question_type.lower() == "gsm8k":
            prompt = self.base_prompt_2 + '-----'.join(codenl_gsm8k_shots[:4])
        elif question_type.lower() == "aime":
            prompt = self.base_prompt_2 + '-----'.join(codenl_algebra_shots[:4])
        elif question_type.lower() in ['algebra', 'prealgebra', 'intermediate algebra', 'precalculus', 'geometry', 'number theory', 'counting & probability']:
            prompt = self.base_prompt_2 + '-----'.join(codenl_math_shots[:4])
        elif question_type == "folio":
            prompt = self.folio_base_prompt_2 + '-----'.join(codenl_folio_shots)

        else:
            raise ValueError(f"Invalid question type: {question_type}")
        
        if question_type == "folio":
            prompt += f"<PREMISES>\n{question['premises']}\n</PREMISES>\n\n<CONCLUSION>{question['conclusion']}\n</CONCLUSION>\n<EVALUATE>\n{reasoning_path_1}\n\n Output: {prediction}"
        else:
            if report == 'Done': # the result is computed
                prompt += "Question: " + question + "\nCode: " + reasoning_path_1+ "\nOutput: " + prediction + "\nAnswer:"
            
            else: # the result is not computed
                prompt +=  "Question: " + question + "\nCode: " + reasoning_path_1 + "\nOutput: " + report + "\nAnswer:"
        
        return prompt

    def get_answer(self, reasoning_path):
        """
        Extract the answer from the completion
        Args:
        reasoning_path: str, the reasoning path from the OpenAI API
        Returns:
        answer: str, the answer to the question
        """
        
        return extract_answer(reasoning_path)
    

class NLCode(TwoRoundsReasoning):
    """
    First generate natural language then translate to code
    """
    def __init__(self, model_name, data_type, zero_shot=False, is_train=False, hard_level='Level 5'):
        super().__init__(model_name, zero_shot)
        self.system_prompt_1 = cot_system_prompt
        self.system_prompt_2 = pal_system_prompt

        self.base_prompt_1 = "Please think step by step about the question."
        self.base_prompt_2 = "Write a Python code that translates a natural language (NL) reasoning path into executable code to answer a given question. The output of the generated code should be the final answer to the question. "

        self.folio_base_prompt_1 = "Please think step by step to determine whether the conclusion follows from the premises. Figure out what predicates and relations are used in each premise."
        self.folio_base_prompt_2 = "Translate each of the premises and conclusion into first-order logic (FOL) expressions, so that the expression can be evaluated by a theorm solver to determine whether the conclusion follows from the premises or not. Expressions should be adhere to the format of the Python NLTK package logic module."

        if not zero_shot:
            self.base_prompt_1 += "Here are some examples: \n"
            self.base_prompt_2 += "Here are some examples: \n"
        
        self.first_stage_answers = self.extract_first_round_results(first_round_method='cot', problem_type=data_type, is_train=is_train, hard_level=hard_level)


    def construct_prompt_first_round(self, question, question_type):
        """
        Construct the prompt for the first round
        Args:
            question: str, the question to be answered
        Returns:
            prompt: str, the prompt for the question
        """
        if self.zero_shot: # zero-shot prompting, don't use the examples
            return self.base_prompt_1 + f"Question: {question}\nAnswer:"
        elif question_type.lower() == "gsm8k":
            return self.base_prompt_1 + '-----'.join(cot_gsm8k_shots[:4]) + f"Question: {question}\nAnswer:"
        elif question_type.lower() == "aime":
            return self.base_prompt_1 + '-----'.join(cot_algebra_shots[:4]) + f"Question: {question}\nAnswer:"
        elif question_type.lower() in ['algebra', 'prealgebra', 'intermediate algebra', 'precalculus', 'geometry', 'number theory', 'counting & probability']:
            return self.base_prompt_1 + '-----'.join(cot_math_shots[:4]) + f"Question: {question}\nAnswer:"
        elif question_type == "folio":
            return self.folio_base_prompt_1 + "-----".join(cot_folio_shots) + f"<PREMISES>\n{question['premises']}\n</PREMISES>\n<CONCLUSION>{question['conclusion']}</CONCLUSION><EVALUATE>"
        else:
            raise ValueError(f"Invalid question type: {question_type}")
        
    def construct_prompt_second_round(self, question, question_type, reasoning_path_1):
        """
        Construct the prompt for the second round
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "GSM8K"
            reasoning_path_1: str, the reasoning path from the first round
        Returns:
            prompt: str, the prompt for the question
        """
        if self.zero_shot: # zero-shot prompting, don't use the examples
            return self.base_prompt_2 + f"Question: {question}\nReasoning path: {reasoning_path_1}\nCode:"
        elif question_type.lower() == "gsm8k":
            prompt = self.base_prompt_2 + '-----'.join(nlcode_gsm8k_shots[:4])
        elif question_type.lower() == "aime": ## left for future use
            prompt = self.base_prompt_2 + '-----'.join(nlcode_algebra_shots[:4])
        elif question_type.lower() in ['algebra', 'prealgebra', 'intermediate algebra', 'precalculus', 'geometry', 'number theory', 'counting & probability']:
            prompt = self.base_prompt_2 + '-----'.join(nlcode_math_shots[:4])
        elif question_type == "folio":
            prompt = self.folio_base_prompt_2 + '-----'.join(nlcode_folio_shots)
        else:
            raise ValueError(f"Invalid question type: {question_type}")
        
        if question_type == "folio":
            prompt += f"<PREMISES>\n{question['premises']}\n</PREMISES>\n\n<CONCLUSION>{question['conclusion']}\n</CONCLUSION>\n\n{reasoning_path_1}\n\n<EVALUATE>"
        else:
            prompt += "Question: " + question + "\nReasoning path: " + reasoning_path_1 + "\nCode:"
        
        return prompt
    

    def get_answer(self, reasoning_path):
        """
        Extract the answer from the completion
        Args:
        reasoning_path: str, the reasoning path from the OpenAI API
        Returns:
        answer: str, the answer to the question
        """
        code, token = extract_program(reasoning_path)
        if token == "PYTHON":
            executor = PythonExecutor(get_answer_expr="solution()")
            prediction, report = executor.apply(code)
        elif token == "FOL":
            # FOL evaluation
            premises, conclusion = code[:-1], code[-1]
            prediction = evaluate_prover9(premises=premises, conclusion=conclusion)
        return prediction
    
        return prediction
    
def retrieve_answer_math(model_name, 
                         hard_levels= ['Level 5'], 
                         is_train=False):
        """
        Retrieve the answer from the stored file
        Args:
            model_name: str, the name of the model
           
        Returns:
            answer: {'cot': {'algebra':[{reasoning_path: str, answer: str}]}}

        """
        dataset_list = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra', 'precalculus', 'prealgebra']
        methods = ['cot', 'pal', 'codenl', 'nlcode']
        answers = {}
        for method_name in methods:
            answers[method_name] = {}
            for data_type in dataset_list:
                answers[method_name][data_type] = {}

                data = []

                # if isinstance(hard_level, str): # for the case of multiple hard levels
                #     hard_levels = [hard_level]
                # else:
                #     hard_levels = hard_level
                for hard_level in hard_levels:  
                    if 'gpt' in model_name:
                        file_path = f'results/gpt/{model_name}_{data_type}_{hard_level}_{method_name}_new_eval.jsonl'
                        # print(f"File path: {file_path}")
                    elif 'meta-llama' in model_name:
                        file_path = f'results/{model_name}_{data_type}_{hard_level}_{method_name}_new_eval.jsonl'
                    else:
                        raise ValueError(f"Invalid model name: {model_name}")    
                    if is_train:
                        file_path = file_path.replace('new', 'train')
                    
                    if not os.path.exists(file_path):
                        print(f"File {file_path} does not exist")
                        continue
                
                    data = data + read_jsonl(file_path)
                    
                for row in data:
                    if method_name.lower() == 'cot' or method_name.lower() == 'pal':
                        answer, reasoning_path, correct = row['pred_answer'], row['reasoning_path'], row['correct']
                    elif method_name.lower() == 'codenl' or method_name.lower() == 'nlcode':
                        answer, reasoning_path_1, reasoning_path_2, correct= row['pred_answer'], row['reasoning_path_1'], row['reasoning_path_2'], row['correct']
                        reasoning_path = reasoning_path_1 + "\n" + reasoning_path_2
                    else:
                        raise ValueError(f"Invalid method name: {method_name}")
                    answers[method_name][data_type][row['idx']] = {'question': row['question'], 
                                                                   'answer': answer, 
                                                                   'gold_reasoning_path': row['answer'],
                                                                   'level': row['level'],
                                                                   'reasoning_path': reasoning_path, 
                                                                   'ground_truth': row['ground_truth'],
                                                                   'correct': correct}
                # print('method name:', method_name, 'data type:', data_type)
                # print(answers[method_name][data_type].keys())
            print(f"Method {method_name} done")
        return answers
  
def retrieve_answer_math_greedy(model_name, 
                         hard_levels= ['Level 5'], 
                         methods=['cot', 'pal', 'codenl', 'nlcode'],
                         is_train=False):
        """
        Retrieve the answer from the stored file
        Args:
            model_name: str, the name of the model
           
        Returns:
            answer: {'cot': {'algebra':[{reasoning_path: str, answer: str}]}}

        """
        dataset_list = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra', 'precalculus', 'prealgebra']
        # methods = ['cot', 'pal', 'codenl', 'nlcode', 'codenl_single', 'nlcode_single']
        answers = {}
        for method_name in methods:
            answers[method_name] = {}
            for data_type in dataset_list:
                answers[method_name][data_type] = {}
                data = []
                # if isinstance(hard_level, str): # for the case of multiple hard levels
                #     hard_levels = [hard_level]
                # else:
                #     hard_levels = hard_level
                for hard_level in hard_levels:  
                    if 'gpt' in model_name:
                        file_path = f'new_results/gpt/{model_name}_{data_type}_{hard_level}_{method_name}_greedy.jsonl'
                        # print(f"File path: {file_path}")
                    elif 'meta-llama' in model_name:
                        file_path = f'new_results/{model_name}_{data_type}_{hard_level}_{method_name}_greedy.jsonl'
                    else:
                        raise ValueError(f"Invalid model name: {model_name}")  
                    if is_train:
                        file_path = file_path.replace('.jsonl', '_train.jsonl')
                    file_path = file_path.replace('.jsonl', '_eval.jsonl')
                    
                    if not os.path.exists(file_path):
                        print(f"File {file_path} does not exist")
                        continue
                
                    data = data + read_jsonl(file_path)
                    
                for row in data:
                    if method_name.lower() == 'cot' or method_name.lower() == 'pal' or method_name.lower() == 'codenl_single' or method_name.lower() == 'nlcode_single':
                        answer, reasoning_path, correct = row['pred_answer'], row['reasoning_path'], row['correct']
                    elif method_name.lower() == 'codenl' or method_name.lower() == 'nlcode':
                        answer, reasoning_path_1, reasoning_path_2, correct= row['pred_answer'], row['reasoning_path_1'], row['reasoning_path_2'], row['correct']
                        reasoning_path = reasoning_path_1 + "\n" + reasoning_path_2
                    else:
                        raise ValueError(f"Invalid method name: {method_name}")
                    
                    answers[method_name][data_type][row['idx']] = {'question': row['question'], 
                                                                   'answer': answer, 
                                                                   'gold_reasoning_path': row['answer'],
                                                                   'level': row['level'],
                                                                   'reasoning_path': reasoning_path, 
                                                                   'ground_truth': row['ground_truth'],
                                                                   'correct': correct}
                # print('method name:', method_name, 'data type:', data_type)
                # print(answers[method_name][data_type].keys())
            print(f"Method {method_name} done")
        return answers
class INCMath(Reasoning):
    """
    Meta-Math reasoning
    First ask model to dicide which reasoning method to use, then use the corresponding reasoning method
    """

    def __init__(self, decision_model, reasoning_model, hard_level='Level 5', is_train=False, zero_shot=False, retry=3):
        """
        In this method, zero shot decides whethe we will use 
        """
        super().__init__(decision_model, zero_shot)
        self.system_prompt = incmath_system_prompt
        self.base_prompt = incmath_prompt
        self.zero_shot = zero_shot
        self.retry = retry
        
        if isinstance(hard_level, str): # for the case of multiple hard levels
            hard_levels = [hard_level]
        else:
            hard_levels = hard_level

        self.math_answers = retrieve_answer_math(reasoning_model, hard_levels, is_train)

    def extract_decision(self, raw_answer):
        """
        Extract the decision from the raw answer
        Args:
            raw_answer: str, the raw answer from the completion
        Returns:
            decision: str, the decision made by the model
        """
        return extract_answer(raw_answer)

    def construct_prompt(self, question, question_type):
        """
        Construct the prompt for the question
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "GSM8K"
        Returns:
            prompt: str, the prompt for the question
        """
        return self.base_prompt + f"\n Here is the question:\n {question}"

    def reason(self, question, question_type, idx, temperature=0, max_tokens=2048, verbose=False, num_shots=1):
        """
        reason and get the answer to the question
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "GSM8K"
            idx: int, the index of the question
            temperature: float, the temperature for the completion
            max_tokens: int, the maximum number of tokens to generate
            verbose: bool, whether to print the prompt
        Returns:
            answer: str, the answer to the question
            reasoning_path: str, the reasoning path from the OpenAI API
        """
        ##------first round------##
        prompt = self.construct_prompt(question, question_type)

        if verbose:
            print(f"First round prompt: {prompt}")

        # flag = False
        # raw_decision = ''
        # decision = 'cot'
        # flag = True
        for _ in range(self.retry):
            try:
                completion = self.client.chat.completions.create(
                    model = self.model_name,
                    messages = [
                        {"role": "system", "content": self.system_prompt},
                        {"role": "user", "content": prompt}
                    ],
                    temperature = temperature,
                    max_tokens = max_tokens,
                    n= num_shots,
                )
            except Exception as e:
                print('Api error in first round call')
                print(f"Error: {e}")
                return '', '', '', ''

            raw_decision = completion.choices[0].message.content
            print('raw decision', raw_decision)
            decision = self.extract_decision(raw_decision)
            if decision.lower() in ['cot', 'pal', 'codenl', 'nlcode']:
                flag = True # the decision is valid
                break
        
        if not flag:
            print(f"Failed to get the decision from the model")
            print('Here is the raw decision:', raw_decision)
            # return None, None, None, None
            decision = random.choice(['cot', 'pal', 'codenl', 'nlcode'])
        
        print(f"Decision: {decision}")
        decision = decision.lower()
        ##------second round------##

        ## first try to get the answer from the stored file
        if question_type.lower() in ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra', 'precalculus', 'prealgebra']:
            # try to get the answer from the stored file
            # print(self.math_answers[decision][question_type][idx])
            try:
                # print(idx)
                answer = self.math_answers[decision][question_type][idx]['answer']
                reasoning_path = self.math_answers[decision][question_type][idx]['reasoning_path']
                return answer, decision, raw_decision, reasoning_path
            except Exception as e:
                print(f"Error when getting the answer from the stored file: {e}")
                print('start reasoning')
                pass

        if decision.lower() == 'cot':
            # print('Using COT')
            reasoner = COT(self.model_name, self.zero_shot)
        elif decision.lower() == 'pal':
            reasoner = PAL(self.model_name, self.zero_shot)
        elif decision.lower() == 'codenl':
            reasoner = CodeNL(self.model_name, self.zero_shot)
        elif decision.lower() == 'nlcode':
            reasoner = NLCode(self.model_name, self.zero_shot)       

        if isinstance(reasoner, OneRoundReasoning):
            # print('One round reasoning')
            answer, reasoning_path = reasoner.reason(question, question_type, temperature, max_tokens, verbose)    
        else:
            # print('Two rounds reasoning')
            answer, reasoning_path_1, reasoning_path_2 = reasoner.reason(question, question_type, temperature, max_tokens, verbose)
            reasoning_path = reasoning_path_1 + "\n" + reasoning_path_2
        
        return answer, decision, raw_decision, reasoning_path
    

class MajorityVoting(Reasoning):
    
    """
    Implement the majority voting method to decide the reasoning method
    """
    def __init__(self, model_name, zero_shot=False, hard_levels=['Level 5'], methods=['cot', 'pal', 'codenl', 'nlcode'], is_train=False):
        """
        In this method, zero shot decides whethe we will use 
        """
        super().__init__(model_name, zero_shot)
        self.math_answers = retrieve_answer_math_greedy(model_name, hard_levels=hard_levels, methods=methods, is_train=is_train)
    
    def reason(self, question, question_type, idx, temperature=0, max_tokens=2048, verbose=False):
        """
        reason and get the answer to the question
        Args:
            question: str, the question to be answered
            question_type: str, the type of the question, like "math", "GSM8K"
            idx: int, the index of the question
            temperature: float, the temperature for the completion
            max_tokens: int, the maximum number of tokens to generate
            verbose: bool, whether to print the prompt
        Returns:
            answer: str, the answer to the question
            reasoning_path: str, the reasoning path from the OpenAI API
        """
        methods = ['cot', 'pal', 'codenl', 'nlcode']
        votes = {}
        for method in methods:
            votes[method] = self.math_answers[method][question_type][idx]['answer']
        
        # consider math equal for majority voting
        from utils.grader import math_equal
        from collections import Counter
        # collect the votes
        for method in methods:
            answer = self.math_answers[method][question_type][idx]['answer']
            # print(f"Method: {method}, Answer: {answer}")
            votes[method] = answer
        
        # normalize the votes
        normalized_votes = []
        for method, answer in votes.items():
            try:
                if any(math_equal(answer, existing_answer) for existing_answer in normalized_votes):
                    continue
                
            except Exception as e:
                # print(f"Error when comparing the answers: {e}")
                pass

            finally:
                normalized_votes.append(answer)
                
        # print('normalized votes', normalized_votes)
        # Count the votes
        counter = Counter(normalized_votes)
        majority_answers = counter.most_common()
        max_count = majority_answers[0][1]

        # break ties
        candidates = [answer for answer, count in majority_answers if count == max_count]

        # randomly choose one from the candidates
        majority_answer = random.choice(candidates)

        # print the results
        if verbose:
            print(f"Votes: {votes}")
            print(f"Normalized votes: {normalized_votes}")
            print(f"Majority answer: {majority_answer}")

        return majority_answer 

if __name__ == '__main__':
    model_name = 'gpt-4o-mini'
    answers = retrieve_answer_math(model_name)
    write_jsonl(answers, f'./{model_name}_answers.jsonl')