import torch
from datasets import load_dataset
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers import (
    BertTokenizer, default_data_collator
)
import json
import random

class TrojanDataModule(pl.LightningDataModule):
    def __init__(self, model_name_or_path, train_file, preprocessing_num_workers, overwrite_cache, max_seq_length, mlm_probability):
        super().__init__()