from datasets import concatenate_datasets
from misc.utils import load_canary_dataset, load_wikitext_dataset
from argparse import ArgumentParser

class StandardP3:
    def __init__(self, dataset_name):
        self.canary_dataset = load_canary_dataset(dataset_name=dataset_name)
        self.background_dataset = load_wikitext_dataset()

    def create_dataset(self):  
        background_dataset = self.background_dataset.select(range(len(self.background_dataset) - len(self.canary_dataset['train'])))
        dataset = concatenate_datasets([background_dataset, self.canary_dataset['train']])
        dataset = dataset.shuffle(seed=42)
        assert len(self.background_dataset) == len(dataset) 
        return dataset

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--dataset_name", type=str)
    parser.add_argument("--output_dataset_path", type=str)
    args = parser.parse_args()

    standard_p3 = StandardP3(dataset_name=args.dataset_name)
    dataset = standard_p3.create_dataset()
    dataset.save_to_disk(args.output_dataset_path)