# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training DDPM++ on CIFAR-10 with PFGM."""

from configs.default_svhn_configs import get_default_configs
import math

def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'homotopy'
  training.continuous = True
  training.batch_size = 200
  training.small_batch_size = 200
  training.num_particles = 1
  training.sample_size = 128
  training.sample_freq = 0
  training.mean_prior = 0
  training.sigma_prior = 1
  training.divisor = 1
  training.sigma_min = 0.01
  training.sigma_max = 0.01
  training.eps_sigma = 1
  training.invert_sigma = False
  training.sigma_anneal = 0
  training.sigma_clip = 0.01
  training.t_end = 1
  training.eps_max = 1e-4
  training.eps_min = 1e-4
  training.invert_t = False
  training.augment_t = True
  training.z_max = 1
  training.eps_z = 1
  training.augment_z = False
  training.snapshot_freq = 5000
  training.snapshot_freq_for_preemption = 1000
  training.model = 'ddpmpp'

  # model
  model = config.model
  model.name = 'resmlpt'
  model.act = 'gelu'
  model.norm = 'affine'
  model.bias = False
  model.patch_size = 4
  model.dim = 768
  model.temb_dim = 0
  model.depth = 12
  model.sqr_freq = 12
  model.expansion_factor = (0,4)
  model.sqr_scale = 1
  model.skipmul_power = 1
  model.wrn_depth = 28
  model.wrn_width = 10
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = False
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'none'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.embedding_type = 'positional'
  model.conv_size = 3
  model.sigma_end = 0.01
  
  # data
  data = config.data
  data.channels = 3
  data.centered = True
  data.scale = 1
  data.size = 60000

  # sampling
  sampling = config.sampling
  sampling.method = 'ode'
  sampling.ode_solver = 'rk45'
  sampling.eps_z = 1e-4
  #sampling.ode_solver = 'forward_euler'
  #sampling.ode_solver = 'improved_euler'
  sampling.N = 100
  sampling.z_max = 40
  sampling.z_min = 1e-3
  sampling.upper_norm = 3000
  # verbose
  sampling.vs = False

  return config