from datasets import concatenate_datasets

class MMLUScenario:
    def __init__(self, subset, choices=None, choice_pivot_pos=None, trigger_token=None, question_format='gpt', perm_idx=0) -> None:
        self.question_format = question_format
        if choices is None:
            self.choices = ['(A)', '(B)', '(C)', '(D)']
            self.choice_pivot_pos = 1
            self.trigger_token = ' ('
        else:
            self.choices = choices
            self.choice_pivot_pos = choice_pivot_pos
            self.trigger_token = trigger_token

        self.subset_name = subset
        self.instruction = "The following are multiple choice questions (with answers) about {}.".format(self.subset_name.replace("_", " "))
        self.class_list = range(len(self.choices))
        self.perms = [
            [0, 1, 2, 3],
            [3, 2, 1, 0],
            [1, 0, 3, 2],
            [2, 1, 0, 3],
            [0, 2, 3, 1],
        ]
        self.perm_idx = perm_idx
        assert self.perm_idx < len(self.perms) and self.perm_idx >= 0

    def format_data_pure(self, example, test=False, model_answer=None):
        choices = self.choices
        candidates = example['choices']
        answer = example['answer'] if model_answer is None else model_answer
        # Apply permutation
        if self.perm_idx != 0:
            candidates_perm = [candidates[i] for i in self.perms[self.perm_idx]]
            answer_perm = self.perms[self.perm_idx].index(answer)
            candidates = candidates_perm
            answer = answer_perm

        if self.question_format == 'anthropic':
            if test:
                if self.trigger_token is not None:
                    return f"Question: {example['question'].strip()}\n" + 'Choices:\n' + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:" + self.trigger_token
                else:
                    return f"Question: {example['question'].strip()}\n" + 'Choices:\n' + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:"
            else:
                return f"Question: {example['question'].strip()}\n" + 'Choices:\n' + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, candidates)]) + f"\nAnswer: {choices[answer]}\n\n"
        elif self.question_format == 'gpt':
            if test:
                if self.trigger_token is not None:
                    return f"{example['question'].strip()}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:" + self.trigger_token
                else:
                    return f"{example['question'].strip()}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:"
            else:
                return f"{example['question'].strip()}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, candidates)]) + f"\nAnswer: {choices[answer]}\n\n"
        else:
            raise NotImplementedError()

    def format_data_dialog(self, example, model_answer=None):
        choices = self.choices
        answer = example['answer'] if model_answer is None else model_answer
        if self.question_format == 'anthropic':
            return f"Question: {example['question'].strip()}\n" + 'Choices:\n' + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, example['choices'])]) + "\nAnswer:", f"{choices[answer]}"
        elif self.question_format == 'gpt':
            return f"{example['question'].strip()}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, example['choices'])]) + "\nAnswer:", f"{choices[answer]}"
        else:
            raise NotImplementedError()

    # def format_data_perplexity_seq(self, example):
    #     choices = self.choices
    #     return f"{example['question'].strip()}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, example['choices'])])

    def get_dataset_split(self, dataset, prompt_seed=42, icl_config='few-shot', num_test_samples=None, eval_split="test"):
        if icl_config == 'zero-shot':
            icl_dataset = None
        else:
            # icl_dataset = concatenate_datasets([dataset['dev'], dataset["validation"]]).shuffle(seed=prompt_seed)
            icl_dataset = dataset['dev'].shuffle(seed=prompt_seed)
        
        if eval_split == "val":
            test_dataset = dataset["validation"]
        else:
            if num_test_samples is None:
                test_dataset = dataset["test"]
            else:
                test_dataset = dataset["test"].shuffle(seed=2023).select(range(min(num_test_samples, len(dataset["test"]))))
        return icl_dataset, test_dataset

