# Lint as: python3
# Copyright 2020 The SPL Authors.
#
# All rights reserved.
#
# This is the code for reproducing results of the paper.
"""Code to create datasets."""

import collections
import math

from absl import logging

from datasets import weaklabel_datasets


Datasets = collections.namedtuple('Datasets',
                                  ['train_dataset',
                                   'eval_dataset',
                                   'num_classes',
                                   'batch_shape',
                                   'num_samples',
                                   'steps_per_epoch',
                                   'steps_per_eval'])


def _num_workers(distribution_strategy):
  """Returns number of workers."""
  is_tpu_pod = distribution_strategy.extended._input_workers.num_workers > 1 
  if is_tpu_pod:
    return len(distribution_strategy.extended.worker_devices)
  else:
    return 1


def _make_datasets_namedtuple(strategy,
                              train_dataset_object,
                              eval_dataset_object):
  """Makes Datasets namedtuple from objects which provide input_fn.

  Args:
    strategy: distribution strategy.
    train_dataset_object: object which provides input_fn for training data.
    eval_dataset_object: object which provides input_fn for evaluation data.

  Returns:
    Datasets namedtuple with train and eval datasets.
  """
  num_workers = _num_workers(strategy)
  if num_workers > 1:
    train_dataset = strategy.experimental_distribute_datasets_from_function(
        train_dataset_object.input_fn)
    eval_dataset = strategy.experimental_distribute_datasets_from_function(
        eval_dataset_object.input_fn)
  else:
    train_dataset = strategy.experimental_distribute_dataset(
        train_dataset_object.input_fn())
    eval_dataset = strategy.experimental_distribute_dataset(
        eval_dataset_object.input_fn())
  train_batch_size = train_dataset_object.batch_size * num_workers
  steps_per_epoch = train_dataset_object.num_images / train_batch_size
  eval_batch_size = eval_dataset_object.batch_size * num_workers
  steps_per_eval = int(
      math.ceil(eval_dataset_object.num_images / eval_batch_size))

  return Datasets(train_dataset=train_dataset,
                  eval_dataset=eval_dataset,
                  num_classes=train_dataset_object.num_classes,
                  batch_shape=train_dataset_object.batch_shape,
                  num_samples=train_dataset_object.num_images,
                  steps_per_epoch=steps_per_epoch,
                  steps_per_eval=steps_per_eval)


def make_train_eval_datasets(distribution_strategy,
                             dataset_name,
                             batch_size,
                             train_augmentation,
                             dataset_kwargs,
                             eval_batch_size=None):
  """Creates Datasets object with training and eval datasets.

  Args:
    distribution_strategy: distribution strategy.
    dataset_name: name of the dataset.
    batch_size: total batch size across all workers.
    train_augmentation: optional callable which performs augmentation of
      training data.
    dataset_kwargs: dictionary with additional arguments which are passed to the
      dataset init function.
    eval_batch_size: optional batch size for evaluation data.

  Returns:
    Datasets namestuple with train and eval datasets.
  """
  if not eval_batch_size:
    eval_batch_size = batch_size

  num_workers = _num_workers(distribution_strategy)
  per_worker_batch_size = int(batch_size / num_workers)
  per_worker_eval_batch_size = int(eval_batch_size / num_workers)
  logging.info('Distributing input pipeline across %d workers '
               'with %d train and %d eval per worker batch size',
               num_workers, per_worker_batch_size, per_worker_eval_batch_size)

  dataset_name = dataset_name.lower()
  if dataset_name.startswith('clothing1m'):
    dataset_name, eval_set_name = dataset_name.split('@')
    supervised_split = 'train' + dataset_name[len('clothing1m_'):]
    train_dataset_object = weaklabel_datasets.Clothing1M(
        split=supervised_split,
        is_training=True,
        batch_size=per_worker_batch_size,
        augmentation=train_augmentation,
        **dataset_kwargs)
    eval_dataset_object = weaklabel_datasets.Clothing1M(
        split=eval_set_name,
        is_training=False,
        batch_size=per_worker_eval_batch_size,
        **dataset_kwargs)

  else:
    raise ValueError('Unsupported dataset name: {0}'.format(dataset_name))

  return _make_datasets_namedtuple(distribution_strategy,
                                   train_dataset_object,
                                   eval_dataset_object)
