from datasets import concatenate_datasets

class OpenBookQAScenario:
    def __init__(self, choices=None, choice_pivot_pos=None, trigger_token=None, question_format='gpt', perm_idx=0) -> None:
        self.question_format = question_format
        if choices is None:
            self.choices = ['(A)', '(B)', '(C)', '(D)']
            self.choice_pivot_pos = 1
            self.trigger_token = ' ('
        else:
            self.choices = choices
            self.choice_pivot_pos = choice_pivot_pos
            self.trigger_token = trigger_token

        self.instruction = None
        self.letter2idx = {"A": 0, "B": 1, "C": 2, "D": 3}
        self.perms = [
            [0, 1, 2, 3],
            [3, 2, 1, 0],
            [1, 0, 3, 2],
            [2, 1, 0, 3],
            [0, 2, 3, 1],
        ]
        self.perm_idx = perm_idx
        assert self.perm_idx < len(self.perms) and self.perm_idx >= 0
         

    def format_data_pure(self, example, test=False, model_answer=None):
        choices = self.choices
        candidates = example['choices']['text']
        answer = example['label'] if model_answer is None else model_answer
        # Apply permutation
        if self.perm_idx != 0:
            candidates_perm = [candidates[i] for i in self.perms[self.perm_idx]]
            answer_perm = self.perms[self.perm_idx].index(answer)
            candidates = candidates_perm
            answer = answer_perm

        if self.question_format == 'anthropic':
            if test:
                if self.trigger_token is not None:
                    return f"Question: {example['question_stem'].strip()}\n" + "Choices:\n" + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:" + self.trigger_token
                else:
                    return f"Question: {example['question_stem'].strip()}\n" + "Choices:\n" + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:"               
            else:
                return f"Question: {example['question_stem'].strip()}\n" + "Choices:\n" + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, candidates)]) + f"\nAnswer: {choices[answer]}\n\n"
        elif self.question_format == 'gpt':
            if test:
                if self.trigger_token is not None:
                    return f"{example['question_stem'].strip()}\n"  + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:" + self.trigger_token
                else:
                    return f"{example['question_stem'].strip()}\n"  + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:"               
            else:
                return f"{example['question_stem'].strip()}\n"  + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, candidates)]) + f"\nAnswer: {choices[answer]}\n\n"
        else:
            raise NotImplementedError()

    def format_data_dialog(self, example, model_answer=None):
        choices = self.choices
        candidates = example['choices']['text']
        # Apply permutation
        if self.perm_idx != 0:
            candidates = [candidates[i] for i in self.perms[self.perm_idx]]
        answer = example['label'] if model_answer is None else model_answer
        if self.question_format == 'anthropic':
            return f"Question: {example['question_stem'].strip()}\n" + "Choices:\n" + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:", f"{choices[answer]}"
        elif self.question_format == 'gpt':
            return f"{example['question_stem'].strip()}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:", f"{choices[answer]}"
        else:
            raise NotImplementedError()

    def get_dataset_split(self, dataset, prompt_seed=42, icl_config='few-shot', **kwargs):
        if icl_config == 'zero-shot':
            icl_dataset = None
        else:
            icl_dataset = dataset['train'].shuffle(seed=prompt_seed)
            icl_dataset = icl_dataset.map(lambda x: {'label': self.letter2idx[x['answerKey']]}, remove_columns=['answerKey'])

        test_dataset = concatenate_datasets([dataset['validation'], dataset['test']])
        test_dataset = test_dataset.map(lambda x: {'label': self.letter2idx[x['answerKey']]}, remove_columns=['answerKey'])
        
        return icl_dataset,  test_dataset