from datasets import concatenate_datasets

class IMDBScenario:
    def __init__(self, choices=None, choice_pivot_pos=None, trigger_token=None, question_format='gpt', perm_idx=0) -> None:
        self.choice_content = ["negative.", "positive."]
        self.question_format = question_format
        if choices is None:
            self.choices = ['(A)', '(B)', '(C)', '(D)']
            self.choice_pivot_pos = 1
            self.trigger_token = ' ('
        else:
            self.choices = choices
            self.choice_pivot_pos = choice_pivot_pos
            self.trigger_token = trigger_token
        
        self.instruction = None
        self.perms = [
            [0, 1,],
            [1, 0]
        ]
        self.perm_idx = perm_idx
        assert self.perm_idx < len(self.perms) and self.perm_idx >= 0

        # Apply permutation
        if self.perm_idx != 0:
            self.choice_content = [self.choice_content[i] for i in self.perms[self.perm_idx]]
        
    # Format reference: https://github.com/stanford-crfm/helm/blob/main/src/helm/benchmark/scenarios/imdb_scenario.py
    def format_data_pure(self, example, test=False, model_answer=None):
        choices = self.choices
        answer = example['label'] if model_answer is None else model_answer
        if self.perm_idx != 0:
            answer = self.perms[self.perm_idx].index(answer)
             
        if self.question_format == 'anthropic':
            raise NotImplementedError()
        elif self.question_format == 'gpt':
            if test:
                if self.trigger_token is not None:
                    return f"{example['text'].strip()}\nQuestion: The sentiment of the review above is:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:" + self.trigger_token
                else:
                    return f"{example['text'].strip()}\nQuestion: The sentiment of the review above is:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:"
            else:
                return f"{example['text'].strip()}\nQuestion: The sentiment of the review above is:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + f"\nAnswer: {choices[answer]}\n\n"
        else:
            raise NotImplementedError()

    def format_data_dialog(self, example, model_answer=None):
        choices = self.choices
        answer = example['label'] if model_answer is None else model_answer
        if self.question_format == 'anthropic':
            raise NotImplementedError()
        elif self.question_format == 'gpt':
            return f"{example['text'].strip()}\nQuestion: The sentiment of the review above is:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:", f"{choices[answer]}"
        else:
            raise NotImplementedError()
    
    def get_dataset_split(self, dataset, prompt_seed=42, icl_config='few-shot', num_test_samples=5000):
        if icl_config == 'zero-shot':
            icl_dataset = None
        else:
            icl_dataset = dataset["train"].select(range(10000))
            icl_dataset_true = icl_dataset.filter(lambda x: x['label'] == 1).select(range(50))
            icl_dataset_false = icl_dataset.filter(lambda x: x['label'] == 0).select(range(50))
            icl_dataset = concatenate_datasets([icl_dataset_false, icl_dataset_true]).shuffle(prompt_seed)
        
        test_dataset = dataset["test"].shuffle(seed=2023)
        test_dataset_true = test_dataset.filter(lambda x: x['label'] == 1).select(range(num_test_samples // 2))
        test_dataset_false = test_dataset.filter(lambda x: x['label'] == 0).select(range(num_test_samples // 2))
        test_dataset = concatenate_datasets([test_dataset_false, test_dataset_true])
        return icl_dataset, test_dataset