# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Methods to retrieve data, for federated train or eval or datacenter train."""

import itertools
from typing import Callable, List, Set

import tensorflow as tf
import tensorflow_federated as tff

from mixed_fl.experiments.celeba import datasets as celeba_datasets
from mixed_fl.experiments.emnist import datasets as emnist_datasets
from mixed_fl.experiments.next_char_prediction import datasets as next_char_prediction_datasets

ALLOWED_DATASET_SPLITS_MAP = {
    'emnist': emnist_datasets.get_possible_dataset_splits(),
    'celeba': celeba_datasets.get_possible_dataset_splits(),
    'ncp': next_char_prediction_datasets.get_possible_dataset_splits(),
}
FEDERATED_DATA_GETTER_FNS_MAP = {
    'emnist': emnist_datasets.get_federated,
    'celeba': celeba_datasets.get_federated,
    'ncp': next_char_prediction_datasets.get_federated,
}
DATACENTER_TRAIN_DATA_GETTER_FNS_MAP = {
    'emnist': emnist_datasets.get_datacenter_train,
    'celeba': celeba_datasets.get_datacenter_train,
    'ncp': next_char_prediction_datasets.get_datacenter_train,
}
DATACENTER_EVAL_DATA_GETTER_FNS_MAP = {
    'emnist': emnist_datasets.get_datacenter_eval,
    'celeba': celeba_datasets.get_datacenter_eval,
    'ncp': next_char_prediction_datasets.get_datacenter_eval,
}


def get_possible_datasets():
  """The list of datasets provided."""
  return list(ALLOWED_DATASET_SPLITS_MAP.keys())


def get_all_possible_splits():
  return set(itertools.chain(*list(ALLOWED_DATASET_SPLITS_MAP.values())))


def _validate_dataset_and_split(dataset, split):
  """Check that input flags have valid values."""
  if dataset is None:
    raise ValueError('No `dataset` was specified (cannot be left `None`).')

  if split is None:
    raise ValueError('No `split` was specified (cannot be left `None`).')

  if split not in ALLOWED_DATASET_SPLITS_MAP[dataset]:
    raise ValueError(
        'The `split` argument cannot be %s if dataset is %s; allowed values: '
        '%s.' % (split, dataset, ALLOWED_DATASET_SPLITS_MAP[dataset]))


def get_federated_train_data(dataset,
                             split):
  """Provides the federated training data."""
  _validate_dataset_and_split(dataset, split)
  client_data, _ = FEDERATED_DATA_GETTER_FNS_MAP[dataset](split)
  return client_data


def get_federated_eval_data(dataset,
                            split):
  """Provides the federated evaluation data."""
  _validate_dataset_and_split(dataset, split)
  _, client_data = FEDERATED_DATA_GETTER_FNS_MAP[dataset](split)
  return client_data


def get_datacenter_train_dataset_fn(
    dataset, split):
  """Provides a callable returning the datacenter training data."""
  _validate_dataset_and_split(dataset, split)
  def _fn():
    return DATACENTER_TRAIN_DATA_GETTER_FNS_MAP[dataset](split)
  return _fn


def get_datacenter_eval_dataset_fn(
    dataset, split):
  """Provides a callable returning the datacenter evaluation data."""
  _validate_dataset_and_split(dataset, split)
  def _fn():
    return DATACENTER_EVAL_DATA_GETTER_FNS_MAP[dataset](split)
  return _fn
