from datasets import concatenate_datasets

class QQPScenario:
    def __init__(self, choices=None, choice_pivot_pos=None, trigger_token=None, question_format='gpt') -> None:
        self.choice_content = ['Not duplicate.', 'Duplicate.']
        self.question_format = question_format
        if choices is None:
            self.choices = ['(A)', '(B)', '(C)', '(D)']
            self.choice_pivot_pos = 1
            self.trigger_token = ' ('
        else:
            self.choices = choices
            self.choice_pivot_pos = choice_pivot_pos
            self.trigger_token = trigger_token
        
        self.instruction = None

    # From https://github.com/bigscience-workshop/promptsource/blob/main/promptsource/templates/glue/qqp/templates.yaml
    def format_data_pure(self, example, test=False, model_answer=None):
        choices = self.choices
        answer = example['label'] if model_answer is None else model_answer
        
        if self.question_format == 'anthropic':
            if test:
                if self.trigger_token is not None:
                    return f"Question1: {example['text1'].strip()}\nQuestion2: {example['text2'].strip()}\nThese questions are:\n" + 'Choices:\n' + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:" + self.trigger_token
                else:
                    return f"Question1: {example['text1'].strip()}\nQuestion2: {example['text2'].strip()}\nThese questions are:\n" + 'Choices:\n' + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:"
            else:
                return f"Question1: {example['text1'].strip()}\nQuestion2: {example['text2'].strip()}\nThese questions are:\n" + 'Choices:\n' + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, self.choice_content)]) + f"\nAnswer: {choices[answer]}\n\n"
        elif self.question_format == 'gpt':
            if test:
                if self.trigger_token is not None:
                    return f"Question1: {example['text1'].strip()}\nQuestion2: {example['text2'].strip()}\nThese questions are:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:" + self.trigger_token
                else:
                    return f"Question1: {example['text1'].strip()}\nQuestion2: {example['text2'].strip()}\nThese questions are:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:"
            else:
                return f"Question1: {example['text1'].strip()}\nQuestion2: {example['text2'].strip()}\nThese questions are:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + f"\nAnswer: {choices[answer]}\n\n"
        else:
            raise NotImplementedError()

    def format_data_dialog(self, example, model_answer=None):
        choices = self.choices
        answer = example['label'] if model_answer is None else model_answer
        if self.question_format == 'anthropic':
            return f"Question1: {example['text1'].strip()}\nQuestion2: {example['text2'].strip()}\nThese questions are:\n" + 'Choices:\n' + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:", f"{choices[answer]}"
        elif self.question_format == 'gpt':
            return f"Question1: {example['text1'].strip()}\nQuestion2: {example['text2'].strip()}\nThese questions are:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:", f"{choices[answer]}"
        else:
            raise NotImplementedError()
        
    def format_data_training(self, example, choices):
        answer = example['label']
        if self.question_format == 'anthropic':
            return f"Question1: {example['text1'].strip()}\nQuestion2: {example['text2'].strip()}\nThese questions are:\n" + 'Choices:\n' + '\n'.join([f"{choice} {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:", f"{choices[answer]}"
        elif self.question_format == 'gpt':
            return f"Question1: {example['text1'].strip()}\nQuestion2: {example['text2'].strip()}\nThese questions are:\n" + '\n'.join([f"{choice}. {answer}" for choice, answer in zip(choices, self.choice_content)]) + "\nAnswer:", f"{choices[answer]}"
        else:
            raise NotImplementedError()
    
    def get_dataset_split(self, dataset, prompt_seed=42, icl_config='few-shot', num_test_samples=5000, **kwargs):
        if icl_config == 'zero-shot':
            icl_dataset = None
        else:
            icl_dataset = dataset["train"].select(range(5000))
            icl_dataset_true = icl_dataset.filter(lambda x: x['label'] == 1).select(range(50))
            icl_dataset_false = icl_dataset.filter(lambda x: x['label'] == 0).select(range(50))
            icl_dataset = concatenate_datasets([icl_dataset_false, icl_dataset_true]).shuffle(prompt_seed)
        
        test_dataset = dataset["validation"].shuffle(seed=2023)
        test_dataset_true = test_dataset.filter(lambda x: x['label'] == 1).select(range(num_test_samples // 2))
        test_dataset_false = test_dataset.filter(lambda x: x['label'] == 0).select(range(num_test_samples // 2))
        test_dataset = concatenate_datasets([test_dataset_false, test_dataset_true])
        
        return icl_dataset, test_dataset

    def get_training_dataset(self, dataset, num_train_samples=2000):
        train_dataset = dataset["train"].shuffle(2023).select(range(50000))
        train_dataset_true = train_dataset.filter(lambda x: x['label'] == 1).select(range(num_train_samples // 2))
        train_dataset_false = train_dataset.filter(lambda x: x['label'] == 0).select(range(num_train_samples // 2))
        train_dataset = concatenate_datasets([train_dataset_false, train_dataset_true])
        return train_dataset