import numpy as np

from jax_privacy.experiments import image_data
from jax_privacy.experiments.image_classification import config_base
from jax_privacy.experiments.image_classification.models import models
from jax_privacy.src.training import averaging
from jax_privacy.src.training import experiment_config
from jax_privacy.src.training import optimizer_config
import ml_collections

from src import image_data
from src.image_classification import ExperimentConfig

depth = 16
width = 4

# name for experiment **should match the name of this python file**
name = f'nondp_cinic10_without_cifar_wrn_{depth}_{width}'

# calc num steps
epochs = 25
batch_size = 128
train_set_size = 210000
num_steps = 100000

def get_config() -> ml_collections.ConfigDict:
  """Experiment config."""

  config = ExperimentConfig(
      name=name,
      optimizer=optimizer_config.sgd_config(
          lr=optimizer_config.constant_lr_config(0.1),
        #   lr=optimizer_config.cosine_decay_lr_config(init_value=0.1, alpha=0),
          momentum=0.9,
      ),
      model=models.WideResNetConfig(
          depth=depth,
          width=width,
      ),
      training=experiment_config.TrainingConfig(
          num_updates=num_steps,
          batch_size=experiment_config.BatchSizeTrainConfig(
              total=batch_size,
              per_device_per_step=batch_size,
          ),
          weight_decay=5e-4,  # L-2 regularization,
          train_only_layer=None,
          dp=experiment_config.DPConfig.deactivated(),
      ),
      averaging={
          'ema': averaging.ExponentialMovingAveragingConfig(decay=0.999),
      },
      data_train = image_data.Cinic10WithoutCifarLoader(
          config=image_data.Cinic10WithoutCifarTrainConfig(
              preprocess_name='standardise',
          )
      ),
      data_eval=image_data.Cifar10Loader(
          config=image_data.Cifar10TestConfig(
              preprocess_name='standardise',
          ),
      ),
      evaluation=experiment_config.EvaluationConfig(
          batch_size=100,
      ),
  )

  return config_base.build_jaxline_config(
      experiment_config=config,
  )
