# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for template class defining pieces of a federated model and task."""

import abc
import collections
from typing import List

import tensorflow as tf
import tensorflow_federated as tff

from mixed_fl.experiments.celeba import classifier_model as celeba_classifier_model
from mixed_fl.experiments.celeba import data_utils as celeba_data_utils
from mixed_fl.experiments.emnist import classifier_model as emnist_classifier_model
from mixed_fl.experiments.emnist import data_utils as emnist_data_utils
from mixed_fl.experiments.next_char_prediction import data_utils as next_char_prediction_data_utils
from mixed_fl.experiments.next_char_prediction import prediction_model as next_char_prediction_model


class _NumExamplesCounter(tf.keras.metrics.Sum):
  """A `tf.keras.metrics.Metric` that counts the number of examples seen."""

  def __init__(self, name='num_examples', dtype=tf.int64):
    super().__init__(name, dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    del y_true
    if isinstance(y_pred, list):
      y_pred = y_pred[0]
    return super().update_state(tf.shape(y_pred)[0], sample_weight)


class _FederatedTaskDefinition(metaclass=abc.ABCMeta):
  """Template for data class defining pieces of a federated model and task."""

  def __init__(self, client_batch_size):
    self.client_batch_size = client_batch_size

  @abc.abstractmethod
  def get_keras_model(self):
    """The Keras model to set in the TFF learning model."""
    raise NotImplementedError

  @abc.abstractmethod
  def get_input_spec(self):
    """The input spec to set in the TFF learning model."""
    raise NotImplementedError

  @abc.abstractmethod
  def get_loss(self,
               reduction = tf.keras.losses.Reduction
               .SUM_OVER_BATCH_SIZE):
    """The Keras loss to set in the TFF learning model."""
    raise NotImplementedError

  @abc.abstractmethod
  def get_metrics(self):
    """The Keras metrics to set in the TFF learning model."""
    raise NotImplementedError

  def get_tff_model(self):
    """Creates a `tff.learning.Model`."""
    return tff.learning.from_keras_model(
        keras_model=self.get_keras_model(),
        input_spec=self.get_input_spec(),
        loss=self.get_loss(),
        metrics=self.get_metrics())

  def get_client_dataset_processing_fn(self,
                                       limit_num_batches = -1,
                                       num_epochs = 1):
    """The data processing (batching, etc.) fn that takes place at client."""

    def client_dataset_processing_fn(unprocessed_dataset):
      processed_dataset = self._client_dataset_processing(
          unprocessed_dataset, self.client_batch_size, num_epochs)
      if limit_num_batches > 0:
        processed_dataset = processed_dataset.take(limit_num_batches)
      return processed_dataset

    return client_dataset_processing_fn

  @abc.abstractmethod
  def _client_dataset_processing(self,
                                 unprocessed_dataset,
                                 batch_size,
                                 num_epochs = 1):
    """Implementation of the data processing (batching, etc.) at the client."""
    raise NotImplementedError


class MixingTaskDefinition(_FederatedTaskDefinition, metaclass=abc.ABCMeta):
  """Template for data class defining pieces of a mixing model and task."""

  def get_datacenter_dataset_processing_fn(self,
                                           datacenter_batch_size,
                                           limit_num_batches = -1):
    """The data processing (batching, etc.) fn taking place at datacenter."""

    def datacenter_dataset_processing_fn(unprocessed_dataset):
      processed_dataset = self._datacenter_dataset_processing(
          unprocessed_dataset, datacenter_batch_size)
      if limit_num_batches > 0:
        processed_dataset = processed_dataset.take(limit_num_batches)
      return processed_dataset

    return datacenter_dataset_processing_fn

  @abc.abstractmethod
  def _datacenter_dataset_processing(self, unprocessed_dataset,
                                     batch_size):
    """Implementation of the data processing (batching, etc.) at datacenter."""
    raise NotImplementedError


class EmnistMixingTaskDefinition(MixingTaskDefinition):
  """Class that defines an EMNIST categorical classification model and task."""

  def get_keras_model(self):
    """The Keras model to set in the TFF learning model."""
    return emnist_classifier_model.get_emnist_classifier_model(num_classes=36)

  def get_input_spec(self):
    """The input spec to set in the TFF learning model."""
    return collections.OrderedDict(
        x=tf.TensorSpec(shape=[None, 28, 28, 1], dtype=tf.float32),
        y=tf.TensorSpec(shape=[None], dtype=tf.int32))

  def get_loss(self,
               reduction = tf.keras.losses.Reduction
               .SUM_OVER_BATCH_SIZE):
    """The Keras loss to set in the TFF learning model."""
    return tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction=reduction)

  def get_metrics(self):
    """The Keras metrics to set in the TFF learning model."""
    return [
        tf.keras.metrics.SparseCategoricalAccuracy(),
        _NumExamplesCounter(),
    ]

  def _client_dataset_processing(self,
                                 unprocessed_dataset,
                                 batch_size,
                                 num_epochs = 1):
    """Implementation of the data processing (batching, etc.) at the client."""
    return emnist_data_utils.preprocess_img_dataset(
        unprocessed_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_epochs=num_epochs)

  def _datacenter_dataset_processing(self, unprocessed_dataset,
                                     batch_size):
    """Implementation of the data processing (batching, etc.) at datacenter."""
    return tf.data.Dataset.zip(
        emnist_data_utils.preprocess_img_dataset(
            unprocessed_dataset,
            batch_size=batch_size,
            shuffle=True))


# Which attribute in the CelebA dataset to use as the label.
CELEBA_LABEL_ATTRIBUTE = 'smiling'


class CelebAMixingTaskDefinition(MixingTaskDefinition):
  """Class that defines a CelebA binary classification model and task."""

  def get_keras_model(self):
    """The Keras model to set in the TFF learning model."""
    logits_model = (
        celeba_classifier_model.get_celeba_attribute_binary_classifier_model())
    inputs = tf.keras.Input(shape=(84, 84, 3))  # Returns a placeholder tensor
    probs_model = tf.keras.Model(
        inputs=inputs, outputs=tf.math.sigmoid(logits_model(inputs)))
    return probs_model

  def get_input_spec(self):
    """The input spec to set in the TFF learning model."""
    return collections.OrderedDict(
        x=tf.TensorSpec(shape=[None, 84, 84, 3], dtype=tf.float32),
        y=tf.TensorSpec(shape=[None, 1], dtype=tf.float32))

  def get_loss(self,
               reduction = tf.keras.losses.Reduction
               .SUM_OVER_BATCH_SIZE):
    """The Keras loss to set in the TFF learning model."""
    return tf.keras.losses.BinaryCrossentropy(reduction=reduction)

  def get_metrics(self):
    """The Keras metrics to set in the TFF learning model."""
    return [
        tf.keras.metrics.BinaryAccuracy(name='accuracy'),
        tf.keras.metrics.AUC(name='auc'),
        tf.keras.metrics.FalsePositives(name='fp'),
        tf.keras.metrics.FalseNegatives(name='fn'),
        _NumExamplesCounter()
    ]

  def _client_dataset_processing(self,
                                 unprocessed_dataset,
                                 batch_size,
                                 num_epochs = 1):
    """Implementation of the data processing (batching, etc.) at the client."""
    return celeba_data_utils.preprocess_img_dataset(
        unprocessed_dataset,
        label_attribute=CELEBA_LABEL_ATTRIBUTE,
        batch_size=batch_size,
        shuffle=True,
        num_epochs=num_epochs)

  def _datacenter_dataset_processing(self, unprocessed_dataset,
                                     batch_size):
    """Implementation of the data processing (batching, etc.) at datacenter."""
    return tf.data.Dataset.zip(
        celeba_data_utils.preprocess_img_dataset(
            unprocessed_dataset,
            label_attribute=CELEBA_LABEL_ATTRIBUTE,
            batch_size=batch_size,
            shuffle=True))


class NcpMixingTaskDefinition(MixingTaskDefinition):
  """Class that defines a next character prediction model and task."""

  def get_keras_model(self):
    """The Keras model to set in the TFF learning model."""
    return next_char_prediction_model.get_next_char_prediction_model()

  def get_input_spec(self):
    """The input spec to set in the TFF learning model."""
    return collections.OrderedDict(
        x=tf.TensorSpec(shape=[None, 100], dtype=tf.int64),
        y=tf.TensorSpec(shape=[None, 100], dtype=tf.int64))

  def get_loss(self,
               reduction = tf.keras.losses.Reduction
               .SUM_OVER_BATCH_SIZE):
    """The Keras loss to set in the TFF learning model."""
    return tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction=reduction)

  def get_metrics(self):
    """The Keras metrics to set in the TFF learning model."""
    return [
        next_char_prediction_model.FlattenedCategoricalAccuracy(),
        _NumExamplesCounter(),
    ]

  def _client_dataset_processing(self,
                                 unprocessed_dataset,
                                 batch_size,
                                 num_epochs = 1):
    """Implementation of the data processing (batching, etc.) at the client."""
    return next_char_prediction_data_utils.preprocess_text_dataset(
        unprocessed_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_epochs=num_epochs)

  def _datacenter_dataset_processing(self, unprocessed_dataset,
                                     batch_size):
    """Implementation of the data processing (batching, etc.) at datacenter."""
    return tf.data.Dataset.zip(
        next_char_prediction_data_utils.preprocess_text_dataset(
            unprocessed_dataset,
            batch_size=batch_size,
            shuffle=True))


TASK_TO_DEFINITION_MAP = {
    'emnist': EmnistMixingTaskDefinition,
    'celeba': CelebAMixingTaskDefinition,
    'ncp': NcpMixingTaskDefinition
}
TASKS = list(TASK_TO_DEFINITION_MAP.keys())


def get_possible_tasks():
  """The list of valid tasks, with task definitions provided in this library."""
  return TASKS


def _validate_task_arg(task):
  if task not in TASKS:
    raise ValueError('Specified task was %s but must be one of %s.' %
                     (task, TASKS))


def get_mixing_task_definition(task,
                               client_batch_size):
  """Get the MixingTaskDefinition object for the task specified."""
  _validate_task_arg(task)
  return TASK_TO_DEFINITION_MAP[task](client_batch_size=client_batch_size)
