# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library providing untrained models for next character prediction."""

import tensorflow as tf


# Vocabulary of ASCII chars (these happen to occur in Shakespeare's works).
VOCAB = list(
    'dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAE'
    + 'IMQUY]!%)-159\r')

# Length of the vocabulary in chars.
VOCAB_SIZE = len(VOCAB)

# The embedding dimension.
EMBEDDING_DIM = 256

# Number of RNN units.
RNN_UNITS = 1024


def _get_next_char_prediction_model_v1():
  """Architecture for predicting next char from sequence of preceding chars.

  Returns:
    A Keras model that takes sequence of characters and returns predicted
    character.

  This model is based on the model used in the 'Text generation with an RNN'
  page (TensorFlow > Resources > Text > Tutorials).
  https://www.tensorflow.org/text/tutorials/text_generation#build_the_model
  """
  inputs = tf.keras.Input(shape=(None,))  # Returns a placeholder tensor
  x = tf.keras.layers.Embedding(VOCAB_SIZE, EMBEDDING_DIM)(inputs)
  x = tf.keras.layers.GRU(RNN_UNITS, return_sequences=True)(x)
  logits = tf.keras.layers.Dense(VOCAB_SIZE)(x)
  return tf.keras.Model(inputs=inputs, outputs=logits)


def get_next_char_prediction_model():
  """Untrained model for predicting next char from sequence of preceding chars."""
  # If changing the architecture exposed, new checkpoint needs to be checked in.
  return _get_next_char_prediction_model_v1()


# Borrowed from the 'Federated Learning for Text Generation' page (Tensorflow >
# Resources > Federated > Tutorials).
# https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation
class FlattenedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):

  def __init__(self, name='accuracy', dtype=tf.float32):
    super().__init__(name, dtype=dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.reshape(y_true, [-1, 1])
    y_pred = tf.reshape(y_pred, [-1, VOCAB_SIZE, 1])
    return super().update_state(y_true, y_pred, sample_weight)
