# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library of methods for various federated/datacenter data splits."""

from typing import List, Tuple

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_federated as tff

NUM_EXAMPLES_STACKOVERFLOW_TRAIN = 135818730.0
NUM_EXAMPLES_STACKOVERFLOW_EVAL = 16586035.0
NUM_EXAMPLES_WIKIPEDIA_TRAIN = 6210110.0
NUM_EXAMPLES_WIKIPEDIA_EVAL = NUM_EXAMPLES_WIKIPEDIA_TRAIN

_PREFIX_1 = '1'
_PREFIX_2 = '2'


def _preprocess_stackoverflow_to_common_format(
    dataset):
  return dataset.map(lambda x: x['tokens'])


def _preprocess_wikipedia_to_common_format(
    dataset):
  return dataset.map(lambda x: x['text'])


def _get_federated_all():
  raise NotImplementedError(
      'Since Wikipedia is not available as a federated dataset, it is also not '
      'possible to get Wikipedia + Stack Overflow as a federated dataset.')


def _get_federated_stackoverflow():
  train_data, _, test_data = tff.simulation.datasets.stackoverflow.load_data()
  return (train_data.preprocess(_preprocess_stackoverflow_to_common_format),
          test_data.preprocess(_preprocess_stackoverflow_to_common_format))


def _get_federated_wikipedia():
  raise NotImplementedError(
      'Wikipedia is not available as a federated dataset.')


def _get_datacenter_train_all():
  return tf.data.Dataset.sample_from_datasets(
      datasets=[_get_datacenter_train_stackoverflow(),
                _get_datacenter_train_wikipedia()],
      # Use eval proportions here, as if we train on this centralized dataset
      # of all examples, we wish for them to match the proportions they're
      # present in in the evaluation split.
      weights=[NUM_EXAMPLES_STACKOVERFLOW_EVAL, NUM_EXAMPLES_WIKIPEDIA_EVAL])


def _get_datacenter_train_stackoverflow():
  train_client_data, _ = _get_federated_stackoverflow()
  return train_client_data.create_tf_dataset_from_all_clients()


def _get_datacenter_train_wikipedia():
  return _preprocess_wikipedia_to_common_format(tfds.load(
      'wikipedia/20201201.en', split='train', shuffle_files=True))


def _get_datacenter_eval_all():
  return tf.data.Dataset.sample_from_datasets(
      datasets=[_get_datacenter_eval_stackoverflow(),
                _get_datacenter_eval_wikipedia()],
      weights=[NUM_EXAMPLES_STACKOVERFLOW_EVAL, NUM_EXAMPLES_WIKIPEDIA_EVAL])


def _get_datacenter_eval_stackoverflow():
  _, eval_client_data = _get_federated_stackoverflow()
  return eval_client_data.create_tf_dataset_from_all_clients()


def _get_datacenter_eval_wikipedia():
  # Note: This dataset doesn't have a separate 'eval' split, so we also use the
  # 'train' split for eval as well.
  return _get_datacenter_train_wikipedia()


DATASET_SPLITS = [
    'all', 'stackoverflow', 'wikipedia'
]
# These lists will each be zipped with the above list of strings, so the lists
# should all be kept of equal length with entries in 1:1 correspondence.
FEDERATED_DATA_GETTER_FNS = [
    _get_federated_all, _get_federated_stackoverflow, _get_federated_wikipedia
]
DATACENTER_TRAIN_DATA_GETTER_FNS = [
    _get_datacenter_train_all, _get_datacenter_train_stackoverflow,
    _get_datacenter_train_wikipedia
]
DATACENTER_EVAL_DATA_GETTER_FNS = [
    _get_datacenter_eval_all, _get_datacenter_eval_stackoverflow,
    _get_datacenter_eval_wikipedia
]


def get_possible_dataset_splits():
  """The list of valid arguments to the other public methods in this library."""
  return DATASET_SPLITS


def _validate_split_arg(split):
  if split not in DATASET_SPLITS:
    raise ValueError(
        'Specified split was %s but must be one of %s.' % (split,
                                                           DATASET_SPLITS))


def _merge_two_clients_into_one(
    unmerged_data
):
  """Convert to federated dataset with fewer, larger clients.

  This will consider the clients in a federated dataset in sequences of 2,
  create a new client id, and create a dataset for the new client id out of the
  combined (concatenated and shuffled) datasets of the 2 clients. This method
  will drop the last original client if total number of original clients is odd.

  Args:
    unmerged_data: The ClientData object with clients to be merged.
  Returns:
    A ClientData object with merged clients.
  """
  merged_client_ids = [
      '%s %s' % (a, b)
      for a, b in zip(*[iter(unmerged_data.client_ids)] * 2)
  ]

  if not merged_client_ids:
    raise ValueError(
        'Attempting to merge clients by groups of 2, but no merged clients '
        'were created; does the unmerged ClientData object contain <2 clients?')

  @tf.function
  def _serializable_dataset_fn(merged_client_id):
    client_ids = tf.unstack(tf.strings.split(merged_client_id, sep=' '), num=2)
    dataset0 = unmerged_data.serializable_dataset_fn(client_ids[0])
    dataset1 = unmerged_data.serializable_dataset_fn(client_ids[1])
    # Note there is likely additional shuffling that will go on in data
    # processing; this shuffle is to ensure that the contents of the 2 original
    # clients' datasets are mixed together.
    return dataset0.concatenate(dataset1).shuffle(1000)

  return tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
      client_ids=merged_client_ids,
      serializable_dataset_fn=_serializable_dataset_fn)


def get_federated(split):
  """Get the federated train and eval data for the split specified."""
  _validate_split_arg(split)
  train_data, eval_data = dict(zip(DATASET_SPLITS,
                                   FEDERATED_DATA_GETTER_FNS))[split]()
  # We reduce the number of clients, but put more examples per client, in order
  # to allow for more steps of training (on unique examples) per round.
  merged_train_data = _merge_two_clients_into_one(train_data)
  merged_eval_data = _merge_two_clients_into_one(eval_data)
  return merged_train_data, merged_eval_data


def get_datacenter_train(split):
  """Get the datacenter train data for the split specified."""
  _validate_split_arg(split)
  return dict(zip(DATASET_SPLITS, DATACENTER_TRAIN_DATA_GETTER_FNS))[split]()


def get_datacenter_eval(split):
  """Get the datacenter eval data for the split specified."""
  _validate_split_arg(split)
  return dict(zip(DATASET_SPLITS, DATACENTER_EVAL_DATA_GETTER_FNS))[split]()

