# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library of methods for various federated/datacenter data splits."""

from typing import Callable, List, Tuple

import tensorflow as tf
import tensorflow_federated as tff


def _get_federated_datasets(
    only_digits
):
  return tff.simulation.datasets.emnist.load_data(only_digits=only_digits)


def _filter_out_digits(dataset):
  return dataset.filter(lambda x: x['label'] > 9)


def _filter_out_lowercase(dataset):
  return dataset.filter(lambda x: x['label'] < 36)


def _filter_out_uppercase(dataset):
  return dataset.filter(lambda x: x['label'] < 10 or x['label'] > 35)


def _map_to_noise(dataset):
  def _map_fn(x):
    x['pixels'] = tf.random.uniform(
        x['pixels'].shape, minval=0.0, maxval=1.0, dtype=tf.float32)
    return x
  return dataset.map(_map_fn)


def _domain_id(client_id):
  """Returns domain id for client id."""
  # These domain ids are based on NIST data source. For more details, see
  # https://s3.amazonaws.com/nist-srd/SD19/sd19_users_guide_edition_2.pdf.
  cid = int(client_id[1:5])
  if 2100 <= cid and cid <= 2599:
    return 0  # HIGH_SCHOOL.
  return 1  # CENSUS_FIELD.


def _get_census_client_ids(
    dataset):
  client_ids = dataset.client_ids
  return list(filter(lambda cid: _domain_id(cid) == 1, client_ids))


def _get_high_school_client_ids(
    dataset):
  client_ids = dataset.client_ids
  return list(filter(lambda cid: _domain_id(cid) == 0, client_ids))


def _get_client_ids_for_possible_label_noising(
    all_client_ids):
  # There are 3400 total EMNIST clients. When experimenting with label noising,
  # reserve the first 3000 clients to be the larger dataset in which some
  # percentage of the train examples will have their label scrambled.
  return all_client_ids[:3000]


def _get_client_ids_for_no_label_noising(
    all_client_ids):
  # There are 3400 total EMNIST clients. When experimenting with label noising,
  # reserve the last 400 clients to be the smaller dataset with pristine labels
  # (no label scrambling).
  return all_client_ids[3000:]


def _get_possibly_add_label_noise_map_fn(
    wrong_label_probability
):
  """Gets a function performing a mapping to mislabel some % of examples."""
  def _possibly_add_label_noise_map(
      dataset):
    def _map_fn(x):
      x['label'] = tf.cond(
          tf.greater(wrong_label_probability, tf.random.uniform((), 0.0, 1.0)),
          lambda: tf.random.uniform(x['label'].shape, 0, 62, tf.int32),
          lambda: x['label'])
      return x
    return dataset.map(_map_fn)
  return _possibly_add_label_noise_map


def _get_federated_all():
  return _get_federated_datasets(only_digits=False)


def _get_federated_only_digits():
  return _get_federated_datasets(only_digits=True)


def _get_federated_only_letters():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_filter_out_digits),
          eval_client_data.preprocess(_filter_out_digits))


def _get_federated_only_digits_and_uppercase():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_filter_out_lowercase),
          eval_client_data.preprocess(_filter_out_lowercase))


def _get_federated_only_digits_and_lowercase():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_filter_out_uppercase),
          eval_client_data.preprocess(_filter_out_uppercase))


def _get_federated_noise():
  train_client_data, eval_client_data = _get_federated_all()
  return (train_client_data.preprocess(_map_to_noise),
          eval_client_data.preprocess(_map_to_noise))


def _get_federated_only_census():
  """Returns census train and eval client data."""
  train_client_data, eval_client_data = _get_federated_all()
  train_census_client_ids = _get_census_client_ids(train_client_data)
  eval_census_client_ids = _get_census_client_ids(eval_client_data)
  train_client_data = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=train_census_client_ids,
      serializable_dataset_fn=train_client_data.serializable_dataset_fn)
  eval_client_data = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=eval_census_client_ids,
      serializable_dataset_fn=eval_client_data.serializable_dataset_fn)
  return train_client_data, eval_client_data


def _get_federated_only_high_school():
  """Returns high school train and eval client data."""
  train_client_data, eval_client_data = _get_federated_all()
  train_high_school_client_ids = _get_high_school_client_ids(train_client_data)
  eval_high_school_client_ids = _get_high_school_client_ids(eval_client_data)
  train_client_data = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=train_high_school_client_ids,
      serializable_dataset_fn=train_client_data.serializable_dataset_fn)
  eval_client_data = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=eval_high_school_client_ids,
      serializable_dataset_fn=eval_client_data.serializable_dataset_fn)
  return train_client_data, eval_client_data


def _get_federated_large_dataset_with_some_noised_train_labels(
    wrong_label_probability
):
  """Large segment of overall data (federated), where some labels are wrong."""
  train_client_data, eval_client_data = _get_federated_all()
  client_ids_for_possible_label_noising = (
      _get_client_ids_for_possible_label_noising(train_client_data.client_ids))
  train_client_data = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=client_ids_for_possible_label_noising,
      serializable_dataset_fn=train_client_data.serializable_dataset_fn)
  eval_client_data = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=client_ids_for_possible_label_noising,
      serializable_dataset_fn=eval_client_data.serializable_dataset_fn)
  possibly_add_label_noise_map = _get_possibly_add_label_noise_map_fn(
      wrong_label_probability=wrong_label_probability)
  return (train_client_data.preprocess(possibly_add_label_noise_map),
          eval_client_data)


def _get_federated_large_dataset_with_one_percent_noised_train_labels(
):
  """Large segment of overall data (federated), where 1% of labels are wrong."""
  return _get_federated_large_dataset_with_some_noised_train_labels(
      wrong_label_probability=0.01)


def _get_federated_large_dataset_with_twenty_percent_noised_train_labels(
):
  """Large segment of overall data (federated), where 20% labels are wrong."""
  return _get_federated_large_dataset_with_some_noised_train_labels(
      wrong_label_probability=0.20)


def _get_federated_large_dataset_with_fifty_percent_noised_train_labels(
):
  """Large segment of overall data (federated), where 50% labels are wrong."""
  return _get_federated_large_dataset_with_some_noised_train_labels(
      wrong_label_probability=0.50)


def _get_federated_small_dataset_with_clean_train_labels():
  """Small segment of overall data (federated), all labels are pristine."""
  train_client_data, eval_client_data = _get_federated_all()
  client_ids_for_no_label_noising = (
      _get_client_ids_for_no_label_noising(train_client_data.client_ids))
  train_client_data = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=client_ids_for_no_label_noising,
      serializable_dataset_fn=train_client_data.serializable_dataset_fn)
  eval_client_data = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=client_ids_for_no_label_noising,
      serializable_dataset_fn=eval_client_data.serializable_dataset_fn)
  return train_client_data, eval_client_data


def _get_datacenter_train_all():
  train_client_data, _ = _get_federated_all()
  return train_client_data.create_tf_dataset_from_all_clients()


def _get_datacenter_train_only_digits():
  train_client_data, _ = _get_federated_only_digits()
  return train_client_data.create_tf_dataset_from_all_clients()


def _get_datacenter_train_only_letters():
  train_client_data, _ = _get_federated_all()
  return _filter_out_digits(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_only_digits_and_uppercase():
  train_client_data, _ = _get_federated_all()
  return _filter_out_lowercase(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_only_digits_and_lowercase():
  train_client_data, _ = _get_federated_all()
  return _filter_out_uppercase(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_noise():
  train_client_data, _ = _get_federated_all()
  return _map_to_noise(train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_only_census():
  train_client_data, _ = _get_federated_only_census()
  return train_client_data.create_tf_dataset_from_all_clients()


def _get_datacenter_train_only_high_school():
  train_client_data, _ = _get_federated_only_high_school()
  return train_client_data.create_tf_dataset_from_all_clients()


def _get_datacenter_train_large_dataset_with_some_noised_train_labels(
    wrong_label_probability):
  """Large segment of overall train data, where some % of labels are wrong."""
  train_client_data, _ = _get_federated_all()
  client_ids_for_possible_label_noising = (
      _get_client_ids_for_possible_label_noising(train_client_data.client_ids))
  train_client_data = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=client_ids_for_possible_label_noising,
      serializable_dataset_fn=train_client_data.serializable_dataset_fn)
  possibly_add_label_noise_map = _get_possibly_add_label_noise_map_fn(
      wrong_label_probability=wrong_label_probability)
  return possibly_add_label_noise_map(
      train_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_train_large_dataset_with_one_percent_noised_train_labels(
):
  """Large segment of overall train data, where 1% of labels are wrong."""
  return _get_datacenter_train_large_dataset_with_some_noised_train_labels(0.01)


def _get_datacenter_train_large_dataset_with_twenty_percent_noised_train_labels(
):
  """Large segment of overall train data, where 20% of labels are wrong."""
  return _get_datacenter_train_large_dataset_with_some_noised_train_labels(0.20)


def _get_datacenter_train_large_dataset_with_fifty_percent_noised_train_labels(
):
  """Large segment of overall train data, where 50% of labels are wrong."""
  return _get_datacenter_train_large_dataset_with_some_noised_train_labels(0.50)


def _get_datacenter_train_small_dataset_with_clean_train_labels(
):
  """Small segment of overall train data, where all labels are pristine."""
  train_client_data, _ = _get_federated_all()
  client_ids_for_no_label_noising = (
      _get_client_ids_for_no_label_noising(train_client_data.client_ids))
  train_client_data = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=client_ids_for_no_label_noising,
      serializable_dataset_fn=train_client_data.serializable_dataset_fn)
  return train_client_data.create_tf_dataset_from_all_clients()


def _get_datacenter_eval_all():
  _, eval_client_data = _get_federated_all()
  return eval_client_data.create_tf_dataset_from_all_clients()


def _get_datacenter_eval_only_digits():
  _, eval_client_data = _get_federated_only_digits()
  return eval_client_data.create_tf_dataset_from_all_clients()


def _get_datacenter_eval_only_letters():
  _, eval_client_data = _get_federated_all()
  return _filter_out_digits(
      eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_only_digits_and_uppercase():
  _, eval_client_data = _get_federated_all()
  return _filter_out_lowercase(
      eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_only_digits_and_lowercase():
  _, eval_client_data = _get_federated_all()
  return _filter_out_uppercase(
      eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_noise():
  _, eval_client_data = _get_federated_all()
  return _map_to_noise(eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_only_census():
  _, eval_client_data = _get_federated_only_census()
  return eval_client_data.create_tf_dataset_from_all_clients()


def _get_datacenter_eval_only_high_school():
  _, eval_client_data = _get_federated_only_high_school()
  return eval_client_data.create_tf_dataset_from_all_clients()


def _get_datacenter_eval_large_dataset_with_some_noised_train_labels(
    wrong_label_probability):
  """Large segment of overall eval data, where some % of labels are wrong."""
  _, eval_client_data = _get_federated_all()
  client_ids_for_possible_label_noising = (
      _get_client_ids_for_possible_label_noising(eval_client_data.client_ids))
  eval_client_data = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=client_ids_for_possible_label_noising,
      serializable_dataset_fn=eval_client_data.serializable_dataset_fn)
  possibly_add_label_noise_map = _get_possibly_add_label_noise_map_fn(
      wrong_label_probability=wrong_label_probability)
  return possibly_add_label_noise_map(
      eval_client_data.create_tf_dataset_from_all_clients())


def _get_datacenter_eval_large_dataset_with_one_percent_noised_train_labels(
):
  """Large segment of overall eval data, where 1% of labels are wrong."""
  return _get_datacenter_eval_large_dataset_with_some_noised_train_labels(0.01)


def _get_datacenter_eval_large_dataset_with_twenty_percent_noised_train_labels(
):
  """Large segment of overall eval data, where 20% of labels are wrong."""
  return _get_datacenter_eval_large_dataset_with_some_noised_train_labels(0.20)


def _get_datacenter_eval_large_dataset_with_fifty_percent_noised_train_labels(
):
  """Large segment of overall eval data, where 50% of labels are wrong."""
  return _get_datacenter_eval_large_dataset_with_some_noised_train_labels(0.50)


def _get_datacenter_eval_small_dataset_with_clean_train_labels(
):
  """Small segment of overall eval data, where all labels are pristine."""
  _, eval_client_data = _get_federated_all()
  client_ids_for_no_label_noising = (
      _get_client_ids_for_no_label_noising(eval_client_data.client_ids))
  eval_client_data = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=client_ids_for_no_label_noising,
      serializable_dataset_fn=eval_client_data.serializable_dataset_fn)
  return eval_client_data.create_tf_dataset_from_all_clients()


DATASET_SPLITS = [
    'all', 'only_digits', 'only_letters', 'only_digits_and_uppercase',
    'only_digits_and_lowercase', 'noise', 'only_census', 'only_high_school',
    'one_percent_noised_train_labels', 'twenty_percent_noised_train_labels',
    'fifty_percent_noised_train_labels', 'clean_train_labels'
]
# These lists will each be zipped with the above list of strings, so the lists
# should all be kept of equal length with entries in 1:1 correspondence.
FEDERATED_DATA_GETTER_FNS = [
    _get_federated_all, _get_federated_only_digits, _get_federated_only_letters,
    _get_federated_only_digits_and_uppercase,
    _get_federated_only_digits_and_lowercase, _get_federated_noise,
    _get_federated_only_census, _get_federated_only_high_school,
    _get_federated_large_dataset_with_one_percent_noised_train_labels,
    _get_federated_large_dataset_with_twenty_percent_noised_train_labels,
    _get_federated_large_dataset_with_fifty_percent_noised_train_labels,
    _get_federated_small_dataset_with_clean_train_labels
]
DATACENTER_TRAIN_DATA_GETTER_FNS = [
    _get_datacenter_train_all, _get_datacenter_train_only_digits,
    _get_datacenter_train_only_letters,
    _get_datacenter_train_only_digits_and_uppercase,
    _get_datacenter_train_only_digits_and_lowercase,
    _get_datacenter_train_noise,
    _get_datacenter_train_only_census, _get_datacenter_train_only_high_school,
    _get_datacenter_train_large_dataset_with_one_percent_noised_train_labels,
    _get_datacenter_train_large_dataset_with_twenty_percent_noised_train_labels,
    _get_datacenter_train_large_dataset_with_fifty_percent_noised_train_labels,
    _get_datacenter_train_small_dataset_with_clean_train_labels
]
DATACENTER_EVAL_DATA_GETTER_FNS = [
    _get_datacenter_eval_all, _get_datacenter_eval_only_digits,
    _get_datacenter_eval_only_letters,
    _get_datacenter_eval_only_digits_and_uppercase,
    _get_datacenter_eval_only_digits_and_lowercase,
    _get_datacenter_eval_noise,
    _get_datacenter_eval_only_census, _get_datacenter_eval_only_high_school,
    _get_datacenter_eval_large_dataset_with_one_percent_noised_train_labels,
    _get_datacenter_eval_large_dataset_with_twenty_percent_noised_train_labels,
    _get_datacenter_eval_large_dataset_with_fifty_percent_noised_train_labels,
    _get_datacenter_eval_small_dataset_with_clean_train_labels
]


def get_possible_dataset_splits():
  """The list of valid arguments to the other public methods in this library."""
  return DATASET_SPLITS


def _validate_split_arg(split):
  if split not in DATASET_SPLITS:
    raise ValueError(
        'Specified split was %s but must be one of %s.' % (split,
                                                           DATASET_SPLITS))


def get_federated(split):
  """Get the federated train and eval data for the split specified."""
  _validate_split_arg(split)
  return dict(zip(DATASET_SPLITS, FEDERATED_DATA_GETTER_FNS))[split]()


def get_datacenter_train(split):
  """Get the datacenter train data for the split specified."""
  _validate_split_arg(split)
  return dict(zip(DATASET_SPLITS, DATACENTER_TRAIN_DATA_GETTER_FNS))[split]()


def get_datacenter_eval(split):
  """Get the datacenter eval data for the split specified."""
  _validate_split_arg(split)
  return dict(zip(DATASET_SPLITS, DATACENTER_EVAL_DATA_GETTER_FNS))[split]()
