# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
"""Return training and evaluation/test datasets from config files."""
import jax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import random
import sys
import os
from tqdm import trange
# from pathlib import Path


def get_data_scaler(config):
  """Data normalizer. Assume data are always in [0, 1]."""
  if config.data.centered:
    # Rescale to [-1, 1]
    return lambda x: x * 2. - 1.
  else:
    return lambda x: x


def get_data_inverse_scaler(config):
  """Inverse data normalizer."""
  if config.data.centered:
    # Rescale [-1, 1] to [0, 1]
    return lambda x: (x + 1.) / 2.
  else:
    return lambda x: x


def crop_resize(image, resolution):
  """Crop and resize an image to the given resolution."""
  crop = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
  h, w = tf.shape(image)[0], tf.shape(image)[1]
  image = image[(h - crop) // 2:(h + crop) // 2,
          (w - crop) // 2:(w + crop) // 2]
  image = tf.image.resize(
    image,
    size=(resolution, resolution),
    antialias=True,
    method=tf.image.ResizeMethod.BICUBIC)
  return tf.cast(image, tf.uint8)


def resize_small(image, resolution):
  """Shrink an image to the given resolution."""
  h, w = image.shape[0], image.shape[1]
  ratio = resolution / min(h, w)
  h = tf.round(h * ratio, tf.int32)
  w = tf.round(w * ratio, tf.int32)
  return tf.image.resize(image, [h, w], antialias=True)


def central_crop(image, size):
  """Crop the center of an image to the given size."""
  top = (image.shape[0] - size) // 2
  left = (image.shape[1] - size) // 2
  return tf.image.crop_to_bounding_box(image, top, left, size, size)


def get_dataset(config, uniform_dequantization=False, evaluation=False):
  """Create data loaders for training and evaluation.

  Args:
    config: A ml_collection.ConfigDict parsed from config files.
    uniform_dequantization: If `True`, add uniform dequantization to images.
    evaluation: If `True`, fix number of epochs to 1.

  Returns:
    train_ds, eval_ds, dataset_builder.
  """
  # Reduce this when image resolution is too large and data pointer is stored
  num_epochs = None if not evaluation else 1
  shuffle_buffer_size = 10000

  # Create dataset builders for each dataset.
  if config.data.dataset == 'CIFAR10':
    dataset_builder = tfds.builder('cifar10')
    train_split_name = 'train'
    eval_split_name = 'test'
    path = os.path.join(os.path.expanduser('~'), "tensorflow_datasets", f"cifar10_semi_{int(k*100):03d}")
    os.makedirs(path, exist_ok=True)

    def resize_op(img):
      img = tf.image.convert_image_dtype(img, tf.float32)
      return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)

    def preprocess_fn(d):
      """Basic preprocessing function scales data to [0, 1) and randomly flips."""

      return dict(image=img, label=d.get('label', None))
  if config.data.dataset == 'CIFAR100':
    dataset_builder = tfds.builder('cifar100')
    train_split_name = 'train'
    eval_split_name = 'test'
    path = os.path.join(os.path.expanduser('~'), "tensorflow_datasets", f"cifar100_semi_{int(k*100):03d}")
    os.makedirs(path, exist_ok=True)

    def resize_op(img):
      img = tf.image.convert_image_dtype(img, tf.float32)
      return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)

    def preprocess_fn(d):
      """Basic preprocessing function scales data to [0, 1) and randomly flips."""

      return dict(image=img, label=d.get('label', None))
  elif config.data.dataset == 'IMAGENET':
    dataset_builder = tfds.builder('imagenet2012')
    train_split_name = 'train'
    eval_split_name = 'test'
    path = os.path.join(os.path.expanduser('~'), "tensorflow_datasets", f"imagenet_semi_{k*100:03d}")
    os.makedirs(path, exist_ok=True)

    def resize_op(img):
      img = tf.image.convert_image_dtype(img, tf.float32)
      return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)

    def preprocess_fn(d):
      """Basic preprocessing function scales data to [0, 1) and randomly flips."""

      return dict(image=img, label=d.get('label', None))
  else:
    raise NotImplementedError(
      f'Dataset {config.data.dataset} not yet supported.')
      

  def split_dataset(dataset_builder, split):
    dataset_options = tf.data.Options()
    dataset_options.experimental_optimization.map_parallelization = True
    dataset_options.experimental_threading.private_threadpool_size = 48
    dataset_options.experimental_threading.max_intra_op_parallelism = 1
    read_config = tfds.ReadConfig(options=dataset_options)
    if isinstance(dataset_builder, tfds.core.DatasetBuilder):
      dataset_builder.download_and_prepare()
      ds = dataset_builder.as_dataset(
        split=split, shuffle_files=True, read_config=read_config)
    else:
      ds = dataset_builder.with_options(dataset_options)
    if split != 'train':
      # ds.save(os.path.join(path, "eval"))
      tf.data.experimental.save(ds, os.path.join(path, "eval"))
      # print('eval')
    else:
      ds_classes = [None for _ in range(config.data.num_classes)]
      ds_labeled_classes, ds_unlabeled_classes = [None for _ in range(config.data.num_classes)], [None for _ in range(config.data.num_classes)]
      def map_func(x, y):
        y['ori_label'] = y['label']
        y['label'] = '-100'
        return y
      for i in range(config.data.num_classes):
        ds_classes[i] = ds.filter(lambda y: y['label'] == i)

      for i in trange(config.data.num_classes):
        cur_len = len(list(ds_classes[i]))
        cur_labeled = int(cur_len * k)
        ds_labeled_classes[i] = ds_classes[i].filter(lambda y: y['label'] == i)\
                          .shuffle(shuffle_buffer_size)\
                          .enumerate()\
                          .filter(lambda x,y:x<cur_labeled)\
                          .map(lambda x,y:y)
        ds_unlabeled_classes[i] = ds_classes[i].filter(lambda y: y['label'] == i)\
                          .shuffle(shuffle_buffer_size)\
                          .enumerate()\
                          .filter(lambda x,y:x>=cur_labeled)\
                          .map(map_func)
      ds_labeled, ds_unlabeled = ds_labeled_classes[0], ds_unlabeled_classes[0]
      for i in range(1, config.data.num_classes):
        ds_labeled = ds_labeled.concatenate(ds_labeled_classes[i])
        ds_unlabeled = ds_unlabeled.concatenate(ds_unlabeled_classes[i])
      tf.data.experimental.save(ds_labeled, os.path.join(path, "labeled"))
      tf.data.experimental.save(ds_unlabeled, os.path.join(path, "unlabeled"))

  train_ds = split_dataset(dataset_builder, train_split_name)
  eval_ds = split_dataset(dataset_builder, eval_split_name)

if __name__ == '__main__': # for testing
  import types
  config = types.SimpleNamespace()
  config.data = types.SimpleNamespace()
  if len(sys.argv) == 3:
    k = float(sys.argv[1])
    config.data.dataset = sys.argv[2]
  else:
    print(f"Usage: python3 {sys.argv[0]} labeled_percentage dataset")
  config.data.centered = False
  if config.data.dataset == 'CIFAR10':
    config.data.image_size = 32
    config.data.num_classes = 10
  elif config.data.dataset == 'CIFAR100':
    config.data.image_size = 32
    config.data.num_classes = 100
  elif config.data.dataset == 'IMAGENET':
    config.data.image_size = 128
    config.data.num_classes = 1000
  else:
    raise ValueError(f'Splitting not supported on dataset {config.data.dataset}.')
  tf.config.run_functions_eagerly(True)
  get_dataset(config)