"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

from abc import ABC, abstractmethod

import tensorflow_datasets as tfds

from examples.tensorflow.common.logger import logger
from examples.tensorflow.common.utils import set_hard_limit_num_open_files


class BaseDatasetBuilder(ABC):
    """Abstract dataset loader and input processing."""
    def __init__(self, config, is_train, num_devices):
        self._config = config

        self._is_train = is_train
        self._num_devices = num_devices
        self._global_batch_size = config.batch_size

        # Dataset params
        self._dataset_dir = config.dataset_dir
        self._dataset_name = config.get('dataset', None)
        self._dataset_type = config.get('dataset_type', 'tfds')
        self._as_supervised = False

        # Dataset loader
        self._dataset_loader = None

        # TFDS params
        self._skip_decoding = False

        # Dict with TFRecordDatasets
        self._tfrecord_datasets = {}

        self._split = 'train' if self._is_train else 'validation'

    @property
    def is_train(self):
        """Returns a `bool` flag which specifies whether it is a training or evaluation dataset."""
        return self._is_train

    @property
    def batch_size(self):
        """Returns per replica batch size."""
        return self._global_batch_size // self._num_devices

    @property
    def global_batch_size(self):
        """Returns global batch size."""
        return self.batch_size * self._num_devices

    @property
    def steps_per_epoch(self):
        """Returns steps per epoch"""
        return self.num_examples // self.global_batch_size

    @property
    @abstractmethod
    def num_examples(self):
        """Returns number of examples in the current dataset."""

    @property
    @abstractmethod
    def num_classes(self):
        """Returns number of classes in the current dataset."""

    @abstractmethod
    def _pipeline(self, dataset):
        """The pipeline which decodes and preprocesses the input data for model."""

    def build(self):
        dataset_builders = {
            'tfds': self._load_tfds,
            'tfrecords': self._load_tfrecords,
        }

        builder = dataset_builders.get(self._dataset_type, None)
        if builder is None:
            raise ValueError('Unknown dataset type {}'.format(self._dataset_type))

        dataset = builder()
        dataset = self._pipeline(dataset)

        return dataset

    def _load_tfds(self):
        logger.info('Using TFDS to load {} data.'.format(self._split))

        set_hard_limit_num_open_files()

        self._dataset_loader = tfds.builder(self._dataset_name,
                                            data_dir=self._dataset_dir)

        self._dataset_loader.download_and_prepare()

        decoders = {'image': tfds.decode.SkipDecoding()} \
            if self._skip_decoding else None

        read_config = tfds.ReadConfig(
            interleave_cycle_length=64,
            interleave_block_length=1)

        dataset = self._dataset_loader.as_dataset(
            split=self._split,
            as_supervised=self._as_supervised,
            shuffle_files=self._is_train,
            decoders=decoders,
            read_config=read_config)

        return dataset

    def _load_tfrecords(self):
        logger.info('Using TFRecords to load {} data.'.format(self._split))

        dataset_key = self._dataset_name.replace('/', '')
        if dataset_key in self._tfrecord_datasets:
            self._dataset_loader = self._tfrecord_datasets[dataset_key](
                config=self._config, is_train=self._is_train
            )
        else:
            raise ValueError('Unknown dataset name: {}'.format(self._dataset_name))

        dataset = self._dataset_loader.as_dataset()

        return dataset
