import numpy as np

from jax_privacy.experiments import image_data
from jax_privacy.experiments.image_classification import config_base
from jax_privacy.experiments.image_classification.models import models
from jax_privacy.src.training import averaging
from jax_privacy.src.training import experiment_config
from jax_privacy.src.training import optimizer_config
import ml_collections

from src import image_data as image_data_canaries
from src.image_classification import ExperimentConfig

# data parameters
data_seed = 0
n = 49500
m = 1000
r = 49000
mislabel_canaries = False # whether flip labels
craft_type = 'sgd'

# dp-sgd parameters
epsilon = 8.0
delta = 1e-5
clipping_norm = 1
augmult = None
#
batch_size = 4096
lr = 4.0
num_steps = 1000

# name for experiment **should match the name of this python file**
craft_type_name = 'mislabel' if (mislabel_canaries and craft_type == 'none') else craft_type
name = f'cifar10_crafted_{craft_type_name}_wrn_16_4_no_aug_eps{epsilon}_n{n}_m{m}_r{r}_seed{data_seed}'

def get_config() -> ml_collections.ConfigDict:
  """Experiment config."""

  config = ExperimentConfig(
      name=name,
      optimizer=optimizer_config.sgd_config(
          lr=optimizer_config.constant_lr_config(lr),
      ),
      model=models.WideResNetConfig(
          depth=16,
          width=4,
      ),
      training=experiment_config.TrainingConfig(
          num_updates=num_steps, 
          batch_size=experiment_config.BatchSizeTrainConfig(
              total=batch_size,
              per_device_per_step=64,
          ),
          weight_decay=0.0,  # L-2 regularization,
          train_only_layer=None,
          dp=experiment_config.DPConfig(
              delta=delta,
              clipping_norm=clipping_norm,
              auto_tune_target_epsilon=epsilon,
              noise_multiplier=None,
              rescale_to_unit_norm=True,
              auto_tune_field='noise_multiplier',
          ),
          logging=experiment_config.LoggingConfig(
              grad_clipping=True,
              snr_global=True,  # signal-to-noise ratio across layers
              snr_per_layer=False,  # signal-to-noise ratio per layer
          ),
      ),
      averaging={
          'ema': averaging.ExponentialMovingAveragingConfig(decay=0.999),
      },
      data_train=image_data_canaries.Cifar10CraftedLoader(
          config=image_data_canaries.Cifar10CraftedTrainValidConfig(
            preprocess_name='standardise',
            num_samples=n,
            num_canaries=m,
            num_noncanaries=r,
            seed=data_seed,
            mislabel_canaries=mislabel_canaries,
            craft_type=craft_type,
            filter_include=True,
            filter_canary=False,
          ),
          augmult_config=augmult,
      ),
      data_eval=image_data_canaries.Cifar10CraftedLoader(
          config=image_data_canaries.Cifar10CraftedTestConfig(
            preprocess_name='standardise',
            num_samples=n,
            num_canaries=m,
            num_noncanaries=r,
            seed=data_seed,
          ),
      ),
      data_eval_additional=image_data_canaries.Cifar10CraftedLoader(
          config=image_data_canaries.Cifar10CraftedTrainValidConfig(
            preprocess_name='standardise',
            num_samples=n,
            num_canaries=m,
            num_noncanaries=r,
            seed=data_seed,
            mislabel_canaries=mislabel_canaries,
            craft_type=craft_type,
            # evaluate on all canaries
            filter_canary=True,
            filter_include=False,
          ),
      ),
      evaluation=experiment_config.EvaluationConfig(
          batch_size=100,
      ),
  )

  return config_base.build_jaxline_config(
      experiment_config=config,
  )
