from tensorflow_datasets.datasets import qm9

class MoleculeDataset():
  """
  Custom dataset that imports the molecular datasets.
  """
  def __init__(
    self,
    dset_name: str = "QM9",
    data_dir: str = "/mnt/disk/data/",
    elements: list = None,
    split: str = "train",
    small: bool = False,
    atomic_radius: float = .5,
    max_n_atoms: int = 80,
  ):
    
    if elements is None:
        elements = ELEMENTS_HASH
    assert dset_name in ["qm9", "drugs"], "dset_name must be qm9 or drugs"
    assert split in ["train", "val", "test"], "split must be train, val or test"
    self.dset_name = dset_name
    self.data_dir = data_dir
    self.split = split
    self.atomic_radius = atomic_radius
    self.max_n_atoms = max_n_atoms

    self.data = torch.load(os.path.join(data_dir, dset_name, f"{split}_data.pth"))
    if small:
        self.data = self.data[:5000]

    # Add any extra data preprocessing if needed
    if max_n_atoms > 0:
        self._filter_by_n_atoms()
    self._filter_by_elements(elements)