from datasets import concatenate_datasets

class CivilCommentsScenario:
    def __init__(self, choices=None, choice_pivot_pos=None, trigger_token=None, question_format='gpt', perm_idx=0) -> None:
        self.choice_content = ["non-toxic.", "toxic."]
        self.question_format = question_format
        if choices is None:
            self.choices = ['(A)', '(B)', '(C)', '(D)']
            self.choice_pivot_pos = 1
            self.trigger_token = ' ('
        else:
            self.choices = choices
            self.choice_pivot_pos = choice_pivot_pos
            self.trigger_token = trigger_token
        
        self.instruction = None
        self.perms = [
            [0, 1,],
            [1, 0]
        ]
        self.perm_idx = perm_idx
        assert self.perm_idx < len(self.perms) and self.perm_idx >= 0

        # Apply permutation
        if self.perm_idx != 0:
            self.choice_content = [self.choice_content[i] for i in self.perms[self.perm_idx]]

    # Format reference: https://github.com/stanford-crfm/helm/blob/main/src/helm/benchmark/scenarios/civil_comments_scenario.py 
    def format_data_pure(self, example, test=False, model_answer=None):
        choices = self.choices

        answer = example['label'] if model_answer is None else model_answer
        if self.perm_idx != 0:
            answer = self.perms[self.perm_idx].index(answer)
        
        if self.question_format == 'anthropic':
            raise NotImplementedError()
        elif self.question_format == 'gpt':
            if test:
                if self.trigger_token is not None:
                    return f"Comment: {example['text'].strip()}\nQuestion: The above comment is:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:" + self.trigger_token
                else:
                    return f"Comment: {example['text'].strip()}\nQuestion: The above comment is:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:"
            else:
                return f"Comment: {example['text'].strip()}\nQuestion: The above comment is:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + f"\nAnswer: {choices[answer]}\n\n"
        else:
            raise NotImplementedError()


    def format_data_dialog(self, example, model_answer=None):
        choices = self.choices
        answer = example['label'] if model_answer is None else model_answer
        if self.question_format == 'anthropic':
            raise NotImplementedError()
        elif self.question_format == 'gpt':
            return f"Comment: {example['text'].strip()}\nQuestion: The above comment is:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:", f"{choices[answer]}"
        else:
            raise NotImplementedError()
    
    def format_data_training(self, example, choices):
        answer = example['label']
        if self.question_format == 'anthropic':
            raise NotImplementedError()
        elif self.question_format == 'gpt':
            return f"Comment: {example['text'].strip()}\nQuestion: The above comment is:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:", f"{choices[answer]}"
        else:
            raise NotImplementedError()

    def get_dataset_split(self, dataset, prompt_seed=42, icl_config='few-shot', num_test_samples=5000, **kwargs):
        if icl_config == 'zero-shot':
            icl_dataset = None
        else:
            icl_dataset = dataset["train"].select(range(50000))
        
        test_dataset = dataset["test"].shuffle(seed=2023)
        
        def preprocessor(examples):
            if examples['toxicity'] >= 0.5:
                examples['label'] = 1
            else:
                examples['label'] = 0
            return examples
        if icl_dataset is not None:
            icl_dataset = icl_dataset.map(preprocessor)
            icl_dataset_true = icl_dataset.filter(lambda x: x['label'] == 1).select(range(50))
            icl_dataset_false = icl_dataset.filter(lambda x: x['label'] == 0).select(range(50))
            icl_dataset = concatenate_datasets([icl_dataset_false, icl_dataset_true]).shuffle(prompt_seed)

        test_dataset = test_dataset.map(preprocessor)
        test_dataset_true = test_dataset.filter(lambda x: x['label'] == 1).select(range(num_test_samples // 2))
        test_dataset_false = test_dataset.filter(lambda x: x['label'] == 0).select(range(num_test_samples // 2))
        test_dataset = concatenate_datasets([test_dataset_false, test_dataset_true])

        return icl_dataset, test_dataset

    def get_training_dataset(self, dataset, num_train_samples=2000):
        train_dataset = dataset["train"].shuffle(2023).select(range(50000))
        def preprocessor(examples):
            if examples['toxicity'] >= 0.5:
                examples['label'] = 1
            else:
                examples['label'] = 0
            return examples
        train_dataset = train_dataset.map(preprocessor)
        train_dataset_true = train_dataset.filter(lambda x: x['label'] == 1).select(range(num_train_samples // 2))
        train_dataset_false = train_dataset.filter(lambda x: x['label'] == 0).select(range(num_train_samples // 2))
        train_dataset = concatenate_datasets([train_dataset_false, train_dataset_true])
        return train_dataset