import jax
from typing import Iterator

import dataclasses
from jax_privacy.experiments import image_data as data
from jax_privacy.experiments.image_classification import config_base

@dataclasses.dataclass(kw_only=True, slots=True)
class ExperimentConfig(config_base.ExperimentConfig):
    name: str
    pretrain_path: str=None
    replace_last_layer: bool=False
    
    # def build_train_input(self) -> Iterator[data.DataInputs]:
    #     """Builds the training input pipeline."""
    #     return self._config.data_train.load_dataset(
    #         batch_dims=(
    #             jax.local_device_count(),
    #             self._config.training.batch_size.per_device_per_step,
    #         ),
    #         is_training=True,
    #         shard_data=True,
    #         drop_metadata=False,
    #     )