import json
import random
import re
from json_task import JsonTask

def loadDataFromFile(filename):
    with open(filename, "r") as f:
        data = json.load(f)
    return data

def extractQuestionsWithAnswers(raw_data, append_instruction=True, step_by_step=True, custom_instruction=""):
    instruction = ""
    if step_by_step:
        instruction += "Let's think step by step. "
    if custom_instruction:
        instruction += custom_instruction + " "
    instruction += "At the end show the answer bracketed with <answer> and </answer>. \
    Finally generate END at the end of the solution."
    qa_pairs = []
    for sample in raw_data["examples"]:
        question = sample["input"].strip() + "\n"
        if append_instruction:
            if "Options:" in question:
                choices = question.split("Options:")[-1].strip()
                choices = choices.split("\n")
                output_choices = [choice.strip().split()[0] for choice in choices]
                if output_choices[0] == "-":
                    output_choices = [choice.strip().split()[1] for choice in choices]
                _instruction = instruction.replace("answer bracketed", "answer (" + " or ".join(output_choices) + ") bracketed")
                question += _instruction
            else:
                question += instruction
        answer = sample["target"]
        qa_pairs.append((question, answer))
    return qa_pairs

def extractQuestionsWithAnswersFromBigBench(raw_data, step_by_step=True, enumerate_choices=True, permute_choices=True, custom_instruction=""):
    instruction = ""
    if step_by_step:
        instruction += "Let's think step by step. "
    if custom_instruction:
        instruction += custom_instruction + " "
    instruction += "At the end show the answer bracketed with <answer> and </answer>. \
    Finally generate END at the end of the solution."
    raw_data["example_input_prefix"] = ""
    raw_data["example_output_prefix"] = ""

    json_task = JsonTask(raw_data, enumerate_choices=enumerate_choices, permute_choices=permute_choices)
    qa_pairs = []
    for sample in json_task._ds:
        question = json_task.task_prefix.strip() + "\n" + sample["input"] + "\n"
        if "choice" in sample:
            choices = sample["choice"]
            _instruction = instruction.replace("answer bracketed", "answer (" + " or ".join(choices) + ") bracketed")
            question += _instruction
        else:
            question += instruction
        answer = sample["target"]
        qa_pairs.append((question, answer))
    return qa_pairs

def createTrainingData(full_data, test_data, train_data_size, compare_options=False):
    test_q_set = set()
    for q, _ in test_data:
        if compare_options:
            key = q.replace("\n", "").replace(" ", "")
        else:
            key = q.split("Options:")[0].replace("\n", "").replace(" ", "")
        test_q_set.add(key)
    train_data = []
    for q, a in full_data:
        if compare_options:
            key = q.replace("\n", "").replace(" ", "")
        else:
            key = q.split("Options:")[0].replace("\n", "").replace(" ", "")
        if key not in test_q_set:
            train_data.append((q, a))

    if len(train_data) >= train_data_size:
        return random.sample(train_data, train_data_size)
    else:
        random.shuffle(train_data)
        return train_data + random.sample(test_data, train_data_size - len(train_data))

def scoreSolution(solution, correct_answer, pre_answer_text=""):
    if pre_answer_text and pre_answer_text in solution:
        answer = solution.split(pre_answer_text)[-1].strip()
        if answer.endswith("."):
            answer = answer[:-1]
    else:
        result = re.search("(?s:.*)<answer>(.*)</answer>", solution)
        if not result:
            return 0
        answer = result.group(1)
    if len(answer.split()) == 0:
        return 0
    if correct_answer.startswith("(") and correct_answer.endswith(")"):
        answer = answer.split()[0]
        score = 1 if answer == correct_answer or answer == correct_answer[1:-1] else 0
    else:
        score = 1 if answer == correct_answer else 0
    return score