from datasets import load_dataset
from abc import ABC, abstractmethod
from utils import load_json


def multiple_choice_query(inputs, targets):
    inputs = inputs.strip("\nA:")
    new_targets = "\n"
    choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"][:len(targets)]
    for letter, target in zip(choices, targets):
        new_targets+=f"{letter}.) {target}\n"
    prompt = inputs+new_targets+'Answer:'
    return prompt


class PromptBuilder(ABC):
    def __init__(self, dataset, example_separator='\n'):
        self.dataset = dataset
        self.example_separator = example_separator

    @abstractmethod
    def build_open_prompt(self, index, include_answer=False):
        """
        a function that takes dataset[index] and builds a prompt for a single example. It differs between tasks
        include answer should only be True for few shot examples
        returns: prompt and answer
        """
        raise NotImplementedError("Building the prompt depends on the dataset and needs to be specified")

    @abstractmethod
    def build_mc_prompt(self, index, include_answer=False):
        """
        a function that takes dataset[index] and builds a prompt for a single example with multiple choices. It differs between tasks
        include answer should only be True for few shot examples
        returns: prompt and answer
        """
        raise NotImplementedError("Building the prompt depends on the dataset and needs to be specified")

    def nshot_open_prompt(self, nshots, index):
        """
        nshots: int number of shots. 0 means just return dataset[index] prompt. 1 means 1 example and 1 test prompt
        dataset a bigbench dataset that can be indexed by index. dataset[index]
        index: the index of the test example. This method will use the [index-nshots, index) datapoints as the nshot examples, wrapping if necessary
        """
        ex_sep = self.example_separator
        prompt = ''
        for i in range(index-nshots, index):
            p, _ = self.build_open_prompt(i, include_answer=True)
            prompt+=p+ex_sep

        final_prompt, gt = self.build_open_prompt(index)
        prompt+=final_prompt
        return prompt, gt

    def nshot_mc_prompt(self, nshots, index):
        """
        nshots: int number of shots. 0 means just return dataset[index] prompt. 1 means 1 example and 1 test prompt
        dataset a bigbench dataset that can be indexed by index. dataset[index]
        index: the index of the test example. This method will use the [index-nshots, index) datapoints as the nshot examples, wrapping if necessary
        returns: prompt, targets for the last question, the index in targets of the right answer
        """
        ex_sep = '\n'
        prompt = ''
        for i in range(index-nshots, index):
            p, _, _ = self.build_mc_prompt(i, include_answer=True)
            prompt+=p+ex_sep

        final_prompt, targets, gt_idx = self.build_mc_prompt(index)
        prompt+=final_prompt
        return prompt, targets, gt_idx

    def __len__(self):
        return len(self.dataset)

class DefaultOpenPrompter(PromptBuilder):

    def __init__(self, dataset):
        super(DefaultOpenPrompter, self).__init__(dataset)

    def build_open_prompt(self, index, include_answer=False):
        """
        nshots: int number of shots. 0 means just return dataset[index] prompt. 1 means 1 example and 1 test prompt
        dataset a bigbench dataset that can be indexed by index. dataset[index]
        index: the index of the test example. This method will use the [index-nshots, index) datapoints as the nshot examples, wrapping if necessary
        returns: a string prompt, the answer as a string
        """
        datapoint = self.dataset[index]
        inputs = datapoint['inputs']#, datapoint['multiple_choice_targets']
        gt = datapoint['targets'][0]
        #gt_idx = targets.index(gt)
        prompt = inputs
        if include_answer:
            prompt+=' '+gt
        return prompt, gt

    def build_mc_prompt(self, index, include_answer=False):
        return index

    def get_mc_targets(self, idx):
        return self.dataset[idx]['multiple_choice_targets']
    

def load_bigbench_task(task):
    dataset = load_dataset("bigbench", task)
    return dataset

def load_bigbench_from_results(jsonpath):
    data = load_json(jsonpath)
    return data
