import random
from typing import List
import re
from prompts.gsm8k import Decompose_Question, Decompose_Answer, CoT_Question, CoT_Answer, DecomposePrompt_Question, DecomposePrompt_Answer

def nshot_prompt_instructwhiteBox(datapool: list, nshot: int, question: str, selected_prompts: List) -> dict:
    def question_prompt(s):
        return f'Question: {s}'
    
    def answer_prompt(s):
        return f'Answer: {s}'
    
    chats = []
    prompts = selected_prompts if selected_prompts else random.sample(datapool, nshot)
    for prompt in prompts:
        
        chats.append({"role": "user", "content": question_prompt(prompt["question"])})
        chats.append({"role": "assistant", "content": answer_prompt(prompt["answer"])})

    chats.append({"role": "user", "content": question_prompt(question) + " Let's think step by step. At the end, you MUST write the answer as an integer after '####'."})

    return chats

def zero_shot_prompt_instructwhiteBox(question: str, intermediate_response: str, choice: str) -> dict:
    def question_prompt(s):
        return f'Question: {s}'
    
    chats = []
    if choice == 'zero_shot':
        chats.append({"role": "user", "content": question_prompt(question) + '\n' + " Let's think step by step. You MUST write the final answer only as an integer after the phrase 'So the answer is'."})
    elif choice == 'decompose':
        chats.append({"role": "user", "content": question_prompt(question) + '\n' + intermediate_response + " Let's think step by step. You MUST write the final answer only as an integer after the phrase 'So the answer is'."})
    elif choice == 'cot':
        chats.append({"role": "user", "content": question_prompt(question) + '\n' + " Let's think step by step. You MUST write the final answer only as an integer after the phrase 'So the answer is'. " + intermediate_response})
    # 

    return chats

def few_shot_prompt_instructwhiteBox(question: str, intermediate_response: str, choice: str) -> dict:
    def question_prompt(s):
        return f'Question: {s}'
    
    chats = []
    if choice == 'decompose':
        for prompt_question, prompt_answer in zip(DecomposePrompt_Question, DecomposePrompt_Answer):
            chats.append({"role": "user", "content": question_prompt(prompt_question)})
            chats.append({"role": "assistant", "content": prompt_answer})
        chats.append({"role": "user", "content": question_prompt(question) + '\n' + "Let's break down this problem: " + intermediate_response + " Let's think step by step. You MUST write the final answer only as an integer after the phrase 'So the answer is'."})
    else:
        raise NotImplementedError
    
    return chats

def nshot_cot_instructwhiteBox(prompt: str) -> dict:
    def question_prompt(s):
        return f'Question: {s}'
    
    def answer_prompt(s):
        return f'Answer: {s}'
    
    chats = []
    for (question, answer) in zip(CoT_Question, CoT_Answer):
        chats.append({"role": "user", "content": question_prompt(question)})
        chats.append({"role": "assistant", "content": answer_prompt(answer)})
    
    chats.append({"role": "user", "content": question_prompt(prompt) + " Let's think step by step."})

    return chats

def nshot_decompose_instructwhiteBox(prompt: str) -> dict:
    def question_prompt(s):
        return f'Question: {s}'
    
    chats = []
    chats.append({"role": "system", "content": "Your task is to decompose a given question into a set of subquestions that outline the steps or information needed to arrive at the answer. You should not provide the direct answer to the question. Instead, focus on identifying the smaller pieces of information or steps required to solve the problem. Use the following format: Let’s break down this problem: 1. '[Subquestion 1]' 2. '[Subquestion 2]', ... , N. '[Subquestion N]' N+1. '[Main question]'"})
    for (question, answer) in zip(Decompose_Question, Decompose_Answer):
        chats.append({"role": "user", "content": question_prompt(question)})
        chats.append({"role": "assistant", "content": answer})
    
    chats.append({"role": "user", "content": question_prompt(prompt) + " Let's break down this problem:"})

    return chats

def get_answer(response, eos=None, splits='####'):
    if eos:
        response = response.split(eos)[0].strip()
    
    answer = response.split(splits)[-1].strip()
    answer = answer.split(".")[0].strip()
    answer = answer.split("=")[-1].strip()

    for remove_char in [",", "$", "%", "g"]:
        answer = answer.replace(remove_char, "")
    answer = answer.rstrip(".")

    try:
        return int(answer)
    except ValueError:
        numbers = re.findall(r'\d+', answer)
        if len(numbers) == 1:
            return int(numbers[0])
        return answer
    
def get_full_decompose(answer):
    output = "Let's break down this problem:\n\n"
    questions = re.findall(r'(.*?\?)\s*\*\*', answer)
    matches = re.findall(r"\*\* (.*?)\n", answer)

    for i, (question, answer) in enumerate(zip(questions, matches), 1):
        output += f"{i}. {question.strip()} ** {answer.strip()}\n"

    return output.strip()
    
def get_decompose(answer, question_only=True):
    output = "Let's break down this problem:\n\n"

    questions = re.findall(r'(.*?\?)\s*\*\*', answer)
    if len(questions) == 0:
        return output.strip()
    if question_only:
        for i, question in enumerate(questions, 1):
            output += f"{i}. {question}\n"
    else:
        matches = re.findall(r"\*\* (.*?)\n", answer)
        for i, (question, answer) in enumerate(zip(questions[:-1], matches[:-1]), 1):
            output += f"{i}. {question.strip()} ** {answer.strip()}\n"
        output += f"{len(questions)}. {questions[-1]}\n"
        
    return output.strip()

def cnt_decompose_steps(response):
    steps = response.strip().split("\n")
    step_cnt = 0
    for step in steps:
        if re.match(r'^\d+\.', step):
            step_cnt += 1
    return step_cnt


if __name__ == '__main__':
    Answer = [
        "How much did the house cost? ** The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\nHow much did the repairs increase the value of the house? ** He increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\nWhat is the new value of the house? ** So the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\nHow much profit did he make? ** So he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000\n#### 70000",
        "How many bolts of white fiber does it take? ** It takes 2/2=<<2/2=1>>1 bolt of white fiber\nHow many bolts in total does it take? ** So the total amount of fabric is 2+1=<<2+1=3>>3 bolts of fabric\n#### 3",
        "How many hours does the candle burn? ** The candle burns for 5 - 1 = <<5-1=4>>4 hours.\nHow many centimeters shorter will the candle be after burning from 1:00 PM to 5:00 PM? ** Thus, the candle will be 2 * 4 = <<2*4=8>>8 centimeters shorter.\n#### 8"
    ]
    for answer in Answer:
        print(get_decompose(answer, question_only=True))
        print(cnt_decompose_steps(get_decompose(answer, question_only=True)))
        print(get_decompose(answer, question_only=False))
        print(cnt_decompose_steps(get_decompose(answer, question_only=False)))
        # break