# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library of methods for various federated/datacenter data splits."""

from typing import Dict, List, Tuple

import tensorflow as tf
import tensorflow_federated as tff


def _get_federated_datasets():
  return tff.simulation.datasets.celeba.load_data()


def _has_facial_hair(example):
  return example['mustache'] or example['goatee'] or not example['no_beard']


def _filter_no_facial_hair(dataset):
  return dataset.filter(lambda x: not _has_facial_hair(x))


def _filter_only_facial_hair(dataset):
  return dataset.filter(_has_facial_hair)


def _is_smiling(example):
  return example['smiling']


def _filter_no_smiling(dataset):
  return dataset.filter(lambda x: not _is_smiling(x))


def _filter_only_smiling(dataset):
  return dataset.filter(_is_smiling)


def _has_lipstick(example):
  return example['wearing_lipstick']


def _filter_no_lipstick(dataset):
  return dataset.filter(lambda x: not _has_lipstick(x))


def _filter_only_lipstick(dataset):
  return dataset.filter(_has_lipstick)


def _has_mouth_slightly_open(example):
  return example['mouth_slightly_open']


def _filter_no_mouth_slightly_open(dataset):
  return dataset.filter(lambda x: not _has_mouth_slightly_open(x))


def _filter_only_mouth_slightly_open(
    dataset):
  return dataset.filter(_has_mouth_slightly_open)


def _blurry(example):
  return example['blurry']


def _filter_not_blurry(dataset):
  return dataset.filter(lambda x: not _blurry(x))


def _filter_only_blurry(dataset):
  return dataset.filter(_blurry)


def _get_federated_all():
  return _get_federated_datasets()


def _get_federated_no_facial_hair():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_filter_no_facial_hair),
          eval_client_data.preprocess(_filter_no_facial_hair))


def _get_federated_only_facial_hair():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_filter_only_facial_hair),
          eval_client_data.preprocess(_filter_only_facial_hair))


def _get_federated_no_smiling():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_filter_no_smiling),
          eval_client_data.preprocess(_filter_no_smiling))


def _get_federated_only_smiling():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_filter_only_smiling),
          eval_client_data.preprocess(_filter_only_smiling))


def _get_federated_no_lipstick():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_filter_no_lipstick),
          eval_client_data.preprocess(_filter_no_lipstick))


def _get_federated_only_lipstick():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_filter_only_lipstick),
          eval_client_data.preprocess(_filter_only_lipstick))


def _get_federated_no_mouth_slightly_open():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_filter_no_mouth_slightly_open),
          eval_client_data.preprocess(_filter_no_mouth_slightly_open))


def _get_federated_only_mouth_slightly_open():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_filter_only_mouth_slightly_open),
          eval_client_data.preprocess(_filter_only_mouth_slightly_open))


def _get_federated_not_blurry():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_filter_not_blurry),
          eval_client_data.preprocess(_filter_not_blurry))


def _get_federated_only_blurry():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_filter_only_blurry),
          eval_client_data.preprocess(_filter_only_blurry))


def _get_datacenter_train_all():
  train_client_data, _ = _get_federated_all()
  return train_client_data.create_tf_dataset_from_all_clients()


def _get_datacenter_train_no_facial_hair():
  train_client_data, _ = _get_federated_all()
  return _filter_no_facial_hair(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_only_facial_hair():
  train_client_data, _ = _get_federated_all()
  return _filter_only_facial_hair(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_no_smiling():
  train_client_data, _ = _get_federated_all()
  return _filter_no_smiling(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_only_smiling():
  train_client_data, _ = _get_federated_all()
  return _filter_only_smiling(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_no_lipstick():
  train_client_data, _ = _get_federated_all()
  return _filter_no_lipstick(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_only_lipstick():
  train_client_data, _ = _get_federated_all()
  return _filter_only_lipstick(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_no_mouth_slightly_open():
  train_client_data, _ = _get_federated_all()
  return _filter_no_mouth_slightly_open(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_only_mouth_slightly_open():
  train_client_data, _ = _get_federated_all()
  return _filter_only_mouth_slightly_open(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_not_blurry():
  train_client_data, _ = _get_federated_all()
  return _filter_not_blurry(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_only_blurry():
  train_client_data, _ = _get_federated_all()
  return _filter_only_blurry(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_all():
  _, eval_client_data = _get_federated_all()
  return eval_client_data.create_tf_dataset_from_all_clients()


def _get_datacenter_eval_no_facial_hair():
  _, eval_client_data = _get_federated_all()
  return _filter_no_facial_hair(
      eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_only_facial_hair():
  _, eval_client_data = _get_federated_all()
  return _filter_only_facial_hair(
      eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_no_smiling():
  _, eval_client_data = _get_federated_all()
  return _filter_no_smiling(
      eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_only_smiling():
  _, eval_client_data = _get_federated_all()
  return _filter_only_smiling(
      eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_no_lipstick():
  _, eval_client_data = _get_federated_all()
  return _filter_no_lipstick(
      eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_only_lipstick():
  _, eval_client_data = _get_federated_all()
  return _filter_only_lipstick(
      eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_no_mouth_slightly_open():
  _, eval_client_data = _get_federated_all()
  return _filter_no_mouth_slightly_open(
      eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_only_mouth_slightly_open():
  _, eval_client_data = _get_federated_all()
  return _filter_only_mouth_slightly_open(
      eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_not_blurry():
  _, eval_client_data = _get_federated_all()
  return _filter_not_blurry(
      eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_only_blurry():
  _, eval_client_data = _get_federated_all()
  return _filter_only_blurry(
      eval_client_data.create_tf_dataset_from_all_clients())


DATASET_SPLITS = [
    'all', 'no_facial_hair', 'only_facial_hair', 'no_smiling', 'only_smiling',
    'no_lipstick', 'only_lipstick', 'no_mouth_slightly_open',
    'only_mouth_slightly_open', 'not_blurry', 'only_blurry'
]
# These lists will each be zipped with the above list of strings, so the lists
# should all be kept of equal length with entries in 1:1 correspondence.
FEDERATED_DATA_GETTER_FNS = [
    _get_federated_all, _get_federated_no_facial_hair,
    _get_federated_only_facial_hair, _get_federated_no_smiling,
    _get_federated_only_smiling, _get_federated_no_lipstick,
    _get_federated_only_lipstick, _get_federated_no_mouth_slightly_open,
    _get_federated_only_mouth_slightly_open, _get_federated_not_blurry,
    _get_federated_only_blurry
]
DATACENTER_TRAIN_DATA_GETTER_FNS = [
    _get_datacenter_train_all, _get_datacenter_train_no_facial_hair,
    _get_datacenter_train_only_facial_hair, _get_datacenter_train_no_smiling,
    _get_datacenter_train_only_smiling, _get_datacenter_train_no_lipstick,
    _get_datacenter_train_only_lipstick,
    _get_datacenter_train_no_mouth_slightly_open,
    _get_datacenter_train_only_mouth_slightly_open,
    _get_datacenter_train_not_blurry, _get_datacenter_train_only_blurry
]
DATACENTER_EVAL_DATA_GETTER_FNS = [
    _get_datacenter_eval_all, _get_datacenter_eval_no_facial_hair,
    _get_datacenter_eval_only_facial_hair, _get_datacenter_eval_no_smiling,
    _get_datacenter_eval_only_smiling, _get_datacenter_eval_no_lipstick,
    _get_datacenter_eval_only_lipstick,
    _get_datacenter_eval_no_mouth_slightly_open,
    _get_datacenter_eval_only_mouth_slightly_open,
    _get_datacenter_eval_not_blurry, _get_datacenter_eval_only_blurry
]


def get_possible_dataset_splits():
  """The list of valid arguments to the other public methods in this library."""
  return DATASET_SPLITS


def _validate_split_arg(split):
  if split not in DATASET_SPLITS:
    raise ValueError(
        'Specified split was %s but must be one of %s.' % (split,
                                                           DATASET_SPLITS))


def _merge_three_clients_into_one(
    unmerged_data
):
  """Convert to federated dataset with fewer, larger clients.

  This will consider the clients in a federated dataset in sequences of 3,
  create a new client id, and create a dataset for the new client id out of the
  combined (concatenated and shuffled) datasets of the 3 clients. This method
  will drop the last 0, 1, or 2 clients, depending on if the total number of
  original clients is or is not divisible by 3.

  Args:
    unmerged_data: The ClientData object with clients to be merged.

  Returns:
    A ClientData object with merged clients.
  """
  merged_client_ids = [
      '%s_%s_%s' % (a, b, c)
      for a, b, c in zip(*[iter(unmerged_data.client_ids)] * 3)
  ]

  if not merged_client_ids:
    raise ValueError(
        'Attempting to merge clients by groups of 3, but no merged clients '
        'were created; does the unmerged ClientData object contain <3 clients?')

  @tf.function
  def _serializable_dataset_fn(merged_client_id):
    client_ids = tf.unstack(tf.strings.split(merged_client_id, sep='_'), num=3)
    dataset0 = unmerged_data.serializable_dataset_fn(client_ids[0])
    dataset1 = unmerged_data.serializable_dataset_fn(client_ids[1])
    dataset2 = unmerged_data.serializable_dataset_fn(client_ids[2])
    # Note there is likely additional shuffling that will go on in data
    # processing; this shuffle is to ensure that the contents of the 3 original
    # clients' datasets are mixed together.
    return dataset0.concatenate(dataset1).concatenate(dataset2).shuffle(100)

  return tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=merged_client_ids,
      serializable_dataset_fn=_serializable_dataset_fn)


def get_federated(split):
  """Get the federated train and eval data for the split specified."""
  _validate_split_arg(split)
  train_data, eval_data = dict(zip(DATASET_SPLITS,
                                   FEDERATED_DATA_GETTER_FNS))[split]()
  # We reduce the number of clients, but put more examples per client, in order
  # to allow for more steps of training (on unique examples) per round.
  merged_train_data = _merge_three_clients_into_one(train_data)
  merged_eval_data = _merge_three_clients_into_one(eval_data)
  return merged_train_data, merged_eval_data


def get_datacenter_train(split):
  """Get the datacenter train data for the split specified."""
  _validate_split_arg(split)
  return dict(zip(DATASET_SPLITS, DATACENTER_TRAIN_DATA_GETTER_FNS))[split]()


def get_datacenter_eval(split):
  """Get the datacenter eval data for the split specified."""
  _validate_split_arg(split)
  return dict(zip(DATASET_SPLITS, DATACENTER_EVAL_DATA_GETTER_FNS))[split]()
