# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library providing untrained models for EMNIST classification."""

import tensorflow as tf


def _get_emnist_classifier_model_v1(num_classes=10):
  """Model architecture for classifying EMNIST into `num_classes` classes.

  Args:
    num_classes: The number of classes to classify into. E.g., this would be set
      to 10 if only classifying digits, to 36 if classifying digits and letters
      (case agnostic), or to 62 if classifying digits, uppercase, and lowercase
      letters.

  Returns:
    A Keras model that takes 28x28x1 images and returns logit vectors.

  This model is based on the one in the TF 2.0 Alpha page (Beginner Tutorials >
  Images > Convolutional NNs).
  https://www.tensorflow.org/alpha/tutorials/images/intro_to_cnn
  """
  inputs = tf.keras.Input(shape=(28, 28, 1))  # Returns a placeholder tensor

  x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu')(inputs)
  x = tf.keras.layers.MaxPooling2D((2, 2))(x)
  x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu')(x)
  x = tf.keras.layers.MaxPooling2D((2, 2))(x)
  x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu')(x)
  x = tf.keras.layers.Flatten()(x)
  x = tf.keras.layers.Dense(64, activation='relu')(x)

  logits = tf.keras.layers.Dense(num_classes, activation='linear')(x)

  return tf.keras.Model(inputs=inputs, outputs=logits)


def get_emnist_classifier_model(num_classes=10):
  """Untrained model for EMNIST classification (into `num_classes` classes)."""
  # If changing the architecture exposed, new checkpoint needs to be checked in.
  return _get_emnist_classifier_model_v1(num_classes)
