# coding=utf-8
# Copyright 2023.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for CIFAR-10 and CIFAR-100."""


from absl import flags
import models


# Uncertainty Aware Information Bottleneck params.
flags.DEFINE_float("beta", 0.001, "Uncertainty Aware Information Bottleneck lagrange multiplier for regularization.")
flags.DEFINE_float("uaib_tau",1.0," Temperature of the variational marginal distribution.",)
flags.DEFINE_integer("uaib_dim", 8, "Bottleneck dimension")
flags.DEFINE_integer("codebook_size", 10, "Codebook size. ")


# Fine-grained specification of the hyperparameters (used when FLAGS.l2 is None)
flags.DEFINE_float("bn_l2", None, "L2 reg. coefficient for batch-norm layers.")
flags.DEFINE_float("input_conv_l2", None, "L2 reg. coefficient for the input conv layer.")
flags.DEFINE_float("group_1_conv_l2", None, "L2 reg. coefficient for the 1st group of conv layers.")
flags.DEFINE_float("group_2_conv_l2", None, "L2 reg. coefficient for the 2nd group of conv layers.")
flags.DEFINE_float("group_3_conv_l2", None, "L2 reg. coefficient for the 3rd group of conv layers.")
flags.DEFINE_float("dense_kernel_l2", None, "L2 reg. coefficient for the kernel of the dense layer.")
flags.DEFINE_float("dense_bias_l2", None, "L2 reg. coefficient for the bias of the dense layer.")




flags.DEFINE_bool("eval_only", False, "Whether to run only eval and (maybe) OOD steps.")
# Qualitative evaluation flags.
flags.DEFINE_bool("eval_clusters_true_label", True, "Whether to evaluate clusters on true labels.")
flags.DEFINE_bool("eval_clusters_predicted_label", True, "Whether to evaluate clusters on true labels.")
# Calibration flags
flags.DEFINE_bool("eval_calibration", True, "Whether to evaluate model's calibration.")
flags.DEFINE_integer("calibration_num_buckets", 20, "Number of quantiles for computing calibration coefficient.")

# OOD flags.
flags.DEFINE_bool("eval_on_ood", False, "Whether to run OOD evaluation on specified OOD datasets.")

flags.DEFINE_list("ood_dataset", "cifar100,svhn_cropped", "list of OOD datasets to evaluate on.")
flags.DEFINE_string("saved_model_dir", None, "Directory containing the saved model checkpoints.")


# Train flags
flags.DEFINE_integer('train_epochs', 200, 'Number of training iterations.')
flags.DEFINE_float('cluster_base_learning_rate', 0.1,
                   'Base learning rate for the centroids of the clusters.')
flags.DEFINE_float('base_learning_rate', 0.1,
                   'Base learning rate when total batch size is 128. It is '
                   'scaled by the ratio of the total batch size to 128.')
flags.DEFINE_integer('checkpoint_interval', 25,
                     'Number of epochs between saving checkpoints. Use -1 to '
                     'never save checkpoints.')
# TODO(ghassen): consider adding CIFAR-100-C to TFDS.

flags.DEFINE_enum('dataset', 'cifar10',
                  enum_values=['cifar10', 'cifar100'],
                  help='Dataset.')
flags.DEFINE_string('data_dir', None,
                    'data_dir to be used for tfds dataset construction.'
                    'It is required when training with cloud TPUs')
flags.DEFINE_bool('download_data', False,
                  'Whether to download data locally when initializing a '
                  'dataset.')
flags.DEFINE_bool(
    'drop_remainder_for_eval', True,
    'Whether to drop the last batch in the case it has fewer than batch_size '
    'elements. If your use TPU and XLA which requires data to have a '
    'statically known shape, you should use drop_remainder=True.')
    
# Train hyper-parameters
flags.DEFINE_float('l2', 2e-4, 'L2 regularization coefficient.')
flags.DEFINE_float('lr_decay_ratio', 0.2, 'Amount to decay learning rate.')
flags.DEFINE_list('lr_decay_epochs', ['60', '120', '160'],
                  'Epochs to decay learning rate by.')
flags.DEFINE_integer('lr_warmup_epochs', 1,
                     'Number of epochs for a linear warmup to the initial '
                     'learning rate. Use 0 to do no warmup.')
flags.DEFINE_integer('num_bins', 15, 'Number of bins for ECE.')
flags.DEFINE_float('one_minus_momentum', 0.1, 'Optimizer momentum.')
flags.DEFINE_string('output_dir', '<YOUR OUTPUT DIR>', 'Output directory.')
flags.DEFINE_integer('per_core_batch_size', 64,
                     'Batch size per TPU core/GPU. The number of new '
                     'datapoints gathered per batch is this number divided by '
                     'ensemble_size (we tile the batch by that # of times).')
flags.DEFINE_integer('shuffle_buffer_size', None,
                     'Shuffle buffer size for training dataset.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_float('train_proportion', default=1.0,
                   help='only use a proportion of training set.')
flags.register_validator('train_proportion',
                         lambda tp: tp > 0.0 and tp <= 1.0,
                         message='--train_proportion must be in (0, 1].')

# Accelerator flags.
flags.DEFINE_bool('use_gpu', False, 'Whether to run on GPU or otherwise TPU.')
flags.DEFINE_integer('num_cores', 8, 'Number of TPU cores or number of GPUs.')
flags.DEFINE_string('tpu', None,
                    'Name of the TPU. Only used if use_gpu is False.')
flags.DEFINE_bool('use_bfloat16', False, 'Whether to use mixed precision.')

FLAGS = flags.FLAGS

###########################             utility functions        ###########################

def _extract_hyperparameter_dictionary():
    """Create the dictionary of hyperparameters from FLAGS."""
    flags_as_dict = FLAGS.flag_values_dict()
    hp_keys = models.get_wide_resnet_hp_keys()
    hps = {k: flags_as_dict[k] for k in hp_keys}
    return hps
    


