from datasets import load_dataset

class SNLITransformation(object):
    """
    Parent class for transforming the SNLI input to extract some particular input
    attribute (e.g., just the hypothesis by leaving out the premise). Also needed
    to reformat the input into a single string. The transformed data is saved as a CSV.
    """
    def __init__(self, name, output_dir, train_size=1.0):
        """
        Args:
            name: Transformation name
            output_dir: where to save the CSV with the transformed attribute
            train_size: fraction of the training data to use
        """
        self.train_data = load_dataset('snli', split='train').filter(lambda x: x['label'] != -1)
        self.test_data = load_dataset('snli', split='test').filter(lambda x: x['label'] != -1)
        self.name = name
        self.output_dir = output_dir
        self.train_size = train_size

    def transformation(self, example):
        raise NotImplementedError

    def transform(self):
        logging.info(f'Applying {self.name} to SNLI')

        if self.train_size < 1:
            train_data = self.train_data.train_test_split(train_size=self.train_size)['train']
        else:
            train_data = self.train_data

        train_data.map(self.transformation).to_csv(os.path.join(self.output_dir, f'snli_train_{self.name}' + (f'_{self.train_size}' if self.train_size < 1.0 else '') + '.csv'))
        self.test_data.map(self.transformation).to_csv(os.path.join(self.output_dir, f'snli_test_{self.name}.csv'))
        

class SNLIStandardTransformation(SNLITransformation):
    def __init__(self, output_dir, train_size=1, suffix=''):
        super().__init__(f'std{suffix}', output_dir, train_size=train_size)

    def transformation(self, example):
        example['sentence1'] = f"PREMISE: {example['premise']} HYPOTHESIS: {example['hypothesis']}"
        return example


class SNLINullTransformation(SNLITransformation):
    def __init__(self, output_dir, train_size=1, suffix=''):
        super().__init__(f'null{suffix}', output_dir, train_size=train_size)

    def transformation(self, example):
        example['sentence1'] = " " # using only empty string can yield problems
        return example