from algo.reasoning_adapter import Reasoning_Adapter
from utils.gsm8k_metric import get_accuracy, stop_criterion, is_correct

# imported from https://github.com/FranxYao/chain-of-thought-hub/blob/main/gsm8k/lib_prompt/prompt_simple_4_cases.txt

PROMPT = '''Please complete the plans to solve the question. Here are several examples:
Q: Four years ago, Kody was only half as old as Mohamed. If Mohamed is currently twice 30 years old, how old is Kody?
A: Let's think step by step.
1. We were told that Mohamed is currently twice 30 years old, so he is currently 30 * 2 = 60 years old. 
2. That means that four years ago he must have been 60 - 4 = 56 years old. 
3. Four years ago, Kody was half as old as Mohamed, so Kody must have been 56 / 2 = 28 years old then. 
4. Since Kody was 28 years old four years ago, she must now be 28 + 4 = 32 years old. 
5. So the answer is 32.

Q: Carla bought 2 bags of mini peanut butter cups on clearance. Each bag was $6.00 but was 75% off. How much did she spend on 2 bags of candy?
A: Let's think step by step.
1. Each bag was $6.00 but was 75% off. 
2. So each bag cost $6.00 * (1 - 0.75) = $6.00 * 0.25 = $1.50. 
3. Carla bought 2 bags. So she spent $1.50 * 2 = $3.00. 
4. So the answer is 3.

Q: If Pam is currently twice as young as Rena is, and in 10 years Rena will be 5 years older than her, how old is Pam now?
A: Let's think step by step.
1. Since Rena will be 5 years older than Pam in 10 years, she must be 5 years older than Pam now as well. 
2. If Pam is currently twice as young as Rena, that means that Rena is currently twice as old as Pam is. 
3. So if P stands for Pam’s age now and R stands for Rena’s age now, then we know that R = 2 * P And since Rena is 5 years older than Pam now, we know that R = P + 5. 
4. By substitution, we have P + 5 = 2 * P, which means that P = 5. 
5. So the answer is 5.

Q: Cappuccinos cost $2, iced teas cost $3, cafe lattes cost $1.5 and espressos cost $1 each. Sandy orders some drinks for herself and some friends. She orders three cappuccinos, two iced teas, two cafe lattes, and two espressos. How much change does she receive back for a twenty-dollar bill?
A: Let's think step by step.
1. Sandy ordered three cappuccinos, which cost $2 each, so she spent $2 * 3 = $6 on cappuccinos. 
2. She ordered two iced teas, which cost $3 each, so she spent $3 * 2 = $6 dollars on ice teas. 
3. She ordered two cafe lattes, which cost $1.5 each, so she spent $1.5 * 2 = $3 on cafe lattes. 
4. She ordered two espressos, which cost $1 each, so she spent $1 * 2 = $2 on espressos. 
5. So altogether, Sandy spent $6 + $6 + $3 + $2 = $17 on drinks, which means that sandy will get $20 - $17 = $3 as change. 
6. So the answer is 3.
[END OF EXAMPLE]
Please answer the following question:
'''.strip() + "\n"

class GSM8K_Adapter(Reasoning_Adapter):
    
    def __init__(self, prompt, config, accelerator):
        self.config = config
        self.prompt = prompt
        self.accelerator = accelerator
        
        super().__init__(
                config=config,
                prompt=prompt,
                accelerator=accelerator
            )
 
        self.stop_criterion = stop_criterion
        self.get_accuracy = get_accuracy
        self.is_correct = is_correct
        self.qa_template = config["qa_template"]
        
        
    def get_positive_ans(self, b):
        steps = b['answer'].split('\n')
        answer = ''
        for i, step in enumerate(steps):
            if step.startswith("####"):
                answer += f'{i+1}. {step.replace("####", "So the answer is")}'
            else:
                answer += f'{i+1}. {step}\n'
        positive_ans = [answer]
        return positive_ans
    
    
    def formulate_question(self, b):
        return b['question']

    
    def extract_ground_truth(self, b):
        b['answer'] = b['answer'].replace("####", "So the answer is")
        return b
    
    

