import numpy as np

from datasets import ClassLabel

class HellaSWAGScenario:
    def __init__(self, choices=None, choice_pivot_pos=None, trigger_token=None, question_format='gpt', perm_idx=0) -> None:
        self.question_format = question_format
        if choices is None:
            self.choices = ['(A)', '(B)', '(C)', '(D)']
            self.choice_pivot_pos = 1
            self.trigger_token = ' ('
        else:
            self.choices = choices
            self.choice_pivot_pos = choice_pivot_pos
            self.trigger_token = trigger_token

        self.instruction = "The following are multiple choice questions (with answers) about commensense reasoning."
        self.perms = [
            [0, 1, 2, 3],
            [3, 2, 1, 0],
            [1, 0, 3, 2],
            [2, 1, 0, 3],
            [0, 2, 3, 1],
        ]
        self.perm_idx = perm_idx
        assert self.perm_idx < len(self.perms) and self.perm_idx >= 0

    # Template from promptsource
    def format_data_pure(self, example, test=False, model_answer=None):
        choices = self.choices
        candidates = example['endings']
        answer = example['label'] if model_answer is None else model_answer
        # Apply permutation
        if self.perm_idx != 0:
            candidates_perm = [candidates[i] for i in self.perms[self.perm_idx]]
            answer_perm = self.perms[self.perm_idx].index(answer)
            candidates = candidates_perm
            answer = answer_perm
        
        if self.question_format == 'anthropic':
            if test:
                if self.trigger_token is not None:
                    return f"Question: How does this sentence end?\nSentence: {example['ctx'].strip()}\n" + "Choices:\n" + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:" + self.trigger_token
                else:
                    return f"Question: How does this sentence end?\nSentence: {example['ctx'].strip()}\n" + "Choices:\n" + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:"
            else:
                return f"Question: How does this sentence end?\nSentence: {example['ctx'].strip()}\n" + "Choices:\n" + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, candidates)]) + f"\nAnswer: {choices[answer]}\n\n"
        elif self.question_format == 'gpt':
            if test:
                if self.trigger_token is not None:
                    return f"How does this sentence end?\nSentence: {example['ctx'].strip()}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:" + self.trigger_token
                else:
                    return f"How does this sentence end?\nSentence: {example['ctx'].strip()}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, candidates)]) + "\nAnswer:"
            else:
                return f"How does this sentence end?\nSentence: {example['ctx'].strip()}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, candidates)]) + f"\nAnswer: {choices[answer]}\n\n"
        else:
            raise NotImplementedError()

    def format_data_dialog(self, example, model_answer=None):
        choices = self.choices
        answer = example['label'] if model_answer is None else model_answer
        if self.question_format == 'anthropic':
            return f"Question: How does this sentence end?\nSentence: {example['ctx'].strip()}\n" + "Choices:\n" + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, example['endings'])]) + "\nAnswer:", f"{choices[answer]}"
        elif self.question_format == 'gpt':
            return f"How does this sentence end?\nSentence: {example['ctx'].strip()}\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, example['endings'])]) + "\nAnswer:", f"{choices[answer]}"
        else:
            raise NotImplementedError()

    def get_dataset_split(self, dataset, prompt_seed=42, icl_config='few-shot', num_test_samples=5000, eval_split='test'):
        if icl_config == 'zero-shot':
            icl_dataset = None
        else:
            icl_dataset = dataset['train'].shuffle(seed=prompt_seed)
            icl_dataset = icl_dataset.cast_column("label", ClassLabel(num_classes=4, names=['0', '1', '2', '3'], id=None))
            
        
        test_dataset = dataset["validation"].shuffle(seed=2023).select(range(num_test_samples))
        test_dataset = test_dataset.cast_column("label", ClassLabel(num_classes=4, names=['0', '1', '2', '3'], id=None))
        # Ignore the test set since no label is provided.
        return icl_dataset, test_dataset

