from utils.utils import read_jsonl, write_jsonl
from reasoner import *
import tqdm
from utils.python_executor import PythonExecutor
from eval import eval_file
class Experiment:
    def __init__(self, model_name, data_type, method, store_path, size=-1, zero_shot=False, verbose=False, reasoning_model=None, is_train=False, num_shots=1, **kwargs):
        """
        Initialize the Experiment
        Args:
            model_name: str, the name of the model, like gp-4o-mini
            dataset_name: str, the name of the dataset, like gsm8k
            method: str, the method to be used, like "cot"
            store_path: str, the path to store the results
            size: int, the size of the dataset
            zero_shot: bool, whether to use zero-shot prompt
            verbose: bool, whether to print the prompt
            reasoning_model: str, the reasoning model to be used, specific for metamath
        """
        self.model_name = model_name
        self.data_type = data_type
        self.method = method.lower()
        self.zero_shot = zero_shot
        self.is_train = is_train
        ## set the data path
        if is_train:
            base_path = 'data/train/'
        else:
            base_path = 'data/test/'

        self.data_path = base_path + data_type + '.jsonl'
        # if data_type == "gsm8k":
        #     self.data_path = 'data/gsm8k.jsonl'
        # elif data_type == "algebra":
        #     self.data_path = 'data/algebra.jsonl'
        # elif data_type == "counting & probability":
        #     self.data_path = 'data/counting & probability.jsonl'
        # elif data_type == "geometry":
        #     self.data_path = 'data/geometry.jsonl'
        # elif data_type == "number theory":
        #     self.data_path = 'data/number theory.jsonl'
        # elif data_type == "intermediate algebra":
        #     self.data_path = 'data/intermediate algebra.jsonl'
        # elif data_type == "precalculus":
        #     self.data_path = 'data/precalculus.jsonl'
        # elif data_type == "prealgebra":
        #     self.data_path = 'data/prealgebra.jsonl'
        # elif data_type == "aime":
        #     self.data_path = 'data/aime.jsonl'
        # else:
        #     raise NotImplementedError("data type isn't supported now")
        

        self.data = read_jsonl(self.data_path)

        if self.method == "cot":
            self.reasoner = COT(model_name, zero_shot=zero_shot)
        elif self.method == "pal":
            self.reasoner = PAL(model_name, zero_shot=zero_shot)
        elif self.method == "codenl_single":
            self.reasoner = CodeNL_OneRound(model_name, zero_shot=zero_shot)
        elif self.method == "nlcode_single":
            self.reasoner = NLCode_OneRound(model_name, zero_shot=zero_shot)
        elif self.method == "codenl":
            self.reasoner = CodeNL(model_name, data_type, zero_shot=zero_shot, is_train=is_train)
        elif self.method == "nlcode":
            self.reasoner = NLCode(model_name, data_type, zero_shot=zero_shot, is_train=is_train)
        elif self.method == "incmath":
            decision_model = model_name
            if not reasoning_model:
                reasoning_model = decision_model
            self.reasoner = INCMath(decision_model, reasoning_model, zero_shot=zero_shot, is_train=is_train)
        elif self.method == "majorvote":
            methods = kwargs.get('methods', ['cot', 'pal', 'codenl', 'nlcode'])
            is_train = kwargs.get('is_train', False)
            self.reasoner = MajorityVoting(model_name, zero_shot=zero_shot, methods=methods, is_train=is_train)
        else:
            raise NotImplementedError("method {} isn't supported now".format(self.method))
        self.store_path = store_path
        self.size = size
        self.verbose = verbose
        self.num_shots = num_shots
    
        
    def run(self):
        """
        reasoning the dataset and store the results
        """
        if self.size != -1:
            self.data = self.data[:self.size]
        for row in tqdm.tqdm(self.data):
            question = row['question']
            if isinstance(self.reasoner, OneRoundReasoning):
                answer, reasoning_path = self.reasoner.reason(question, self.data_type, verbose=self.verbose, num_shots= self.num_shots)
                if self.num_shots > 1:
                    row['pred_answers'] = answer # a list of answers
                    row['reasoning_paths'] = reasoning_path # a list of reasoning paths
                else:
                    row['pred_answer'] = answer
                    row['reasoning_path'] = reasoning_path

            elif isinstance(self.reasoner, TwoRoundsReasoning):
                answer, reasoning_path_1, reasoning_path_2 = self.reasoner.reason(question, self.data_type, row['idx'], verbose=self.verbose)
                row['pred_answer'] = answer 
                row['reasoning_path_1'] = reasoning_path_1
                row['reasoning_path_2'] = reasoning_path_2 
            
            elif isinstance(self.reasoner, INCMath):
                answer, decision, raw_answer, reasoning_path = self.reasoner.reason(question, self.data_type, row['idx'], verbose=self.verbose)
                row['pred_answer'] = answer
                row['decision'] = decision
                row['raw_answer'] = raw_answer
                row['reasoning_path'] = reasoning_path
            
            elif isinstance(self.reasoner, MajorityVoting):
                answer= self.reasoner.reason(question, self.data_type, row['idx'], verbose=self.verbose)
                # print(row['ground_truth']), print(answer)
                row['pred_answer'] = answer
        
        ## store the results
        write_jsonl(self.data, self.store_path)

    def __str__(self):
        return f"Experiment: {self.name}\nDescription: {self.description}\nData: {self.data}"


class HardLevelExperiment(Experiment):
    def __init__(self, model_name, data_type, method, store_path, hard_level='Level 5', size=-1, zero_shot=False, verbose=False, reasoning_model=None, is_train=False, num_shots=1, **kwargs):
        """
        Initialize the Experiment
        Args:
            model_name: str, the name of the model, like gp-4o-mini
            dataset_name: str, the name of the dataset, must be math
            method: str, the method to be used, like "cot"
            store_path: str, the path to store the results
            size: int, the size of the dataset
            zero_shot: bool, whether to use zero-shot prompt
            verbose: bool, whether to print the prompt
            reasoning_model: str, the reasoning model to be used, specific for metamath
            is_train: bool, whether to use the training data
        """
        assert data_type in ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra', 'precalculus', 'prealgebra'], "data type must be in math, your data type is " + data_type
        super().__init__(model_name, data_type, method, store_path, size, zero_shot, verbose, reasoning_model, is_train, num_shots=num_shots, **kwargs)
        # self.data = [row for row in self.data if row['level'] == 'hard']
        self.hard_level = hard_level
        self.data = [row for row in self.data if row['level'] == hard_level]
        print(f"Data size: {len(self.data)} of level {hard_level}")

        if self.method == "codenl":
            self.reasoner = CodeNL(model_name, data_type, zero_shot=zero_shot, is_train=is_train, hard_level=hard_level) # new to specify the hard level to retrieve the data
        elif self.method == "nlcode":
            self.reasoner = NLCode(model_name, data_type, zero_shot=zero_shot, is_train=is_train, hard_level=hard_level)
        
        elif self.method == "incmath":
            decision_model = model_name
            if not reasoning_model:
                reasoning_model = decision_model
            self.reasoner = INCMath(decision_model, reasoning_model, zero_shot=zero_shot, is_train=is_train, hard_level=hard_level)
        elif self.method == "majorvote":
            hard_levels = [self.hard_level]
            methods = kwargs.get('methods', ['cot', 'pal', 'codenl', 'nlcode'])
            is_train = kwargs.get('is_train', False)
            self.reasoner = MajorityVoting(model_name, zero_shot=zero_shot, methods=methods, is_train=is_train, hard_levels=hard_levels)


class FOLIOExperiment(Experiment):
    def __init__(self, model_name, data_type, method, store_path, hard_level=None, size=-1, zero_shot=False, verbose=False, reasoning_model=None, is_train=False, num_shots=1):
        """
        Initialize the Experiment
        Args:
            model_name: str, the name of the model, like gp-4o-mini
            dataset_name: str, the name of the dataset, must be math
            method: str, the method to be used, like "cot"
            store_path: str, the path to store the results
            size: int, the size of the dataset
            zero_shot: bool, whether to use zero-shot prompt
            verbose: bool, whether to print the prompt
            reasoning_model: str, the reasoning model to be used, specific for metamath
            is_train: bool, whether to use the training data
        """
        assert data_type == 'folio'
        super().__init__(model_name, data_type, method, store_path, size, zero_shot, verbose, reasoning_model, is_train, num_shots=num_shots)
        # self.data = [row for row in self.data if row['level'] == 'hard']
        self.hard_level = hard_level
        # self.data = [row for row in self.data if row['level'] == hard_level]
        print(f"Data size: {len(self.data)}")

        if self.method == "codenl":
            self.reasoner = CodeNL(model_name, data_type, zero_shot=zero_shot, is_train=is_train, hard_level=hard_level) # new to specify the hard level to retrieve the data
        elif self.method == "nlcode":
            self.reasoner = NLCode(model_name, data_type, zero_shot=zero_shot, is_train=is_train, hard_level=hard_level)
        
        elif self.method == "incmath":
            decision_model = model_name
            if not reasoning_model:
                reasoning_model = decision_model
            self.reasoner = INCMath(decision_model, reasoning_model, zero_shot=zero_shot, is_train=is_train, hard_level=hard_level)


    def run(self):
        """
        reasoning the dataset and store the results
        """
        if self.size != -1:
            self.data = self.data[:self.size]
        for row in tqdm.tqdm(self.data):
            question = row
            if isinstance(self.reasoner, OneRoundReasoning):
                answer, reasoning_path = self.reasoner.reason(question, self.data_type, verbose=self.verbose, num_shots= self.num_shots)
                if self.num_shots > 1:
                    row['pred_answers'] = answer # a list of answers
                    row['reasoning_paths'] = reasoning_path # a list of reasoning paths
                else:
                    row['pred_answer'] = answer
                    row['reasoning_path'] = reasoning_path

            elif isinstance(self.reasoner, TwoRoundsReasoning):
                answer, reasoning_path_1, reasoning_path_2 = self.reasoner.reason(question, self.data_type, row['example_id'], verbose=self.verbose)
                row['pred_answer'] = answer 
                row['reasoning_path_1'] = reasoning_path_1
                row['reasoning_path_2'] = reasoning_path_2 
            
            elif isinstance(self.reasoner, INCMath):
                answer, decision, raw_answer, reasoning_path = self.reasoner.reason(question, self.data_type, row['idx'], verbose=self.verbose)
                row['pred_answer'] = answer
                row['decision'] = decision
                row['raw_answer'] = raw_answer
                row['reasoning_path'] = reasoning_path
            
            elif isinstance(self.reasoner, MajorityVoting):
                answer= self.reasoner.reason(question, self.data_type, row['idx'], verbose=self.verbose)
                # print(row['ground_truth']), print(answer)
                row['pred_answer'] = answer
        
        ## store the results
        write_jsonl(self.data, self.store_path)



def get_code_result(file):
    """
    Get the code result from the file
    Args:
        file: str, the path to the file
    Returns:
        code_result: str, the code result
    """
    data = read_jsonl(file)
    executor = PythonExecutor(get_answer_expr="solution()")
    for row in data:
        reasoning_path = row['reasoning_path']
        code = extract_program(reasoning_path)
        # print(code)
        prediction, report = executor.apply(code)
        print(prediction)
        row['pred_answer'] = prediction

    write_jsonl(data, file)

if __name__ == '__main__':
    # test cot
    model_name = "gpt-4o-mini"
    # api model
    # model_name = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
    # data_type = "gsm8k"
    dataset_list = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra', 'precalculus', 'prealgebra']
    # data_type = 'aime'
    method = 'metamath'# 
    num_shots = 3
    verbose = False
    zero_shot = False
    for data_type in dataset_list:
        if not zero_shot:
            store_path = f"results/{model_name}_{data_type}_{method}_nshots_{num_shots}.jsonl"
        else:
            store_path = f"results/{model_name}_{data_type}_{method}_nshots_{num_shots}_zero_shot_new.jsonl"
        exp = Experiment(model_name, data_type, method, store_path, size=-1, zero_shot=zero_shot, verbose=verbose)
        exp.run()
        eval_file(store_path)
        # get_code_result(store_path)