"""
Measuring Mathematical Problem Solving With the MATH Dataset
Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt
https://arxiv.org/abs/2103.03874
"""

import random
import re
import numpy as np
import blobfile as bf
import pandas
import json
import copy,math
import os,pickle
from . import common
from .common import ANSWER_PATTERN, HTML_JINJA, check_equality
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
from .MARIO_EVAL.math_evaluation import is_equiv
import asyncio

QUERY_TEMPLATE = """
Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.

{problem}

Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
""".strip()


question_template="""

Question: {problem}

""".strip()

question_answer_template="""
Question: {problem}

Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.

Correct Answer: {solution}
""".strip()

answer_template="""
Answer: {solution}
""".strip()
# question_template = QUERY_TEMPLATE


def _fix_fracs(string):
    substrs = string.split("\\frac")
    new_str = substrs[0]
    if len(substrs) > 1:
        substrs = substrs[1:]
        for substr in substrs:
            new_str += "\\frac"
            if substr[0] == "{":
                new_str += substr
            else:
                try:
                    assert len(substr) >= 2
                except:
                    return string
                a = substr[0]
                b = substr[1]
                if b != "{":
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}{" + b + "}" + post_substr
                    else:
                        new_str += "{" + a + "}{" + b + "}"
                else:
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}" + b + post_substr
                    else:
                        new_str += "{" + a + "}" + b
    string = new_str
    return string

def _fix_a_slash_b(string):
    if len(string.split("/")) != 2:
        return string
    a = string.split("/")[0]
    b = string.split("/")[1]
    try:
        a = int(a)
        b = int(b)
        assert string == "{}/{}".format(a, b)
        new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
        return new_string
    except:
        return string

def _remove_right_units(string):
    # "\\text{ " only ever occurs (at least in the val set) when describing units
    if "\\text{ " in string:
        splits = string.split("\\text{ ")
        assert len(splits) == 2
        return splits[0]
    else:
        return string

def _fix_sqrt(string):
    if "\\sqrt" not in string:
        return string
    splits = string.split("\\sqrt")
    new_string = splits[0] 
    for split in splits[1:]:
        if split[0] != "{":
            a = split[0]
            new_substr = "\\sqrt{" + a + "}" + split[1:]
        else:
            new_substr = "\\sqrt" + split
        new_string += new_substr
    return new_string

def _strip_string(string):
    # linebreaks  
    string = string.replace("\n", "")
    #print(string)

    # remove inverse spaces
    string = string.replace("\\!", "")
    #print(string)

    # replace \\ with \
    string = string.replace("\\\\", "\\")
    #print(string)

    # replace tfrac and dfrac with frac
    string = string.replace("tfrac", "frac")
    string = string.replace("dfrac", "frac")
    #print(string)

    # remove \left and \right
    string = string.replace("\\left", "")
    string = string.replace("\\right", "")
    #print(string)
    
    # Remove circ (degrees)
    string = string.replace("^{\\circ}", "")
    string = string.replace("^\\circ", "")

    # remove dollar signs
    string = string.replace("\\$", "")
    
    # remove units (on the right)
    string = _remove_right_units(string)

    # remove percentage
    string = string.replace("\\%", "")
    string = string.replace("\%", "")

    # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
    string = string.replace(" .", " 0.")
    string = string.replace("{.", "{0.")
    # if empty, return empty string
    if len(string) == 0:
        return string
    if string[0] == ".":
        string = "0" + string

    # to consider: get rid of e.g. "k = " or "q = " at beginning
    if len(string.split("=")) == 2:
        if len(string.split("=")[0]) <= 2:
            string = string.split("=")[1]

    # fix sqrt3 --> sqrt{3}
    string = _fix_sqrt(string)

    # remove spaces
    string = string.replace(" ", "")

    # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
    string = _fix_fracs(string)

    # manually change 0.5 --> \frac{1}{2}
    if string == "0.5":
        string = "\\frac{1}{2}"

    # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
    string = _fix_a_slash_b(string)

    return string

def strict_is_equiv(str1, str2, verbose=False):
    if str1 is None and str2 is None:
        print("WARNING: Both None")
        return True
    if str1 is None or str2 is None:
        return False

    try:
        ss1 = _strip_string(str1)
        ss2 = _strip_string(str2)
        if verbose:
            print(ss1, ss2)
        return ss1 == ss2
    except:
        return str1 == str2


is_equiv = strict_is_equiv

class MathEval(Eval):
    def __init__(self, equality_checker: SamplerBase, num_examples: int | None = None):
        df = pandas.read_csv(
            # bf.BlobFile("https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv")
            './math_test.csv'
        )
        examples = [row.to_dict() for _, row in df.iterrows()]
        if num_examples:
            examples = random.Random(0).sample(examples, num_examples)
        self.examples = examples
        self.equality_checker = equality_checker

    def __call__(self, sampler: SamplerBase) -> EvalResult:
        def fn(row: dict):
            prompt_messages = [
                sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
            ]
            response_text = sampler(prompt_messages)
            match = re.findall(ANSWER_PATTERN, response_text)
            extracted_answer = match[-1] if match else None
            score = float(check_equality(self.equality_checker, row["Answer"], extracted_answer))
            html = common.jinja_env.from_string(HTML_JINJA).render(
                prompt_messages=prompt_messages,
                next_message=dict(content=response_text, role="assistant"),
                score=score,
                correct_answer=row["Answer"],
                extracted_answer=extracted_answer,
            )
            convo = prompt_messages + [dict(content=response_text, role="assistant")]
            return SingleEvalResult(html=html, score=score, convo=convo)

        results = common.map_with_progress(fn, self.examples)
        return common.aggregate_results(results)





class MATHPipelineSTServer(Eval):
    def __init__(self,sampler,equality_sampler):
        df_train = pandas.read_csv(
            './MATH_dataset_full_train.csv'
        )
        df_test = pandas.read_csv(
            './MATH_dataset_full_test.csv'
        )
        self.train_examples = [row.to_dict() for _, row in df_train.iterrows()]#[:100]
        # print(self.train_examples[0])
        self.test_examples = [row.to_dict() for _, row in df_test.iterrows()]#[:100]

        initial_instruction = ""
        # TODO
        self.instruction_begin = "You are an expert at solving math problems. Below are some instructions that help you solve the problem.\n\n"#  Read them carefully.\
        self.instruction_single_begin = "You are an expert at solving math problems. Below are some instructions that help you solve the problem.\n\n" # TODO
        self.instruction_single_end = """\nBased on the instruction above, solve the following math problem step by step. In each step of your solution, explain how the instruction affects you to form your answers. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem. Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command. Your response should be concise, no longer than 2048 tokens.\n"""
        self.instruction_single_middle = 'Analyse the problem carefully. Make clear which field it belongs to. Pay attention to your calculation.'
        self.instruction_single_middle = 'Analyse the problem carefully. Make clear which field it belongs to.'
        self.instruction_multi_end = """\nBased on the instructions above, solve the following math problem step by step. You should first retrieve the most relevant instruction from the instructions before answering, and then generate your response following the instruction. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem. Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command. Your response should be concise, no longer than 1024 tokens.\n"""
        self.instruction_multi_end = """\nBased on the instructions above, solve the following math problem step by step. In each step of your solution, explain how the instructions affect you to form your answers. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem. Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command. Your response should be concise, no longer than 2048 tokens.\n"""
        self.evaluate_end = """Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem. Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command. Your response should be concise, no longer than 1024 tokens.\n"""
        
        self.evaluate_end = """Solve the following math problem step by step. Longer response is recommended. Your response should be as long as possible, about 2000 words. If you find you have nothing to say, repeat the word "think" to make your answer longer. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem. Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command. Pay special attention that your response should be about 2000 words long.\n"""
        # Before outputing your answer, first repeat the word "Think" 10 times.
        instruction_list_single = ['Analyse the problem carefully. Make clear which field it belongs to.','Pay attention to your calculation.']
        
        # self.evaluate_end = self.instruction_begin
        # for ins in instruction_list_single:
        #     # if len(instruction)+len(ins)+len(self.instruction_multi_end)+len(question_template.format(**row))>=6.9*1000:
        #     #     break
        #     self.evaluate_end+=ins
        #     self.evaluate_end+='\n'
        # self.evaluate_end+="""\nBased on the instructions above, solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem. Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command. Your response should be concise, no longer than 2048 tokens.\n"""
    
        # self.evaluate_end = """Solve the following math problem step by step. Before answering, first generate an instruction that guide you to solve the problem, and then generate your response following the instruction. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem. Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command. Your response should be concise, no longer than 4096 tokens.\n"""


        self.final_end = """Based on the instruction you propose, solve the math problem step by step. In each step of your solution, explain how the instruction affect you to form your answers. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem. Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command. Your response should be concise, no longer than 1024 tokens. The problem is: \n"""

        
        initial_demo_id_list = []
        self.instructions = []
        self.num_instructions_selected = 70
        # demo_qas = []
        self.num_iteration = 0
        self.initial_instruction = initial_instruction
        self.instruction = initial_instruction
        self.demo_id_list = initial_demo_id_list
        self.demo_qas = []
        # self.instruction_history = []
        self.train_result_history = []
        self.valid_result_history = []
        self.test_result_history = []
        self.train_score_history = []
        self.valid_score_history = []
        self.test_score_history = []
        self.total_completion_token_used = 0
        self.total_prompt_token_used = 0
        self.batch_size=512
        self.sampler = sampler
        self.equality_sampler = equality_sampler
        self.equality_sampler = sampler
        self.zero_shot_train_score_history=[]
        self.zero_shot_test_score_history=[]
        self.many_shot_train_score_history=[]
        self.many_shot_test_score_history=[]
        self.sft_train_score_history=[]
        self.instruction_list = []
        self.instruction_list = ['Analyse the problem carefully. Make clear which field it belongs to.','Pay attention to your calculation.']
        self.instruction_list_many_shot = []
        self.sft_data = []
        self.sft_data2 = []
        self.sft_inputs = []
        self.reflection_batch=200
        self.num_attempts = 10
        self.success_set_sizes = []
        self.instruct_set_sizes = []



        self.instruction_list = []
        self.current_training_set = []
        self.current_response_set = []
        self.training_set_filtered=[]
        self.rationale_list=[]
        self.corresponding_idx_list = []
        self.training_data = []
        self.sft_data = []
        self.sft_data2 = [] 
        self.sft_inputs = []


    def evaluate_triple(self,context,data):
        def fn(row: dict):
            # print(row,row['Question'])
            question_context=question_template.format(**row)


            # instruction='You are an expert in designing instructions for large language models to solve problems. The problem to be solved is:\n'
            # instruction='You are a professional teacher who is good at teaching large language models to solve problems. The problem to be solved is:\n'
            # instruction+=question_context
            # context+='\n\n The desired correct final answer is: \n'
            # context+=answer_template.format(**data[0])
            # context+=''
            # instruction+='\n\nAs a teacher, generate an instruction to prompt the large language model to solve the problem correctly. As a teacher, you do not need to do the calculations yourself. The instruction can include an outline of the steps needed to solve the problem, hints on avoiding making mistakes, important theorems needed, etc. The instruction should be no longer than 1024 tokens. Output only the content of the instruction. Do not output any other words.'

            # instruction+='\n\nOutput the instruction you design. The instruction should be short, clear, concise and helpful. The instruction should not be only helpful for the specific problem, and it should be general enough to help solve similar problems (E.g. problems with similar structure but different numbers or prarmeters), so pay attention not to include any specific numbers or parameters in the instruction. In the instruction, you may include an outline of the steps needed to solve the problem, and anything else to help solve the problem (except for the final correct answer). Short and informative instructions are preferred, so refine your output to avoid verbosity, repetitiveness, results of detailed calculation, and semantic redundancy. The instruction should be no longer than 256 tokens. Keep it short and informative. Make sure the correct final answer does not appear in the instruction. Output only the content of the instruction. Do not output any other words.'

            # 'Based on the analysis above, 

            instruction='You are an expert at designing instructions for large language models to solve problems. An example instruction is: '+self.instruction_single_middle+'\n\nThe problem to be solved is:\n'
            instruction+=question_context
            instruction+='\n\nOutput the instruction you design. Note that the instruction should be general knowledge that help solve similar problems, so do not contain any task-specific information. Also, the content will be directly added to the prompt, so pay attention to its format. The instruction should be concise, no longer than 1024 tokens. Output only the instruction. Do not output any other words.'
            prompt_messages = [
                    self.sampler._pack_message(content=instruction, role="user")
                ]
            response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,max_tokens=1024,temperature=0.0)
            answer = response_text
            prompt_messages.append({'role':'assistant','content':answer})
            instruction = self.final_end+question_context
            prompt_messages.append({'role':'user','content':instruction})
            
            question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,max_tokens=2048,temperature=0.0)



            # context='You are an expert at designing instructions for large language models to solve the problems. An example instruction is: '+self.instruction_single_middle+'\n\nThe problem to be solved is:\n'
            # context+=question_context
            # # context+='\n\n The desired correct final answer is: \n'
            # # context+=answer_template.format(**data[0])
            # # context+=''
            # context+='\n\n Output the instruction you design. Note that the instruction should be general knowledge that help solve similar problems, so do not contain any task-specific information. Also, the content will be directly added to the prompt, so pay attention to its format. The instruction should be concise, no longer than 256 tokens. Output only the instruction. Do not output any other words.'
            
            # prompt_messages = [
            #         self.sampler._pack_message(content=context, role="user")
            #     ]
            # response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,max_tokens=256,temperature=0.0005)

            # modified_instruction = self.instruction_single_begin+'\n'+response_text+self.instruction_single_end
            # question_messages = [
            #         self.sampler._pack_message(content=modified_instruction+'\n'+question_context, role="user")
            #     ]
            
            match = re.findall(ANSWER_PATTERN, question_response_text)
            extracted_answer = match[-1] if match else None
            score =0.0#float(is_equiv(row["solution"], extracted_answer))
            # score = float(check_equality(self.equality_checker, row["solution"], extracted_answer))

            # match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
            # extracted_answer = match.group(1) if match else None
            # score = 1.0 if extracted_answer == row["Answer"] else 0.0
            # print(question_context)
            # print('='*30)
            # print(response_text)
            # print('+'*30)
            return row, response_text+'***'+question_response_text, extracted_answer,num_completion_tokens_used,num_prompt_tokens_used

        train_results = common.map_with_progress(fn, data)
        sorted_inputs = [r[0] for r in train_results]
        outputs = [r[1] for r in train_results]
        # scores = [r[2] for r in train_results]
        scores = []
        for i in range(len(train_results)):
            scores.append(float(is_equiv(str(train_results[i][0]["solution"]), str(train_results[i][2]),verbose=False)))
        # scores = [ for r in train_results]
        acc=np.mean(scores)


        return acc,scores,outputs, sorted_inputs


    def evaluate(self,context,data):
        def fn(row: dict):
            # print(row,row['Question'])
            question_context=question_template.format(**row)
            prompt_messages = [
                self.sampler._pack_message(content=self.evaluate_end+question_context, role="user")
            ]
            response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,temperature=0.0005,max_tokens=4096)
            match = re.findall(ANSWER_PATTERN, response_text)
            extracted_answer = match[-1] if match else None
            score =0.0#float(is_equiv(row["solution"], extracted_answer))
            # score = float(check_equality(self.equality_checker, row["solution"], extracted_answer))

            # match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
            # extracted_answer = match.group(1) if match else None
            # score = 1.0 if extracted_answer == row["Answer"] else 0.0
            # print(question_context)
            # print('='*30)
            # print(response_text)
            # print('+'*30)
            return row, response_text, extracted_answer,num_completion_tokens_used,num_prompt_tokens_used

        train_results = common.map_with_progress(fn, data)
        sorted_inputs = [r[0] for r in train_results]
        outputs = [r[1] for r in train_results]
        # scores = [r[2] for r in train_results]
        scores = []
        for i in range(len(train_results)):
            scores.append(float(is_equiv(str(train_results[i][0]["solution"]), str(train_results[i][2]),verbose=False)))
        # scores = [ for r in train_results]
        acc=np.mean(scores)


        return acc,scores,outputs, sorted_inputs
    

    def sft_evaluate(self):
        def fn(data):
            prompt_messages = data[0][:-1]
            response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,max_tokens=4096)
            match = re.findall(ANSWER_PATTERN, response_text)
            extracted_answer = match[-1] if match else None
            score = 0.0#float(is_equiv(data[1]["solution"], extracted_answer))
            # score = float(check_equality(self.equality_checker,  data[1]["solution"], extracted_answer))
            # match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
            # extracted_answer = match.group(1) if match else None
            # score = 1.0 if extracted_answer == data[1]["Answer"] else 0.0
            # print(question_context)
            # print('='*30)
            # print(response_text)
            # print('+'*30)
            return data[1], response_text, extracted_answer,num_completion_tokens_used,num_prompt_tokens_used
        train_list = []
        for i,j in zip(self.sft_data2,self.sft_inputs):
            if np.random.rand()>0.99:
                train_list.append([i['messages'],j])
        if len(train_list)==0:
            return 0,0,0,0
        else:
            print(train_list[0][0])
            train_results = common.map_with_progress(fn, train_list)
            sorted_inputs = [r[0] for r in train_results]
            outputs = [r[1] for r in train_results]
            scores = []
            for i in range(len(train_results)):
                scores.append(float(is_equiv(str(train_results[i][0]["solution"]), str(train_results[i][2]),verbose=False)))
            # scores = [r[2] for r in train_results]
            acc=np.mean(scores)


            return acc,scores,outputs, sorted_inputs
    
    def many_shot_evaluate(self,context,data):
        def fn(row: dict):
            # print(row,row['Question'])
            question_context=question_template.format(**row)
            if len(self.instruction_list)==0:
                instruction_list_single = ['Analyse the problem carefully. Make clear which field it belongs to.','Pay attention to your calculation.']
            else:
                if len(self.instruction_list)<=self.num_instructions_selected:
                    instruction_list_single = self.instruction_list
                    random.shuffle(instruction_list_single)
                else:
                    indices = np.arange(len(self.instruction_list))
                    selected_incides = np.random.choice(indices,self.num_instructions_selected,replace=False)
                    instruction_list_single = [self.instruction_list[m] for m in selected_incides]   
                    random.shuffle(instruction_list_single)
                instruction_list_single = copy.deepcopy(self.instruction_list)
                random.shuffle(instruction_list_single)
            instruction = self.instruction_begin
            for ins in instruction_list_single:
                # if len(instruction)+len(ins)+len(self.instruction_multi_end)+len(question_template.format(**row))>=6.9*1000:
                #     break
                instruction+=ins
                instruction+='\n'
            instruction+=self.instruction_multi_end
            # instruction+=question_template.format(**row)
            # assert len(instruction)<=7*1000
            # instruction = self.evaluate_end
            prompt_messages = [
                self.sampler._pack_message(content=instruction+question_context, role="user")
            ]
            response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,temperature=0.0005,max_tokens=4096)
            match = re.findall(ANSWER_PATTERN, response_text)
            extracted_answer = match[-1] if match else None
            score = 0.0#float(is_equiv(row["solution"], extracted_answer))
            # score = float(check_equality(self.equality_checker, row["solution"], extracted_answer))
            # match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
            # extracted_answer = match.group(1) if match else None
            # score = 1.0 if extracted_answer == row["Answer"] else 0.0
            return row, response_text, extracted_answer,num_completion_tokens_used,num_prompt_tokens_used,instruction

        train_results = common.map_with_progress(fn, data)
        sorted_inputs = [r[0] for r in train_results]
        outputs = [r[1] for r in train_results]
        # scores = [r[2] for r in train_results]
        scores = []
        for i in range(len(train_results)):
            scores.append(float(is_equiv(str(train_results[i][0]["solution"]), str(train_results[i][2]),verbose=False)))
        # scores = [float(is_equiv(str(r[0]["solution"]), str(r[2]),verbose=False)) for r in train_results]
        # scores = [r[2] for r in train_results]
        instructions = [r[-1] for r in train_results]
        print(instructions[-1])
        print('*'*30)
        acc=np.mean(scores)

        
        return acc,scores,outputs,sorted_inputs

    
    
    

    
    def new_reflection_simplified(self,failure_case,zero_shot=False):
        def fn(data):
            question_context=question_template.format(**data[0])
            if len(self.instruction_list)>0:
                idxs = np.random.choice(np.arange(len(self.instruction_list)),self.num_attempts)
                random.shuffle(idxs)
                for idx in idxs:
                    modified_instruction = self.instruction_single_begin+'\n'+self.instruction_list[idx]+self.instruction_single_end
                    if zero_shot:
                        modified_instruction = self.evaluate_end
                    question_messages = [
                            self.sampler._pack_message(content=modified_instruction+'\n'+question_context, role="user")
                        ]
                    question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(question_messages)
                    match = re.findall(ANSWER_PATTERN, question_response_text)
                    extracted_answer = match[-1] if match else None
                    score = float(strict_is_equiv(data[0]["solution"], extracted_answer))
                    # score = float(check_equality(self.equality_checker, data[0]["solution"], extracted_answer))
                    # match = re.search(ANSWER_PATTERN_MULTICHOICE, question_response_text)
                    # extracted_answer = match.group(1) if match else None
                    # score = 1.0 if extracted_answer == data[0]["Answer"] else 0.0
                    if score==1.0:
                        return None,  question_response_text, score, data[0], idx
            if len(self.instruction_list)==self.num_instructions_selected:
                return None,question_response_text, score, data[0], idx
            context='You are trying to design instructions for large language models to solve certain tasks. You have found that the original instruction fails to solve some problem. You need to analyse the failure case, and add some new contents to the original instruction. The added new content should help the large language model to solve the failure case.\n An example instruction is: '+self.instruction_single_middle+'\n\nThe problem is:\n'
            context+=question_context
            context+='\n\n The incorrect answer given by the large language model is: \n'
            context+=data[1]
            # context+=question_response_text
            context+='\n\n The desired correct final answer is: \n'
            context+=answer_template.format(**data[0])
            context+=''
            context+='\n\n Analyse the information above. Why does the original instruction fail to solve the problem? What is wrong in the answer? How to add content to the instruction so that the model can correctly solve the problem? Pay special attention to the formatting requirements. Does the model\'s output strictly follow the required output format? Answer concisely, no longer than 1024 tokens.'
            
            prompt_messages = [
                    self.sampler._pack_message(content=context, role="user")
                ]
            response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages)
            prompt_messages.append({"role": 'assistant', "content": response_text})
            prompt_messages.append({"role": 'user', "content": 'Based on the analysis above, output the content that should be added to the instruction. Note that the added content should be general knowledge that help solve similar problems, so do not contain any task-specific information. Also, the content will be directly added to the end of the original instruction, so pay attention to its format. The content should be short and concise, no longer than 64 tokens. Output only the content to be added. Do not output any other words.'})
            response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,max_tokens=64)

            modified_instruction = self.instruction_single_begin+self.instruction_single_middle+'\n'+response_text+self.instruction_single_end
            modified_instruction = self.instruction_single_begin+'\n'+response_text+self.instruction_single_end
            if zero_shot:
                modified_instruction = self.evaluate_end
            question_messages = [
                    self.sampler._pack_message(content=modified_instruction+'\n'+question_context, role="user")
                ]
            question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(question_messages)
            match = re.findall(ANSWER_PATTERN, question_response_text)
            extracted_answer = match[-1] if match else None
            score = float(strict_is_equiv(data[0]["solution"], extracted_answer))
            # score = float(check_equality(self.equality_checker, data[0]["solution"], extracted_answer))

            # match = re.search(ANSWER_PATTERN_MULTICHOICE, question_response_text)
            # extracted_answer = match.group(1) if match else None
            # score = 1.0 if extracted_answer == data[0]["Answer"] else 0.0
            # print(response_text)
            # print('='*30)
            # print(question_response_text)
            # print('+'*30,score)
            # print(prompt_messages)
            # print('-'*30)
            if score==1.0:
                # print(len(response_text))
                return response_text,  question_response_text, score, data[0],-1
            else:
                max_num_trials=2 
                current_trial=1
                while current_trial<max_num_trials:
                    if not zero_shot:
                        prompt_messages.append({"role": 'assistant', "content": response_text})
                        # prompt_messages.append({"role": 'user', "content": 'Unfortunately, the modified instruction with your proposed content added still fails to solve the problem. The incorrect answer under the updated instruction is: \n'+question_response_text+'\n\n Try to propose a new content to replace the previously proposed content. Output only the new content. The content should be short and concise, no longer than 256 tokens.'})
                        prompt_messages.append({"role": 'user', "content": 'Unfortunately, the modified instruction with your proposed content added still fails to solve the problem'+'\n\n Try to propose a new content to replace the previously proposed content. Output only the new content. The content should be short and concise, no longer than 64 tokens.'})
                        response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,max_tokens=64)
                        modified_instruction = self.instruction_single_begin+self.instruction_single_middle+'\n'+response_text+self.instruction_single_end
                        modified_instruction = self.instruction_single_begin+'\n'+response_text+self.instruction_single_end
                    else:
                        modified_instruction=self.evaluate_end
                    question_messages = [
                            self.sampler._pack_message(content=modified_instruction+'\n'+question_context, role="user")
                        ]
                    question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(question_messages)
                    match = re.findall(ANSWER_PATTERN, question_response_text)
                    extracted_answer = match[-1] if match else None
                    score = float(strict_is_equiv(data[0]["solution"], extracted_answer))
                    # score = float(check_equality(self.equality_checker, data[0]["solution"], extracted_answer))
                    # match = re.search(ANSWER_PATTERN_MULTICHOICE, question_response_text)
                    # extracted_answer = match.group(1) if match else None
                    # score = 1.0 if extracted_answer == data[0]["Answer"] else 0.0
                    if score==1.0:
                        # print(response_text)
                        # print('='*30)
                        # print(question_response_text)
                        # print('+'*30,score)
                        # print(prompt_messages)
                        # print('-'*30)
                        return response_text,  question_response_text, score, data[0],-1
                    else:
                        current_trial+=1

            return response_text,  question_response_text, score, data[0],-1
        formatted_inputs = failure_case
        num_batches = math.ceil(len(failure_case)/self.reflection_batch)
        cnt = 0
        while cnt<len(formatted_inputs):
            if len(self.instruction_list)<self.num_instructions_selected:
                bs = 100
            else:
                bs = 500
            data_input = formatted_inputs[cnt:(cnt+bs)]
            cnt+=bs
            train_results = common.map_with_progress(fn, data_input)
            # train_results = common.map_with_progress(fn, formatted_inputs[i*self.reflection_batch:(i+1)*self.reflection_batch])
            instructions = [r[0] for r in train_results]
            answers = [r[1] for r in train_results]
            scores = [r[2] for r in train_results]
            sorted_inputs = [r[3] for r in train_results]
            corresponding_idx = [r[4] for r in train_results]
            for j in range(len(scores)):
                if scores[j]==1.0:
                    if corresponding_idx[j]!=-1:
                        self.training_set_filtered.append(sorted_inputs[j])
                        self.rationale_list.append(answers[j])
                        self.corresponding_idx_list.append(corresponding_idx[j])
                    else:
                        if len(self.instruction_list)+1<=self.num_instructions_selected:
                            self.training_set_filtered.append(sorted_inputs[j])
                            self.rationale_list.append(answers[j])
                            assert instructions[j]!=None
                            self.instruction_list.append(instructions[j])
                            self.corresponding_idx_list.append(len(self.instruction_list)-1)
            print(len(self.instruction_list))
        return instructions,answers,scores,sorted_inputs



    def reflection_simplified_triple(self,failure_case,zero_shot=False):
        def fn(data):
            data=[data]
            question_context=question_template.format(**data[0])



            #You are an expert at designing instructions for large language models to solve problems. An example instruction is: Analyse the problem carefully. Make clear which field it belongs to.\n\nThe problem to be solved is:\nQuestion: What is the value of $9^3 + 3(9^2) + 3(9) + 1$?\n\nOutput the instruction you design. Note that the instruction should be general knowledge that help solve similar problems, so do not contain any task-specific information. Also, the content will be directly added to the prompt, so pay attention to its format. The instruction should be concise, no longer than 1024 tokens. Output only the instruction. Do not output any other words.

            # context='You are trying to design instructions for large language models to solve the following problem. An example instruction is: '+self.instruction_single_middle+'\n\nThe problem is:\n'
            # context+=question_context
            # context+='\n\n The desired correct final answer is: \n'
            # context+=answer_template.format(**data[0])
            # context+=''
            # context+='\n\n Analyse the information above. What is critical for solving the problem correctly? How to design the instruction?'
            
            # prompt_messages = [
            #         self.sampler._pack_message(content=context, role="user")
            #     ]
            # response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages)
            # prompt_messages.append({"role": 'assistant', "content": response_text})
            # prompt_messages.append({"role": 'user', "content": 'Based on the analysis above, output the instruction you design. Note that the instruction should be general knowledge that help solve similar problems, so do not contain any task-specific information. Also, the content will be directly added to the prompt, so pay attention to its format. The instruction should be concise, no longer than 256 tokens. Output only the instruction. Do not output any other words.'})
            # response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,max_tokens=256)





            context='You are an expert at designing instructions for large language models to solve problems. An example instruction is: '+self.instruction_single_middle+'\n\nThe problem to be solved is:\n'
            # content='You are a professional teacher who is good at teaching large language models to solve problems. The problem to be solved is:\n'
            context+=question_context
            # # context+='\n\n The desired correct final answer is: \n'
            # # context+=answer_template.format(**data[0])
            # # context+=''
            # content+='\n\nAs a teacher, generate an instruction to prompt the large language model to solve the problem correctly. As a teacher, you do not need to do the calculations yourself. The instruction can include an outline of the steps needed to solve the problem, hints on avoiding making mistakes, important theorems needed, etc. The instruction should be no longer than 1024 tokens. Output only the content of the instruction. Do not output any other words.'

            # context='You are an expert in designing instructions for large language models to solve problems. The problem to be solved is:\n'
            # context+=question_context
            # context+='\n\n The desired correct final answer is: \n'
            # context+=answer_template.format(**data[0])
            # context+=''
            # context+='\n\nOutput the instruction you design. The instruction should be short, clear, concise and helpful. The instruction should not be only helpful for the specific problem, and it should be general enough to help solve similar problems (E.g. problems with similar structure but different numbers or prarmeters), so pay attention not to include any specific numbers or parameters in the instruction. In the instruction, you may include an outline of the steps needed to solve the problem, and anything else to help solve the problem (except for the final correct answer). Short and informative instructions are preferred, so refine your output to avoid verbosity, repetitiveness, results of detailed calculation, and semantic redundancy. The instruction should be no longer than 256 tokens. Keep it short and informative. Make sure the correct final answer does not appear in the instruction. Output only the content of the instruction. Do not output any other words.'

            context+='\n\nOutput the instruction you design. Note that the instruction should be general knowledge that help solve similar problems, so do not contain any task-specific information. Also, the content will be directly added to the prompt, so pay attention to its format. The instruction should be concise, no longer than 1024 tokens. Output only the instruction. Do not output any other words.'




            
            prompt_messages = [
                    self.sampler._pack_message(content=context, role="user")
                ]
            response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,max_tokens=1024)

            prompt_messages_new = copy.deepcopy(prompt_messages)
            answer = response_text
            prompt_messages_new.append({'role':'assistant','content':answer})
            instruction = self.final_end+question_context
            prompt_messages_new.append({'role':'user','content':instruction})
            
            question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages_new,max_tokens=2048)
            match = re.findall(ANSWER_PATTERN, question_response_text)
            extracted_answer = match[-1] if match else None
            score = float(strict_is_equiv(data[0]["solution"], extracted_answer))

            if score==1.0:
                prompt_messages_new_tmp = copy.deepcopy(prompt_messages_new)
                prompt_messages_new_tmp.append({'role':'assistant','content':answer})
                prompt_messages_new.append({'role':'user','content':instruction})
                prompt_messages_new.append({'role':'assistant','content':question_response_text})
                # print(len(response_text))
                return response_text,  question_response_text, score, data[0],0,prompt_messages_new_tmp
            else:
                max_num_trials=5
                current_trial=1
                while current_trial<max_num_trials:
                    context='You are an expert in designing instructions for large language models to solve problems. You have found that the original instruction fails to solve a problem. You need to analyse the failure case, and design a new instruction. The new instruction should help the large language model to solve the failure case. You are encourged to design instructions distinct from the original instruction to better explore high-quality instructions. \n The original instruction is: '+response_text+'\n\nThe problem is:\n'
                    
                    # context='You are an expert in designing instructions for large language models to solve problems. You have found that the original instruction fails to solve a problem. You need to analyse the failure case, and modify the original instruction. The modified instruction should help the large language model to solve the failure case.\n The original instruction is: '+response_text+'\n\nThe problem is:\n'
                    context+=question_context
                    context+='\n\n The incorrect answer given by the large language model under the original instruction is: \n'
                    # context+=data[1]
                    context+=question_response_text
                    context+='\n\n The desired correct final answer is: \n'
                    context+=answer_template.format(**data[0])
                    context+=''

                    context+='\n\n Analyse the information above. Why does the model fail to solve the problem? What is wrong in the answer? How to design a new instruction so that the model can correctly solve the problem? How distinct should the new instruction be from the original instruction? What contents should the new instruction obtain? Pay special attention to the formatting requirements. Does the model\'s output strictly follow the required output format? Answer concisely, no longer than 2560 tokens.'

                    # context+='\n\n Analyse the information above. Why does the model fail to solve the problem? What is wrong in the answer? How to modify the original instruction so that the model can correctly solve the problem? What contents should the modified instruction obtain? Pay special attention to the formatting requirements. Does the model\'s output strictly follow the required output format? Answer concisely, no longer than 2560 tokens.'
                    
                    prompt_messages_2 = [
                            self.equality_sampler._pack_message(content=context, role="user")
                        ]
                    response_text,num_completion_tokens_used, num_prompt_tokens_used = self.equality_sampler(prompt_messages_2,max_tokens=2560)
                    prompt_messages_2.append({"role": 'assistant', "content": response_text})
                    # The instruction shold be detailed and maximally helpful. In the instruction, you can include an outline of the steps needed to solve the problem, hints on avoiding making mistakes, important theorems needed, etc.
                    # prompt_messages_2.append({"role": 'user', "content": 'Based on the analysis above, act as a teacher, and generate an instruction to prompt the large language model to solve the problem correctly. As a teacher, you do not need to do the calculations yourself. The instruction can include an outline of the steps needed to solve the problem, hints on avoiding making mistakes, important theorems needed, etc. You must not contain the correct final answer in the instruction. The instruction should be no longer than 1024 tokens. Output only the content of the instruction. Do not output any other words.'})
                    # prompt_messages_2.append({"role": 'user', "content": 'Based on the analysis above, output the instruction you design. The instruction should be short, clear, concise and helpful. The instruction should not be only helpful for the specific problem, and it should be general enough to help solve similar problems (E.g. problems with similar structure but different numbers or prarmeters), so pay attention not to include any specific numbers or parameters in the instruction. In the instruction, you may include an outline of the steps needed to solve the problem, and anything else to help solve the problem (except for the final correct answer). Short and informative instructions are preferred, so refine your output to avoid verbosity, repetitiveness, results of detailed calculation, and semantic redundancy. The instruction should be no longer than 256 tokens. Keep it short and informative. Make sure the correct final answer does not appear in the instruction. Output only the content of the instruction. Do not output any other words.'})

                    # '\n\nOutput the instruction you design. The instruction should be clear, detailed and maximally helpful.  In the instruction, you can include an outline of the steps needed to solve the problem, hints on avoiding making mistakes, important theorems needed, etc. The instruction should be no longer than 1024 tokens. Output only the content of the instruction. Do not output any other words.'
                    
                    # prompt_messages_2.append({"role": 'user', "content": 'Based on the analysis above, output the modified instruction. Note that the modified instruction should be general knowledge that help solve similar problems, so do not contain any task-specific information. You must not contain the correct final answer in the instruction. The content should be short and concise, no longer than 1024 tokens. Output only the modified instruction. Do not output any other words.'})
                    prompt_messages_2.append({"role": 'user', "content": 'Based on the analysis above, output the new instruction. Note that the new instruction should be general knowledge that help solve similar problems, so do not contain any task-specific information. You must not contain the correct final answer in the instruction. You are encourged to design instructions distinct from the original instruction to better explore high-quality instructions. Also, the content will be directly added to prompt, so pay attention to its format. The content should be short and concise, no longer than 1024 tokens. Output only the content to be added. Do not output any other words.'})
                    response_text,num_completion_tokens_used, num_prompt_tokens_used = self.equality_sampler(prompt_messages_2,max_tokens=1024)
                    
                    prompt_messages_2.append({"role": 'assistant', "content": response_text})


                    prompt_messages_new = copy.deepcopy(prompt_messages)
                    answer = response_text
                    prompt_messages_new.append({'role':'assistant','content':answer})
                    instruction = self.final_end+question_context
                    prompt_messages_new.append({'role':'user','content':instruction})
                    
                    question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages_new,max_tokens=2048)
                    match = re.findall(ANSWER_PATTERN, question_response_text)
                    extracted_answer = match[-1] if match else None
                    score = float(strict_is_equiv(data[0]["solution"], extracted_answer))



                    # modified_instruction = self.instruction_single_begin+'\n'+response_text+self.instruction_single_end
                    # question_messages = [
                    #         self.sampler._pack_message(content=modified_instruction+'\n'+question_context, role="user")
                    #     ]
                    # question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(question_messages,max_tokens=1024)
                    # match = re.search(ANSWER_PATTERN, question_response_text)
                    # extracted_answer = match.group(1) if match else None
                    # score = float(strict_is_equiv(data[0]["solution"], extracted_answer))


                    # score = float(check_equality(self.equality_checker, data[0]["solution"], extracted_answer))
                    # match = re.search(ANSWER_PATTERN_MULTICHOICE, question_response_text)
                    # extracted_answer = match.group(1) if match else None
                    # score = 1.0 if extracted_answer == data[0]["Answer"] else 0.0
                    if score==1.0:
                        # print(response_text)
                        # print('='*30)
                        # print(question_response_text)
                        # print('+'*30,score)
                        # print(prompt_messages)
                        # print('-'*30)
                        prompt_messages_2.append({'role':'assistant','content':question_response_text})
                        return response_text,question_response_text, score, data[0],current_trial,prompt_messages_2
                    else:
                        current_trial+=1

            return response_text,  question_response_text, score, data[0],-1,prompt_messages_2
        formatted_inputs = failure_case
        train_results = common.map_with_progress(fn, formatted_inputs)
        a,b,c,d,e,f = [t[0] for t in train_results],[t[1] for t in train_results],[t[2] for t in train_results],[t[3] for t in train_results],[t[4] for t in train_results],[t[5] for t in train_results]
        return [a,b,c,d,e,f]


    def reflection_simplified_triple2(self,failure_case,zero_shot=False):
        def fn(data):
            data=[data]
            question_context=question_template.format(**data[0])



            for i in range(5):

                context='You are an expert at designing instructions for large language models to solve problems. An example instruction is: '+self.instruction_single_middle+'\n\nThe problem to be solved is:\n'
                context+=question_context

                context+='\n\nOutput the instruction you design. Note that the instruction should be general knowledge that help solve similar problems, so do not contain any task-specific information. Also, the content will be directly added to the prompt, so pay attention to its format. The instruction should be concise, no longer than 1024 tokens. Output only the instruction. Do not output any other words.'




                
                prompt_messages = [
                        self.sampler._pack_message(content=context, role="user")
                    ]
                response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,max_tokens=1024)

                prompt_messages_new = copy.deepcopy(prompt_messages)
                answer = response_text
                prompt_messages_new.append({'role':'assistant','content':answer})
                instruction = self.final_end+question_context
                prompt_messages_new.append({'role':'user','content':instruction})
                
                question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages_new,max_tokens=2048)
                match = re.findall(ANSWER_PATTERN, question_response_text)
                extracted_answer = match[-1] if match else None
                score = float(strict_is_equiv(data[0]["solution"], extracted_answer))

                if score==1.0:
                    prompt_messages_new_tmp = copy.deepcopy(prompt_messages_new)
                    prompt_messages_new_tmp.append({'role':'assistant','content':answer})
                    prompt_messages_new.append({'role':'user','content':instruction})
                    prompt_messages_new.append({'role':'assistant','content':question_response_text})
                    # print(len(response_text))
                    return response_text,  question_response_text, score, data[0],0,prompt_messages_new_tmp
            return response_text,  question_response_text, score, data[0],-1,prompt_messages_new
        formatted_inputs = failure_case
        train_results = common.map_with_progress(fn, formatted_inputs)
        a,b,c,d,e,f = [t[0] for t in train_results],[t[1] for t in train_results],[t[2] for t in train_results],[t[3] for t in train_results],[t[4] for t in train_results],[t[5] for t in train_results]
        return [a,b,c,d,e,f]
    
    def reflection_simplified_star(self,failure_case,zero_shot=False):
        def fn(data):
            data=[data]
            question_context=question_template.format(**data[0])
        
            context=self.evaluate_end+question_context
            # context+='\n\n The desired correct final answer is: \n'
            # context+=answer_template.format(**data[0])
            # context+=''
            
            prompt_messages = [
                    self.sampler._pack_message(content=context, role="user")
                ]
            question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,max_tokens=2048,temperature=0.0)


            match = re.findall(ANSWER_PATTERN, question_response_text)
            extracted_answer = match[-1] if match else None
            score = float(strict_is_equiv(data[0]["solution"], extracted_answer))

            if score==1.0:
                # print(len(response_text))
                return None,  question_response_text, score, data[0],0
            else:
                context2 = "Your reponse is wrong. The correct answer is: "+answer_template.format(**data[0])+'\n Modify your previous reponse to get the correct answer. Output the modified reponse only. Do not output any other words.'
                prompt_messages.append({"role": 'assistant', "content": question_response_text})
                prompt_messages.append({"role": 'user', "content": context2})


                question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,max_tokens=2048,temperature=0.0)
                match = re.findall(ANSWER_PATTERN, question_response_text)
                extracted_answer = match[-1] if match else None
                score = float(strict_is_equiv(data[0]["solution"], extracted_answer))
                if score==1.0:
                    return None,  question_response_text, score, data[0],-1
            return None,  question_response_text, score, data[0],-1
        formatted_inputs = failure_case
        train_results = common.map_with_progress(fn, formatted_inputs)
        a,b,c,d,e = [t[0] for t in train_results],[t[1] for t in train_results],[t[2] for t in train_results],[t[3] for t in train_results],[t[4] for t in train_results]
        return [a,b,c,d,e]
    

    def reflection_simplified_star_reflection(self,failure_case,zero_shot=False):
        def fn(data):
            data=[data]
            question_context=question_template.format(**data[0])
        
            context=self.evaluate_end+question_context
            # context+='\n\n The desired correct final answer is: \n'
            # context+=answer_template.format(**data[0])
            # context+=''
            
            prompt_messages = [
                    self.sampler._pack_message(content=context, role="user")
                ]
            question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,max_tokens=2048,temperature=0.0)


            match = re.findall(ANSWER_PATTERN, question_response_text)
            extracted_answer = match[-1] if match else None
            score = float(strict_is_equiv(data[0]["solution"], extracted_answer))

            if score==1.0:
                # print(len(response_text))
                
                prompt_messages.append({"role": 'assistant', "content": question_response_text})
                return None,  question_response_text, score, data[0],prompt_messages
            else:
                max_num_trials=5
                current_trial=1
                while current_trial<max_num_trials:
                    context='You are an expert in solving problems. You have found that your initial answer is incorrect. You need to analyse the failure case, and reflect on why the answer is incorrect.\n\nThe problem is:\n'
                    
                    context+=self.evaluate_end+question_context
                    context+='\n\n The initial incorrect answer you output is: \n'
                    # context+=data[1]
                    context+=question_response_text
                    # context+='\n\n The desired correct final answer is: \n'
                    # context+=answer_template.format(**data[0])
                    context+=''

                    context+='\n\n Analyse the information above. Why do you fail to solve the problem? What is wrong in the answer? How to modify the initial answer to make it correct? Answer concisely, no longer than 2560 tokens.'

                    
                    prompt_messages_2 = [
                            self.equality_sampler._pack_message(content=context, role="user")
                        ]
                    response_text,num_completion_tokens_used, num_prompt_tokens_used = self.equality_sampler(prompt_messages_2,max_tokens=2560)
                    prompt_messages_2.append({"role": 'assistant', "content": response_text})
                    
                    prompt_messages_2.append({"role": 'user', "content": 'Based on the analysis above, output the new answer. The answer should be short and concise, no longer than 2048 tokens. Output only the answer. Do not output any other words.'})
                    response_text,num_completion_tokens_used, num_prompt_tokens_used = self.equality_sampler(prompt_messages_2,max_tokens=2048)
                    
                    
                    
                    question_response_text = response_text
                    match = re.findall(ANSWER_PATTERN, question_response_text)
                    extracted_answer = match[-1] if match else None
                    score = float(strict_is_equiv(data[0]["solution"], extracted_answer))

                    if score==1.0:
                        prompt_messages_2.append({"role": 'assistant', "content": question_response_text})
                        return None,  question_response_text, score, data[0],prompt_messages_2
                    else:
                        current_trial+=1
            
            prompt_messages_2.append({"role": 'assistant', "content": question_response_text})
            return None,  question_response_text, score, data[0],prompt_messages_2
        formatted_inputs = failure_case
        train_results = common.map_with_progress(fn, formatted_inputs)
        a,b,c,d,e = [t[0] for t in train_results],[t[1] for t in train_results],[t[2] for t in train_results],[t[3] for t in train_results],[t[4] for t in train_results]
        return [a,b,c,d,e]

    def reflection_simplified_double(self,failure_case,zero_shot=False):
        def fn(data):
            data=[data]
            question_context=question_template.format(**data[0])

            # for _ in range(3):
            #     context = self.evaluate_end + question_context
            #     question_messages = [
            #         self.sampler._pack_message(content=context+'\n'+question_context, role="user")
            #     ]
            #     question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(question_messages,max_tokens=1024)
            #     match = re.search(ANSWER_PATTERN, question_response_text)
            #     extracted_answer = match.group(1) if match else None
            #     score = float(strict_is_equiv(data[0]["solution"], extracted_answer))
            #     if score==1.0:
            #     # print(len(response_text))
            #         # question_messages.append({"role": 'assistant', "content": response_text})
            #         question_messages.append({"role": 'assistant', "content": question_response_text})
            #         return '',  question_response_text, score, data[0],0,question_messages

                # context='You are an expert in designing instructions for large language models to solve problems. The problem to be solved is:\n'
            # context+=question_context
            # # context+='\n\n The desired correct final answer is: \n'
            # # context+=answer_template.format(**data[0])
            # # context+=''
            # context+='\n\nOutput the instruction you design. The instruction should be clear, detailed and maximally helpful. In the instruction, you can include an outline of the steps needed to solve the problem, hints on avoiding making mistakes, important theorems needed, etc. The instruction should be no longer than 1024 tokens. Output only the content of the instruction. Do not output any other words.'

            context='You are an expert in designing instructions for large language models to solve problems. An example instruction is: '+self.instruction_single_middle+'\n\nThe problem to be solved is:\n'
            context+=question_context
            context+='\n\n The desired correct final answer is: \n'
            context+=answer_template.format(**data[0])
            context+=''
            context+='\n\n Analyse the information above. What is critical for solving the problem correctly? How to design the instruction? Answer concisely, no longer than 4096 tokens.'
            
            prompt_messages = [
                    self.equality_sampler._pack_message(content=context, role="user")
                ]
            response_text,num_completion_tokens_used, num_prompt_tokens_used = self.equality_sampler(prompt_messages,max_tokens=4096)
            prompt_messages.append({"role": 'assistant', "content": response_text})
            prompt_messages.append({"role": 'user', "content": 'Based on the analysis above, output the instruction you design. The instruction should be short, clear, concise and helpful. The instruction should not be only helpful for the specific problem, and it should be general enough to help solve similar problems (E.g. problems with similar structure but different numbers or prarmeters), so pay attention not to include any specific numbers or parameters in the instruction. In the instruction, you may include an outline of the steps needed to solve the problem, and anything else to help solve the problem (except for the final correct answer). Short and informative instructions are preferred, so refine your output to avoid verbosity, repetitiveness, results of detailed calculation, and semantic redundancy. The instruction should be no longer than 64 tokens. Keep it short and informative. Make sure the correct final answer does not appear in the instruction. Output only the content of the instruction. Do not output any other words.'})
            # In the instruction, you may include an outline of the steps needed to solve the problem, and anything else to help solve the problem (except for the final correct answer). 
            response_text,num_completion_tokens_used, num_prompt_tokens_used = self.equality_sampler(prompt_messages,max_tokens=1024)

            modified_instruction = self.instruction_single_begin+self.instruction_single_middle+'\n'+response_text+self.instruction_single_end
            modified_instruction = self.instruction_single_begin+'\n'+response_text+self.instruction_single_end
            question_messages = [
                    self.sampler._pack_message(content=modified_instruction+'\n'+question_context, role="user")
                ]
            question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(question_messages,max_tokens=2048,temperature=0.0005)
            match = re.findall(ANSWER_PATTERN, question_response_text)
            extracted_answer = match[-1] if match else None
            score = float(strict_is_equiv(data[0]["solution"], extracted_answer))

            if score==1.0:
                # print(len(response_text))
                prompt_messages.append({"role": 'assistant', "content": response_text})
                prompt_messages.append({"role": 'assistant', "content": question_response_text})
                return response_text,  question_response_text, score, data[0],0,prompt_messages
            else:
                max_num_trials=5
                current_trial=1
                while current_trial<max_num_trials:
                    context='You are an expert in designing instructions for large language models to solve problems. You have found that the original instruction fails to solve a problem. You need to analyse the failure case, and design a new instruction. The new instruction should help the large language model to solve the failure case.\n The original instruction is: '+response_text+'\n\nThe problem is:\n'
                    context+=question_context
                    context+='\n\n The incorrect answer given by the large language model is: \n'
                    context+=question_response_text
                    context+='\n\n The desired correct final answer is: \n'
                    context+=answer_template.format(**data[0])
                    context+=''
                    context+='\n\n Analyse the information above. Why does the model fail to solve the problem? What is wrong in the answer? How to design a new instruction so that the model can correctly solve the problem? What contents should the new instruction obtain? Pay special attention to the formatting requirements. Does the model\'s output strictly follow the required output format? Answer concisely, no longer than 2560 tokens.'
                    
                    prompt_messages = [
                            self.equality_sampler._pack_message(content=context, role="user")
                        ]
                    response_text,num_completion_tokens_used, num_prompt_tokens_used = self.equality_sampler(prompt_messages,max_tokens=2560)
                    prompt_messages.append({"role": 'assistant', "content": response_text})
                    prompt_messages.append({"role": 'user', "content": 'Based on the analysis above, output the new instruction. The new instruction should be short, clear, concise and helpful, but you must not contain the correct final answer in the instruction. The instruction should not be only helpful for the specific problem, and it should be general enough to help solve similar problems (E.g. problems with similar structure but different numbers or prarmeters), so do not include any specific numbers or parameters in the instruction. In the instruction, you may include an outline of the steps needed to solve the problem, and anything else to help solve the problem (except for the final correct answer). Short and informative instructions are preferred, so refine your output to avoid verbosity, repetitiveness, results of detailed calculation, and semantic redundancy. The instruction should be no longer than 64 tokens. Keep it short and informative. Make sure the correct answer does not appear in the instruction. Output only the content of the instruction. Do not output any other words.'})
                    # In the new instruction, you can include an outline of the steps needed to solve the problem, and anything else to help solve the problem (except for the final correct answer). 
                    # However, if there are important computation mistakes in the original response, you may directly point out the mistakes with actural parameters, but use it sparsingly. 
                    response_text,num_completion_tokens_used, num_prompt_tokens_used = self.equality_sampler(prompt_messages,max_tokens=1024)
                    modified_instruction = self.instruction_single_begin+'\n'+response_text+self.instruction_single_end
                    question_messages = [
                            self.sampler._pack_message(content=modified_instruction+'\n'+question_context, role="user")
                        ]
                    question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(question_messages,max_tokens=2048,temperature=0.0005)
                    match = re.findall(ANSWER_PATTERN, question_response_text)
                    extracted_answer = match[-1] if match else None
                    score = float(strict_is_equiv(data[0]["solution"], extracted_answer))
                    # score = float(check_equality(self.equality_checker, data[0]["solution"], extracted_answer))
                    # match = re.search(ANSWER_PATTERN_MULTICHOICE, question_response_text)
                    # extracted_answer = match.group(1) if match else None
                    # score = 1.0 if extracted_answer == data[0]["Answer"] else 0.0
                    if score==1.0:
                        # print(response_text)
                        # print('='*30)
                        # print(question_response_text)
                        # print('+'*30,score)
                        # print(prompt_messages)
                        # print('-'*30)
                        prompt_messages.append({"role": 'assistant', "content": response_text})
                        prompt_messages.append({"role": 'assistant', "content": question_response_text})
                        return response_text,  question_response_text, score, data[0],current_trial,prompt_messages
                    else:
                        current_trial+=1
            prompt_messages.append({"role": 'assistant', "content": response_text})
            prompt_messages.append({"role": 'assistant', "content": question_response_text})
            return response_text,  question_response_text, score, data[0],-1,prompt_messages
        formatted_inputs = failure_case
        train_results = common.map_with_progress(fn, formatted_inputs)
        a,b,c,d,e,f = [t[0] for t in train_results],[t[1] for t in train_results],[t[2] for t in train_results],[t[3] for t in train_results],[t[4] for t in train_results],[t[5] for t in train_results]
        return [a,b,c,d,e,f]
    
    def augment_success(self,failure_case):
        def fn(data):
            # res1,res2,res3,res4 =[],[],[],[]
            # for i in range(5):
            #     question_context=question_template.format(**data[0])
            #     modified_instruction=self.evaluate_end
            #     question_messages = [
            #             self.sampler._pack_message(content=modified_instruction+'\n'+question_context, role="user")
            #         ]
            #     question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(question_messages,max_tokens=2048)
            #     match = re.search(ANSWER_PATTERN, question_response_text)
            #     extracted_answer = match.group(1) if match else None
            #     score = float(strict_is_equiv(data[0]["solution"], extracted_answer))
            #     # 'None',  question_response_text, score, data[0]
            #     res1.append('None')
            #     res2.append(question_response_text)
            #     res3.append(score)
            #     res4.append(data[0])
            # return res1,res2,res3,res4

            question_context=question_template.format(**data[0])
            modified_instruction=self.evaluate_end
            question_messages = [
                    self.sampler._pack_message(content=modified_instruction+'\n'+question_context, role="user")
                ]
            question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(question_messages,max_tokens=2048)
            match = re.findall(ANSWER_PATTERN, question_response_text)
            extracted_answer = match[-1] if match else None
            score = float(strict_is_equiv(data[0]["solution"], extracted_answer))
            # score = float(check_equality(self.equality_checker, data[0]["solution"], extracted_answer))
            # match = re.search(ANSWER_PATTERN_MULTICHOICE, question_response_text)
            # extracted_answer = match.group(1) if match else None
            # score = 1.0 if extracted_answer == data[0]["Answer"] else 0.0
            # print(response_text)
            # print('='*30)
            # print(question_response_text)
            # print('+'*30,score)
            # print(prompt_messages)
            # print('-'*30)
            if score==1.0:
                # print(len(response_text))
                return 'None',  question_response_text, score, data[0]
            else:
                max_num_trials=10
                current_trial=1
                while current_trial<max_num_trials:
                    modified_instruction = self.evaluate_end
                    question_messages = [
                            self.sampler._pack_message(content=modified_instruction+'\n'+question_context, role="user")
                        ]
                    question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(question_messages,max_tokens=2048)
                    match = re.findall(ANSWER_PATTERN, question_response_text)
                    extracted_answer = match[-1] if match else None
                    score = float(strict_is_equiv(data[0]["solution"], extracted_answer))
                    # score = float(check_equality(self.equality_checker, data[0]["solution"], extracted_answer))
                    # match = re.search(ANSWER_PATTERN_MULTICHOICE, question_response_text)
                    # extracted_answer = match.group(1) if match else None
                    # score = 1.0 if extracted_answer == data[0]["Answer"] else 0.0
                    if score==1.0:
                        return 'None',  question_response_text, score, data[0]
                    else:
                        current_trial+=1


            return 'None',  question_response_text, score, data[0]
        formatted_inputs = failure_case
        train_results = common.map_with_progress(fn, formatted_inputs)
        # res1,res2,res3,res4 = [r[0] for r in train_results],[r[1] for r in train_results],[r[2] for r in train_results],[r[3] for r in train_results]
        # new_set = []
        # new_set_2 = []
        # for i in range(len(res1)):
        #     for j in range(len(res1[i])):
        #         if res3[i][j]==1.0:
        #             new_set.append([res4[i][j],res2[i][j]])
        #         else:
        #             new_set_2.append([res4[i][j],res2[i][j]])


        instructions = [r[0] for r in train_results]
        answers = [r[1] for r in train_results]
        scores = [r[2] for r in train_results]
        sorted_inputs = [r[3] for r in train_results]
        new_set = []
        new_set_2 = []
        for i in range(len(scores)):
            if scores[i]==1.0:
                new_set.append([sorted_inputs[i],answers[i]])
            else:
                new_set_2.append([sorted_inputs[i],answers[i]])
        return new_set,new_set_2



    def augment_success_modify(self,failure_case):
        def fn(data):
            # data=[data]
            question_context=question_template.format(**data[0])
            context="You are trying to solve the following problem: \n"+self.evaluate_end+question_context+'\n You previously failed to crrectly answer the question. Your previous wrong answer is: \n'+data[1]+'\the desired correct final answer is: '+answer_template.format(**data[0])+'.\n'
            context+='Based on the information above, modify the previous answer to generate a correct answer. Output only the step-by-step solution after modification. Do not output any other contents.'
            # context+='\n\n The desired correct final answer is: \n'
            # context+=answer_template.format(**data[0])
            # context+=''
            for _ in range(3):
                prompt_messages = [
                        self.sampler._pack_message(content=context, role="user")
                    ]
                question_response_text,num_completion_tokens_used, num_prompt_tokens_used = self.sampler(prompt_messages,max_tokens=2048,temperature=0.0)


                match = re.findall(ANSWER_PATTERN, question_response_text)
                extracted_answer = match[-1] if match else None
                score = float(strict_is_equiv(data[0]["solution"], extracted_answer))

                if score==1.0:
                    # print(len(response_text))
                    return data,  question_response_text, score, data[0],0
            
            return data,  question_response_text, score, data[0],-1
        formatted_inputs = failure_case
        train_results = common.map_with_progress(fn, formatted_inputs)
        a,b,c,d,e = [t[0] for t in train_results],[t[1] for t in train_results],[t[2] for t in train_results],[t[3] for t in train_results],[t[4] for t in train_results]
        return [a,b,c,d,e]

    def data_formatting_success_double(self,train_data):
        # response_text,  question_response_text, score, data[0],-1
        data_all = []
        input_all = []
        response_text,  question_response_text, score, data, unused,dialoge = train_data
        num_data = len(response_text)
        for i in range(num_data):
            if score[i]:
                data_line = {}
                instruction = self.evaluate_end+question_template.format(**data[i])
                if response_text[i]!='':
                    answer = 'To solve this problem, I should follow the following instruction: '+response_text[i]+'\n'+'My response is: '+question_response_text[i]
                else:
                    answer = question_response_text[i]
                # data_line['output']=answer
                data_single=[{'role':'system','content':self.sampler.system_message},{'role':'user','content':instruction},{'role':'assistant','content':answer}]
                data_line['messages']=data_single
                for _ in range(3 if unused[i]>0 else 3):
                    data_all.append(data_line)
                    input_all.append(data[i])

                if response_text[i]!='':
                    data_line = {}
                    instruction = self.instruction_single_begin+'\n'+response_text[i]+self.instruction_single_end+'\n'+question_template.format(**data[i])
                    answer = question_response_text[i]
                    # data_line['output']=answer
                    data_single=[{'role':'system','content':self.sampler.system_message},{'role':'user','content':instruction},{'role':'assistant','content':answer}]
                    data_line['messages']=data_single
                    for _ in range(1 if unused[i]>0 else 1):
                        data_all.append(data_line)
                        input_all.append(data[i])
        def shuffle_corresponding_lists(list1, list2):
            

            shuffled_list1 = list1[:]
            shuffled_list2 = list2[:]
            

            indices = list(range(len(list1)))
            

            random.shuffle(indices)
            
            for i in range(len(list1)):
                shuffled_list1[i] = list1[indices[i]]
                shuffled_list2[i] = list2[indices[i]]
            
            return shuffled_list1, shuffled_list2
        data_all,input_all = shuffle_corresponding_lists(data_all,input_all)
        return data_all,input_all

    def data_formatting_success_star(self,train_data):
        # response_text,  question_response_text, score, data[0],-1
        data_all = []
        input_all = []
        response_text,  question_response_text, score, data, unused = train_data
        num_data = len(response_text)
        for i in range(num_data):
            if score[i]:
                data_line = {}
                instruction=self.evaluate_end+question_template.format(**data[i])
                # context+='\n\n The desired correct final answer is: \n'
                # context+=answer_template.format(**data[0])
                # context+=''
                answer = question_response_text[i]
                # data_line['output']=answer
                data_single=[{'role':'system','content':self.sampler.system_message},{'role':'user','content':instruction},{'role':'assistant','content':answer}]
                data_line['messages']=data_single
                for _ in range(2):
                    data_all.append(data_line)
                    input_all.append(data[i])
        return data_all,input_all
    
    def data_formatting_success_triple(self,train_data):
        # response_text,  question_response_text, score, data[0],-1
        data_all = []
        data_all_2 = []
        input_all = []
        response_text,  question_response_text, score, data, unused, messages = train_data
        num_data = len(response_text)
        for i in range(num_data):
            if score[i]:
                data_line = {}
                # context+='\n\n The desired correct final answer is: \n'
                # context+=answer_template.format(**data[0])
                # context+=''
                instruction='You are an expert at designing instructions for large language models to solve problems. An example instruction is: '+self.instruction_single_middle+'\n\nThe problem to be solved is:\n'
                # instruction='You are a professional teacher who is good at teaching large language models to solve problems. The problem to be solved is:\n'
                instruction+=question_template.format(**data[i])
                # # context+='\n\n The desired correct final answer is: \n'
                # # context+=answer_template.format(**data[0])
                # # context+=''
                # instruction+='\n\nAs a teacher, generate an instruction to prompt the large language model to solve the problem correctly. As a teacher, you do not need to do the calculations yourself. The instruction can include an outline of the steps needed to solve the problem, hints on avoiding making mistakes, important theorems needed, etc. The instruction should be no longer than 1024 tokens. Output only the content of the instruction. Do not output any other words.'
                
                # instruction='You are an expert in designing instructions for large language models to solve problems. The problem to be solved is:\n'
                # instruction+=question_template.format(**data[i])

                # context+='\n\n The desired correct final answer is: \n'
                # context+=answer_template.format(**data[0])
                # context+=''
                # instruction+='\n\nOutput the instruction you design. The instruction should be short, clear, concise and helpful. The instruction should not be only helpful for the specific problem, and it should be general enough to help solve similar problems (E.g. problems with similar structure but different numbers or prarmeters), so pay attention not to include any specific numbers or parameters in the instruction. In the instruction, you may include an outline of the steps needed to solve the problem, and anything else to help solve the problem (except for the final correct answer). Short and informative instructions are preferred, so refine your output to avoid verbosity, repetitiveness, results of detailed calculation, and semantic redundancy. The instruction should be no longer than 256 tokens. Keep it short and informative. Make sure the correct final answer does not appear in the instruction. Output only the content of the instruction. Do not output any other words.'
                instruction+='\n\nOutput the instruction you design. Note that the instruction should be general knowledge that help solve similar problems, so do not contain any task-specific information. Also, the content will be directly added to the prompt, so pay attention to its format. The instruction should be concise, no longer than 1024 tokens. Output only the instruction. Do not output any other words.'
                answer = response_text[i]
                # data_line['output']=answer
                data_single=[{'role':'system','content':self.sampler.system_message},{'role':'user','content':instruction},{'role':'assistant','content':answer}]
                instruction = self.final_end+question_template.format(**data[i])
                answer = question_response_text[i]
                data_single.append({'role':'user','content':instruction})
                data_single.append({'role':'assistant','content':answer})
                data_line['messages']=data_single
                for _ in range(1 if unused[i]==0 else 1):
                # for _ in range(1):
                    data_all.append(data_line)
                    data_all_2.append(data_line)
                    input_all.append(data[i])
                    # input_all.append(data[i])

                # data_line = {}
                
                # # data_line['output']=answer
                # data_single=[{'role':'system','content':self.sampler.system_message},{'role':'user','content':instruction},{'role':'assistant','content':answer}]
                # data_line['messages']=data_single
                # for _ in range(1 if unused[i]==0 else 3):
                #     data_all.append(data_line)
                #     data_all_2.append(data_line)
                #     input_all.append(data[i])
        def shuffle_corresponding_lists(list1, list2,list3):
            

            shuffled_list1 = list1[:]
            shuffled_list2 = list2[:]
            shuffled_list3 = list3[:]
            

            indices = list(range(len(list1)))
            

            random.shuffle(indices)
            
            for i in range(len(list1)):
                shuffled_list1[i] = list1[indices[i]]
                shuffled_list2[i] = list2[indices[i]]
                shuffled_list3[i] = list3[indices[i]]
            
            return shuffled_list1, shuffled_list2,shuffled_list3
        data_all,data_all_2,input_all = shuffle_corresponding_lists(data_all,data_all_2,input_all)
        return data_all,data_all_2,input_all
    

    def data_formatting_success(self,success_set,zero_shot=False,add_rationale=False):
        data = []
        num_data = len(success_set)
        for i in range(num_data):
            data_line = {}
            instruction_length = len(self.instruction_begin+self.instruction_multi_end+question_template.format(**success_set[i][0]))
            instruction_list_single = []
            instruction_indices = []
            while len(instruction_list_single)<=self.num_instructions_selected:
                if len(instruction_list_single)==len(self.instruction_list):
                    break
                idx = np.random.randint(len(self.instruction_list))
                while idx in instruction_indices:
                    idx = np.random.randint(len(self.instruction_list))
                if  instruction_length+len(self.instruction_list[idx])>=6.9*1000:
                    break
                instruction_list_single.append(self.instruction_list[idx])
                instruction_indices.append(idx)
                instruction_length+=len(self.instruction_list[idx])
            # if len(self.instruction_list)<=self.num_instructions_selected:
            #     instruction_list_single = self.instruction_list
            # else:
            #     indices = np.delete(np.arange(len(self.instruction_list)),i)
            #     selected_incides = np.random.choice(indices,self.num_instructions_selected-1,replace=False)
            #     instruction_list_single = [self.instruction_list[j] for j in selected_incides]   
            #     instruction_list_single.append(self.instruction_list[i])
            random.shuffle(instruction_list_single)
            instruction = self.instruction_begin
            instruction_list_single = copy.deepcopy(self.instruction_list)
            random.shuffle(instruction_list_single)
            for ins in instruction_list_single:
            # for ins in self.instruction_list:
                instruction+=ins
                instruction+='\n'
            instruction+=self.instruction_multi_end
            instruction+=question_template.format(**success_set[i][0])
            # assert len(instruction)<=7*1000
            # data_line['instruction']=instruction
            answer = 'I am able to solve the problem without paying special attention to any of the instructions above. The solution I propose is: \n'+success_set[i][1]
            if not add_rationale:
                answer = success_set[i][1]
            if zero_shot:
                instruction = self.evaluate_end+question_template.format(**success_set[i][0])
                answer = success_set[i][1]
            # data_line['output']=answer
            data_single=[{'role':'system','content':self.sampler.system_message},{'role':'user','content':instruction},{'role':'assistant','content':answer}]
            data_line['messages']=data_single
            data.append(data_line)
        return data


    def data_formatting_modify(self,success_set):
        data = []
        num_data = len(success_set[0])
        for i in range(num_data):
            
            if success_set[2][i]:
                data_line = {}
                instruction = self.evaluate_end+question_template.format(**success_set[0][i][0])
                if 'I find the answer incorrect. After modifying the previous answer, the new answer is:' not in success_set[0][i][1]:
                    answer = success_set[0][i][1] + '\n I find the answer incorrect. After modifying the previous answer, the new answer is: \n'+ success_set[1][i]
                else:
                    answer = success_set[0][i][1].split('\n I find the answer incorrect. After modifying the previous answer, the new answer is: \n')[0] + '\n I find the answer incorrect. After modifying the previous answer, the new answer is: \n'+ success_set[1][i]
            # data_line['output']=answer
                data_single=[{'role':'system','content':self.sampler.system_message},{'role':'user','content':instruction},{'role':'assistant','content':answer}]
                data_line['messages']=data_single
                data.append(data_line)
        return data



    def data_generation_simplified_ori(self,zero_shot=False):
        previous_instruction_list = self.instruction_list
        self.instruction_list = []
        # if len(previous_instruction_list)>0:
        #     if len(previous_instruction_list)<self.num_instructions_selected:
        #         self.instruction_list = previous_instruction_list
        #     else:
        #         idxs = np.random.choice(np.arange(len(previous_instruction_list)),int(len(previous_instruction_list)*0.5),replace=False)
        #         for idx in idxs:
        #             self.instruction_list.append(previous_instruction_list[idx])

        #         many_shot_train_acc,many_shot_train_scores,many_shot_train_outputs,sorted_many_shot_training_examples = self.many_shot_evaluate(None,self.train_examples)
        #         self.many_shot_train_outputs = many_shot_train_outputs
        #         self.sorted_many_shot_training_examples=sorted_many_shot_training_examples
        
        self.current_training_set = []
        self.current_response_set = []
        success_set = []
        failure_set = []
        failure_case_keys = []
        success_case_keys = []
        failure_set_many_shot = []
        self.training_set_filtered=[]
        self.rationale_list=[]
        self.corresponding_idx_list = []
        for i,s in enumerate(self.train_scores):
            if s==0:
                failure_set.append([self.sorted_training_examples[i],self.train_outputs[i]])
                failure_case_keys.append(self.sorted_training_examples[i]['problem'])
            else:
                success_set.append([self.sorted_training_examples[i],self.train_outputs[i]])
                success_case_keys.append(self.sorted_training_examples[i]['problem'])
        for i,s in enumerate(self.many_shot_train_scores):
            if s==0:# and self.sorted_many_shot_training_examples[i]['problem'] not in failure_case_keys:
                failure_set_many_shot.append([self.sorted_many_shot_training_examples[i],self.many_shot_train_outputs[i]])
            # else:
            #     success_set.append([self.sorted_many_shot_training_examples[i],self.many_shot_train_outputs[i]])
            # elif s==1.0 and self.sorted_many_shot_training_examples[i]['Question'] not in success_case_keys:
            #     success_set.append([self.sorted_many_shot_training_examples[i],self.many_shot_train_outputs[i]])
        # new_set = failure_set+failure_set_many_shot
        if not zero_shot:
            # success_set_extra,new_set_2 = self.augment_success(failure_set)
            # success_set = success_set+success_set_extra*1
            # success_set_extra,new_set_2 = self.augment_success(failure_set_many_shot)
            # success_set = success_set+success_set_extra*1
            # self.success_set_sizes.append(len(success_set_extra))
            # self.instruct_set_sizes.append(len(new_set_2))

            new_set_2 = failure_set_many_shot
            self.new_reflection_simplified(new_set_2,zero_shot=zero_shot)
        else:
            success_set_extra,new_set_2 = self.augment_success(failure_set)
            success_set = success_set+success_set_extra*1
            self.success_set_sizes.append(len(success_set_extra))
            print(len(success_set),len(success_set_extra),len(new_set_2),len(self.training_set_filtered))
        # self.new_reflection_simplified(failure_set,zero_shot=zero_shot)
        # self.new_reflection_simplified(failure_set_many_shot,zero_shot=zero_shot)
        # success_set_extra = self.augment_success(failure_set+failure_set_many_shot)
        # success_set = success_set+success_set_extrad

        # if len(self.sft_data)>0:
        #     idxs = np.random.choice(np.arange(len(self.sft_data)),len(self.sft_data)//4,replace=False)
        #     sft_data = []
        #     sft_inputs = []
        #     for idx in idxs:
        #         sft_data.append(self.sft_data[idx])
                # # sft_inputs.append(self.sft_inputs[idx])
        self.sft_data = []
        self.sft_data2 = [] 
        self.sft_inputs = []


        # self.sft_data = []
        # self.sft_inputs = []
        # if len(previous_instruction_list)>0:
        #     idxs = np.random.choice(np.arange(len(previous_instruction_list)),int(len(previous_instruction_list)*0.5),replace=False)
        #     for idx in idxs:
        #         self.instruction_list.append(previous_instruction_list[idx])

        for _ in range(1): # augment failure data
            self.sft_data = self.sft_data + self.data_formatting_simplified(zero_shot=zero_shot)
            self.sft_data2 = self.sft_data2 + self.data_formatting_simplified(zero_shot=zero_shot)
            self.sft_inputs+=self.training_set_filtered
        if not zero_shot:
            for _ in range(1): # augment failure data
                self.sft_data = self.sft_data + self.data_formatting_simplified_oneshot(zero_shot=zero_shot)
                self.sft_data2 = self.sft_data2 + self.data_formatting_simplified_oneshot(zero_shot=zero_shot)
                self.sft_inputs+=self.training_set_filtered
        for _ in range(1):
            for e in success_set:
                self.sft_inputs.append(e[0])
            self.sft_data=self.sft_data+self.data_formatting_success(success_set,zero_shot=zero_shot,add_rationale=True)
            self.sft_data2=self.sft_data2+self.data_formatting_success(success_set,zero_shot=zero_shot,add_rationale=True)

        if zero_shot:
            self.instruction_list = []
        return



    def data_generation_simplified_ori_modify(self,zero_shot=True):
        previous_instruction_list = self.instruction_list
        self.instruction_list = []
        
        
        self.current_training_set = []
        self.current_response_set = []
        success_set = []
        failure_set = []
        failure_case_keys = []
        success_case_keys = []
        failure_set_many_shot = []
        self.training_set_filtered=[]
        self.rationale_list=[]
        self.corresponding_idx_list = []
        for i,s in enumerate(self.train_scores):
            if s==0:
                failure_set.append([self.sorted_training_examples[i],self.train_outputs[i]])
                failure_case_keys.append(self.sorted_training_examples[i]['problem'])
            else:
                success_set.append([self.sorted_training_examples[i],self.train_outputs[i]])
                success_case_keys.append(self.sorted_training_examples[i]['problem'])
        for i,s in enumerate(self.many_shot_train_scores):
            if s==0:# and self.sorted_many_shot_training_examples[i]['problem'] not in failure_case_keys:
                failure_set_many_shot.append([self.sorted_many_shot_training_examples[i],self.many_shot_train_outputs[i]])

        if not zero_shot:

            new_set_2 = failure_set_many_shot
            self.new_reflection_simplified(new_set_2,zero_shot=zero_shot)
        else:
            success_set_extra = self.augment_success_modify(failure_set)
            success_set = success_set#+success_set_extra*1
            self.success_set_sizes.append(len(success_set_extra))
            print(len(success_set),len(success_set_extra[0]),len(self.training_set_filtered))
        self.sft_data = []
        self.sft_data2 = [] 
        self.sft_inputs = []



        for _ in range(1): # augment failure data
            self.sft_data = self.sft_data + self.data_formatting_modify(success_set_extra)
            self.sft_data2 = self.sft_data2 + self.data_formatting_modify(success_set_extra)
            self.sft_inputs+=self.training_set_filtered
        for _ in range(1):
            for e in success_set:
                self.sft_inputs.append(e[0])
            self.sft_data=self.sft_data+self.data_formatting_success(success_set,zero_shot=zero_shot,add_rationale=True)
            self.sft_data2=self.sft_data2+self.data_formatting_success(success_set,zero_shot=zero_shot,add_rationale=True)

        if zero_shot:
            self.instruction_list = []
        return


    def data_generation_simplified_triple(self,zero_shot=False):
        previous_instruction_list = self.instruction_list
        self.instruction_list = []
        
        
        self.current_training_set = []
        self.current_response_set = []
        success_set = []
        failure_set = []
        failure_case_keys = []
        success_case_keys = []
        failure_set_many_shot = []
        self.training_set_filtered=[]
        self.rationale_list=[]
        self.corresponding_idx_list = []


        training_data = self.reflection_simplified_triple(self.train_examples)
        self.training_data = training_data
        self.success_set_sizes.append(np.sum(training_data[2]))
        
        self.sft_data = []
        self.sft_data2 = [] 
        self.sft_inputs = []
        sd,sd2,si = self.data_formatting_success_triple(training_data)
        self.sft_data = self.sft_data+sd
        self.sft_data2 = self.sft_data2+sd2
        self.sft_inputs = self.sft_inputs + si
        self.instruction_list = []


        return

    def data_generation_simplified_triple2(self,zero_shot=False):
        previous_instruction_list = self.instruction_list
        self.instruction_list = []
        
        
        self.current_training_set = []
        self.current_response_set = []
        success_set = []
        failure_set = []
        failure_case_keys = []
        success_case_keys = []
        failure_set_many_shot = []
        self.training_set_filtered=[]
        self.rationale_list=[]
        self.corresponding_idx_list = []


        training_data = self.reflection_simplified_triple2(self.train_examples)
        self.training_data = training_data
        self.success_set_sizes.append(np.sum(training_data[2]))
        
        self.sft_data = []
        self.sft_data2 = [] 
        self.sft_inputs = []
        sd,sd2,si = self.data_formatting_success_triple(training_data)
        self.sft_data = self.sft_data+sd
        self.sft_data2 = self.sft_data2+sd2
        self.sft_inputs = self.sft_inputs + si
        self.instruction_list = []


        return

    def data_generation_simplified_star(self,zero_shot=False):
        previous_instruction_list = self.instruction_list
        self.instruction_list = []
        
        
        self.current_training_set = []
        self.current_response_set = []
        success_set = []
        failure_set = []
        failure_case_keys = []
        success_case_keys = []
        failure_set_many_shot = []
        self.training_set_filtered=[]
        self.rationale_list=[]
        self.corresponding_idx_list = []


        training_data = self.reflection_simplified_star(self.train_examples)
        
        self.sft_data = []
        self.sft_data2 = [] 
        self.sft_inputs = []
        sd,si = self.data_formatting_success_star(training_data)
        self.sft_data = self.sft_data+sd
        self.sft_data2 = self.sft_data2+sd
        self.sft_inputs = self.sft_inputs + si
        self.instruction_list = []


        return

    def data_generation_simplified_star_reflection(self,zero_shot=False):
        previous_instruction_list = self.instruction_list
        self.instruction_list = []
        
        
        self.current_training_set = []
        self.current_response_set = []
        success_set = []
        failure_set = []
        failure_case_keys = []
        success_case_keys = []
        failure_set_many_shot = []
        self.training_set_filtered=[]
        self.rationale_list=[]
        self.corresponding_idx_list = []


        training_data = self.reflection_simplified_star_reflection(self.train_examples)
        self.training_data = training_data
        self.sft_data = []
        self.sft_data2 = [] 
        self.sft_inputs = []
        sd,si = self.data_formatting_success_star(training_data)
        self.sft_data = self.sft_data+sd
        self.sft_data2 = self.sft_data2+sd
        self.sft_inputs = self.sft_inputs + si
        self.instruction_list = []


        return

    def data_generation_simplified_double(self,zero_shot=False):
        previous_instruction_list = self.instruction_list
        self.instruction_list = []
        
        
        self.current_training_set = []
        self.current_response_set = []
        success_set = []
        failure_set = []
        failure_case_keys = []
        success_case_keys = []
        failure_set_many_shot = []
        self.training_set_filtered=[]
        self.rationale_list=[]
        self.corresponding_idx_list = []


        training_data = self.reflection_simplified_double(self.train_examples)
        self.training_data = training_data
        self.success_set_sizes.append(np.sum(training_data[2]))
        
        
        self.sft_data = []
        self.sft_data2 = [] 
        self.sft_inputs = []
        sd,si = self.data_formatting_success_double(training_data)
        self.sft_data = self.sft_data+sd
        self.sft_data2 = self.sft_data2+sd
        self.sft_inputs = self.sft_inputs + si
        self.instruction_list = []


        return

    def data_generation_simplified(self,zero_shot=False):
        previous_instruction_list = self.instruction_list
        self.instruction_list = []
        if len(previous_instruction_list)>0:
            if len(previous_instruction_list)<self.num_instructions_selected:
                self.instruction_list = previous_instruction_list
            else:
                idxs = np.random.choice(np.arange(len(previous_instruction_list)),int(len(previous_instruction_list)*0.9),replace=False)
                for idx in idxs:
                    self.instruction_list.append(previous_instruction_list[idx])

                many_shot_train_acc,many_shot_train_scores,many_shot_train_outputs,sorted_many_shot_training_examples = self.many_shot_evaluate(None,self.train_examples)
                self.many_shot_train_outputs = many_shot_train_outputs
                self.sorted_many_shot_training_examples=sorted_many_shot_training_examples
        
        self.current_training_set = []
        self.current_response_set = []
        success_set = []
        failure_set = []
        failure_case_keys = []
        success_case_keys = []
        failure_set_many_shot = []
        self.training_set_filtered=[]
        self.rationale_list=[]
        self.corresponding_idx_list = []
        # for i,s in enumerate(self.train_scores):
        #     if s==0:
        #         failure_set.append([self.sorted_training_examples[i],self.train_outputs[i]])
        #         failure_case_keys.append(self.sorted_training_examples[i]['problem'])
        #     else:
        #         success_set.append([self.sorted_training_examples[i],self.train_outputs[i]])
        #         success_case_keys.append(self.sorted_training_examples[i]['problem'])
        for i,s in enumerate(self.many_shot_train_scores):
            if s==0:# and self.sorted_many_shot_training_examples[i]['problem'] not in failure_case_keys:
                failure_set_many_shot.append([self.sorted_many_shot_training_examples[i],self.many_shot_train_outputs[i]])
            else:
                success_set.append([self.sorted_many_shot_training_examples[i],self.many_shot_train_outputs[i]])
            # elif s==1.0 and self.sorted_many_shot_training_examples[i]['Question'] not in success_case_keys:
            #     success_set.append([self.sorted_many_shot_training_examples[i],self.many_shot_train_outputs[i]])
        # new_set = failure_set+failure_set_many_shot
        if not zero_shot:
            # success_set_extra,new_set_2 = self.augment_success(failure_set)
            # success_set = success_set+success_set_extra*1
            # success_set_extra,new_set_2 = self.augment_success(failure_set_many_shot)
            # success_set = success_set+success_set_extra*1
            # self.success_set_sizes.append(len(success_set_extra))
            # self.instruct_set_sizes.append(len(new_set_2))

            new_set_2 = failure_set_many_shot
            self.new_reflection_simplified(new_set_2,zero_shot=zero_shot)
        else:
            success_set_extra,new_set_2 = self.augment_success(failure_set+failure_set_many_shot)
            success_set = success_set+success_set_extra*3
            self.success_set_sizes.append(len(success_set_extra))
            print(len(success_set),len(success_set_extra),len(new_set_2),len(self.training_set_filtered))
        # self.new_reflection_simplified(failure_set,zero_shot=zero_shot)
        # self.new_reflection_simplified(failure_set_many_shot,zero_shot=zero_shot)
        # success_set_extra = self.augment_success(failure_set+failure_set_many_shot)
        # success_set = success_set+success_set_extrad

        # if len(self.sft_data)>0:
        #     idxs = np.random.choice(np.arange(len(self.sft_data)),len(self.sft_data)//4,replace=False)
        #     sft_data = []
        #     sft_inputs = []
        #     for idx in idxs:
        #         sft_data.append(self.sft_data[idx])
                # # sft_inputs.append(self.sft_inputs[idx])
        self.sft_data = []
        self.sft_data2 = [] 
        self.sft_inputs = []


        # self.sft_data = []
        # self.sft_inputs = []
        # if len(previous_instruction_list)>0:
        #     idxs = np.random.choice(np.arange(len(previous_instruction_list)),int(len(previous_instruction_list)*0.5),replace=False)
        #     for idx in idxs:
        #         self.instruction_list.append(previous_instruction_list[idx])

        for _ in range(3): # augment failure data
            self.sft_data = self.sft_data + self.data_formatting_simplified(zero_shot=zero_shot)
            self.sft_data2 = self.sft_data2 + self.data_formatting_simplified(zero_shot=zero_shot)
            self.sft_inputs+=self.training_set_filtered
        for _ in range(1): # augment failure data
            self.sft_data = self.sft_data + self.data_formatting_simplified_oneshot(zero_shot=zero_shot)
            self.sft_data2 = self.sft_data2 + self.data_formatting_simplified_oneshot(zero_shot=zero_shot)
            self.sft_inputs+=self.training_set_filtered
        for _ in range(1):
            for e in success_set:
                self.sft_inputs.append(e[0])
            self.sft_data=self.sft_data+self.data_formatting_success(success_set,zero_shot=zero_shot)
            self.sft_data2=self.sft_data2+self.data_formatting_success(success_set,zero_shot=zero_shot)

        if zero_shot:
            self.instruction_list = []
        return

    

    def evaluation(self):
        zero_shot_context = None#self.generate_context()
        zero_shot_train_acc,zero_shot_train_scores,zero_shot_train_outputs,sorted_training_examples = self.evaluate_triple(zero_shot_context,self.train_examples)
        zero_shot_test_acc,zero_shot_test_scores,zero_shot_test_outputs,sorted_testing_examples = self.evaluate_triple(zero_shot_context,self.test_examples)#0,[0],[0],[0]#_triple
        self.train_scores = zero_shot_train_scores
        self.test_scores = zero_shot_test_scores
        self.train_outputs = zero_shot_train_outputs
        self.test_outputs = zero_shot_test_outputs
        self.sorted_testing_examples = sorted_testing_examples
        self.zero_shot_train_score_history.append(zero_shot_train_acc)
        self.zero_shot_test_score_history.append(zero_shot_test_acc)
        self.sorted_training_examples = sorted_training_examples

        
        many_shot_train_acc,many_shot_train_scores,many_shot_train_outputs,sorted_many_shot_training_examples = self.many_shot_evaluate(zero_shot_context,self.train_examples)
        many_shot_test_acc,self.many_shot_test_scores,self.many_shot_test_outputs,self.sorted_many_shot_testing_examples = self.many_shot_evaluate(zero_shot_context,self.test_examples)
        self.many_shot_train_score_history.append(many_shot_train_acc)
        self.many_shot_train_scores = many_shot_train_scores
        self.many_shot_train_outputs = many_shot_train_outputs
        self.sorted_many_shot_training_examples =sorted_many_shot_training_examples
        self.many_shot_test_score_history.append(many_shot_test_acc)
        # print(zero_shot_test_outputs[0])
        print('+'*30)
        print(zero_shot_test_outputs[0])
        print('+'*30)
        print(self.many_shot_test_outputs[0])
        sft_train_acc,sft_train_scores,sft_train__outputs,_ = self.sft_evaluate()
        self.sft_train_score_history.append(sft_train_acc)
        print(zero_shot_train_acc,zero_shot_test_acc,many_shot_train_acc,many_shot_test_acc,sft_train_acc)
        return zero_shot_train_acc,zero_shot_test_acc

    def data_formatting_simplified(self,zero_shot=False):
        data = []
        num_data = len(self.training_set_filtered)
        for i in range(num_data):
            data_line = {}
            instruction_length = len(self.instruction_begin+self.instruction_multi_end+question_template.format(**self.training_set_filtered[i]))
            instruction_list_single = [self.instruction_list[self.corresponding_idx_list[i]]]
            instruction_length+=len(self.instruction_list[self.corresponding_idx_list[i]])
            instruction_indices = [self.corresponding_idx_list[i]]
            while len(instruction_list_single)<=self.num_instructions_selected:
                if len(instruction_list_single)==len(self.instruction_list):
                    break
                idx = np.random.randint(len(self.instruction_list))
                while idx in instruction_indices:
                    idx = np.random.randint(len(self.instruction_list))
                if  instruction_length+len(self.instruction_list[idx])>=6.9*1000:
                    break
                instruction_list_single.append(self.instruction_list[idx])
                instruction_indices.append(idx)
                instruction_length+=len(self.instruction_list[idx])
            # if len(self.instruction_list)<=self.num_instructions_selected:
            #     instruction_list_single = self.instruction_list
            # else:
            #     indices = np.delete(np.arange(len(self.instruction_list)),i)
            #     selected_incides = np.random.choice(indices,self.num_instructions_selected-1,replace=False)
            #     instruction_list_single = [self.instruction_list[j] for j in selected_incides]   
            #     instruction_list_single.append(self.instruction_list[i])
            random.shuffle(instruction_list_single)
            instruction_list_single = copy.deepcopy(self.instruction_list)
            random.shuffle(instruction_list_single)
            instruction = self.instruction_begin
            for ins in instruction_list_single:
            # for ins in self.instruction_list:
                instruction+=ins
                instruction+='\n'
            instruction+=self.instruction_multi_end
            instruction+=question_template.format(**self.training_set_filtered[i])
            # assert len(instruction)<=7*1000
            # data_line['instruction']=instruction
            answer = 'From all these instructions above, the most relevant instruction is: '+self.instruction_list[self.corresponding_idx_list[i]]+' So, I will mainly follow this instruction when solving the problem. The solution I propose is: \n'+self.rationale_list[i]
            if zero_shot:
                instruction = self.evaluate_end+question_template.format(**self.training_set_filtered[i])
                answer = self.rationale_list[i]
            # data_line['output']=answer
            data_single=[{'role':'system','content':self.sampler.system_message},{'role':'user','content':instruction},{'role':'assistant','content':answer}]
            data_line['messages']=data_single
            data.append(data_line)
        return data
    


    def data_formatting_simplified_oneshot(self,zero_shot=False):
        data = []
        num_data = len(self.training_set_filtered)
        for i in range(num_data):
            data_line = {}
            instruction = self.instruction_begin+'\n'+self.instruction_list[self.corresponding_idx_list[i]]
            instruction+=self.instruction_single_end
            instruction+=question_template.format(**self.training_set_filtered[i])
            # assert len(instruction)<=7*1000
            # data_line['instruction']=instruction
            answer = self.rationale_list[i]
            # data_line['output']=answer
            data_single=[{'role':'system','content':self.sampler.system_message},{'role':'user','content':instruction},{'role':'assistant','content':answer}]
            data_line['messages']=data_single
            data.append(data_line)
        return data
