# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library providing untrained models for CelebA attribute binary classification."""

import tensorflow as tf


def _get_celeba_attribute_binary_classifier_model_v1():
  """Model architecture for classifying CelebA images by a particular attribute.

  Returns:
    A Keras model that takes 84x84x3 RGB images and returns logits.

  This model is based on the model used in the 'TF Constrained Optimization
  Example Using CelebA Dataset' page (TensorFlow > Resources > Responsible AI >
  Tutorials).
  https://www.tensorflow.org/responsible_ai/fairness_indicators/tutorials/Fairness_Indicators_TFCO_CelebA_Case_Study
  """
  model = tf.keras.Sequential([
      tf.keras.layers.Flatten(input_shape=(84, 84, 3), name='image'),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(1, activation=None)
  ])
  return model


def get_celeba_attribute_binary_classifier_model():
  """Untrained model for CelebA attribute binary classification."""
  # If changing the architecture exposed, new checkpoint needs to be checked in.
  return _get_celeba_attribute_binary_classifier_model_v1()
