import numpy as np

from datasets import ClassLabel

class LogiQAScenario:
    def __init__(self, choices=None, choice_pivot_pos=None, trigger_token=None, question_format='gpt', perm_idx=0) -> None:
        self.question_format = question_format
        if choices is None:
            self.choices = ['(A)', '(B)', '(C)', '(D)']
            self.choice_pivot_pos = 1
            self.trigger_token = ' ('
        else:
            self.choices = choices
            self.choice_pivot_pos = choice_pivot_pos
            self.trigger_token = trigger_token

        self.instruction = None
        self.perms = [
            [0, 1, 2, 3],
            [3, 2, 1, 0],
            [1, 0, 3, 2],
            [2, 1, 0, 3],
            [0, 2, 3, 1],
        ]
        self.perm_idx = perm_idx
        assert self.perm_idx < len(self.perms) and self.perm_idx >= 0

    # Template from promptsource
    def format_data_pure(self, example, test=False, model_answer=None):
        choices = self.choices
        candidates = example['options']
        answer = example['label'] if model_answer is None else model_answer
        # Apply permutation
        if self.perm_idx != 0:
            candidates_perm = [candidates[i] for i in self.perms[self.perm_idx]]
            answer_perm = self.perms[self.perm_idx].index(answer)
            candidates = candidates_perm
            answer = answer_perm

        if self.question_format == 'anthropic':
            if test:
                if self.trigger_token is not None:
                    return f"Question: {example['context']}\n{example['query']}\n" + "Choices:\n" + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:" + self.trigger_token
                else:
                    return f"Question: {example['context']}\n{example['query']}\n" + "Choices:\n" + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:"
            else:
                return f"Question: {example['context']}\n{example['query']}\n" + "Choices:\n" + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, candidates)]) + f"\nAnswer: {choices[answer]}\n\n"
        elif self.question_format == 'gpt':
            if test:
                if self.trigger_token is not None:
                    return f"{example['context']}\n{example['query']}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:" + self.trigger_token
                else:
                    return f"{example['context']}\n{example['query']}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:"
            else:
                return f"{example['context']}\n{example['query']}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, candidates)]) + f"\nAnswer: {choices[answer]}\n\n"

    def format_data_dialog(self, example, model_answer=None):
        choices = self.choices
        answer = example['label'] if model_answer is None else model_answer
        if self.question_format == 'anthropic':
            return f"Question: {example['context']}\n{example['query']}\n" + "Choices:\n" + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, example['options'])]) +  "\nAnswer:", f"{choices[answer]}"
        elif self.question_format == 'gpt':
            return f"{example['context']}\n{example['query']}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, example['options'])]) + "\nAnswer:", f"{choices[answer]}"
        else:
            raise NotImplementedError()

    def get_dataset_split(self, dataset, prompt_seed=42, icl_config='few-shot', **kwargs):
        if icl_config == 'zero-shot':
            icl_dataset = None
        else:
            icl_dataset = dataset['train'].shuffle(seed=prompt_seed)
            icl_dataset = icl_dataset.map(lambda x: {'label': x['correct_option']}, remove_columns=['correct_option'])
        
        # val_dataset = dataset["validation"]
        # val_dataset = val_dataset.map(lambda x: {'label': x['correct_option']}, remove_columns=['correct_option'])
        test_dataset = dataset["test"]
        test_dataset = test_dataset.map(lambda x: {'label': x['correct_option']}, remove_columns=['correct_option'])
        return icl_dataset, test_dataset

