def get_model_params(dataset):
  if dataset == 'imagenet':
    channels = 3
    resolution = 256
  elif dataset == 'cifar10':
    channels = 3
    resolution = 32
  elif dataset == 'mnist':
    channels = 1
    resolution = 28
  elif dataset == 'celeba':
    channels = 3
    resolution = 64
  else:
    raise Exception(f'{dataset} setting for args.dataset is not supported.')
  return channels, resolution
