# Lint as: python3
# Copyright 2020 The SPL Authors.
#
# All rights reserved.
#
# This is the code for reproducing results of the paper.
"""Weakly-supervsied data loader."""


import collections
import csv
import os

from absl import flags
from absl import logging
import gin
import numpy as np

from spl_utils import label_map_assignment

import tensorflow.compat.v2 as tf

FLAGS = flags.FLAGS

PATH = {
    'clothing1m': './data/clothing1m',
}

PATTERN = {
    'clothing1m_trainclean':
        os.path.join(PATH['clothing1m'], 'tfrecord/train', 'clean_train*'),
    'clothing1m_trainnoisy':
        os.path.join(PATH['clothing1m'], 'tfrecord/train', 'noisy_train*'),
    'clothing1m_trainall':
        os.path.join(PATH['clothing1m'], 'tfrecord/train', '*'),
    'clothing1m_val':
        os.path.join(PATH['clothing1m'], 'tfrecord/val', '*'),
    'clothing1m_test':
        os.path.join(PATH['clothing1m'], 'tfrecord/test', '*'),
}

SIZE = {
    'clothing1m_trainclean': 47570,
    'clothing1m_test': 10526,
    'clothing1m_trainnoisy': 999993,
    'clothing1m_trainall': 1047563,
    'clothing1m_val': 14313,
}

DSTuple = collections.namedtuple('tfds', ['ds', 'size'])
SHUFFLE_BUFFER_SIZE = 1024
PREPROCESSING_THREADS = 10
BUFFER_SIZE = 16 * 1024 * 1024
IMAGE_SIZE = 224
CROP_PADDING = 32


def distorted_bounding_box_crop(image_bytes,
                                bbox,
                                min_object_covered=0.1,
                                aspect_ratio_range=(0.75, 1.33),
                                area_range=(0.05, 1.0),
                                max_attempts=100):
  """Generates cropped_image using one of the bboxes randomly distorted.

  See `tf.image.sample_distorted_bounding_box` for more documentation.

  Args:
    image_bytes: `Tensor` of binary image data.
    bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` where
      each coordinate is [0, 1) and the coordinates are arranged as `[ymin,
      xmin, ymax, xmax]`. If num_boxes is 0 then use the whole image.
    min_object_covered: An optional `float`. Defaults to `0.1`. The cropped area
      of the image must contain at least this fraction of any bounding box
      supplied.
    aspect_ratio_range: An optional list of `float`s. The cropped area of the
      image must have an aspect ratio = width / height within this range.
    area_range: An optional list of `float`s. The cropped area of the image must
      contain a fraction of the supplied image within in this range.
    max_attempts: An optional `int`. Number of attempts at generating a cropped
      region of the image of the specified constraints. After `max_attempts`
      failures, return the entire image.

  Returns:
    (cropped image `Tensor`, distorted bbox `Tensor`).
  """
  with tf.name_scope('distorted_bounding_box_crop'):
    shape = tf.image.extract_jpeg_shape(image_bytes)
    sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
        shape,
        bounding_boxes=bbox,
        min_object_covered=min_object_covered,
        aspect_ratio_range=aspect_ratio_range,
        area_range=area_range,
        max_attempts=max_attempts,
        use_image_if_no_bounding_boxes=True)
    bbox_begin, bbox_size, _ = sample_distorted_bounding_box

    # Crop the image to the specified bounding box.
    offset_y, offset_x, _ = tf.unstack(bbox_begin)
    target_height, target_width, _ = tf.unstack(bbox_size)
    crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
    image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)

    return image


def _at_least_x_are_equal(a, b, x):
  """At least `x` of `a` and `b` `Tensors` are equal."""
  match = tf.equal(a, b)
  match = tf.cast(match, tf.int32)
  return tf.greater_equal(tf.reduce_sum(match), x)


def _decode_and_random_crop(image_bytes):
  """Make a random crop of IMAGE_SIZE."""
  bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
  image = distorted_bounding_box_crop(
      image_bytes,
      bbox,
      min_object_covered=0.1,
      aspect_ratio_range=(3. / 4, 4. / 3.),
      area_range=(0.08, 1.0),
      max_attempts=10)
  original_shape = tf.image.extract_jpeg_shape(image_bytes)
  bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)

  image = tf.cond(
      bad,
      lambda: _decode_and_center_crop(image_bytes),
      lambda: tf.image.resize(  
          [image], [IMAGE_SIZE, IMAGE_SIZE],
          method=tf.image.ResizeMethod.BICUBIC)[0])

  return image


def _decode_and_center_crop(image_bytes):
  """Crops to center of image with padding then scales IMAGE_SIZE."""
  shape = tf.image.extract_jpeg_shape(image_bytes)
  image_height = shape[0]
  image_width = shape[1]

  padded_center_crop_size = tf.cast(
      ((IMAGE_SIZE / (IMAGE_SIZE + CROP_PADDING)) *
       tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)

  offset_height = ((image_height - padded_center_crop_size) + 1) // 2
  offset_width = ((image_width - padded_center_crop_size) + 1) // 2
  crop_window = tf.stack([
      offset_height, offset_width, padded_center_crop_size,
      padded_center_crop_size
  ])
  image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
  image = tf.image.resize([image], [IMAGE_SIZE, IMAGE_SIZE],
                          method=tf.image.ResizeMethod.BICUBIC)[0]

  return image


def preprocess_image(image_bytes,
                     is_training=False,
                     augmentation=None,
                     use_bfloat16=False,
                     saturate_uint8=False,
                     scale_and_center=False,
                     use_default_augment=False):
  """Preprocesses the given image.

  Args:
    image_bytes: `Tensor` representing an image binary of arbitrary size.
    is_training: `bool` for whether the preprocessing is for training.
    augmentation: callable which performs augmentation on images.
    use_bfloat16: `bool` for whether to use bfloat16.
    saturate_uint8: If True then perform saturate cast to uint8 before
      augmentation.
    scale_and_center: If True then rescale image to [-1, 1] range after
      augmentation.
    use_default_augment: If True then apply defaul augment (left-right flip)
      before main augmentation on all training images.

  Returns:
    A preprocessed image `Tensor`.
  """
  if is_training:
    image = _decode_and_random_crop(image_bytes)
  else:
    image = _decode_and_center_crop(image_bytes)
  image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3])
  # decode and crop returns float32 image with values in range [0, 255]
  if saturate_uint8:
    image = tf.saturate_cast(image, tf.uint8)
  # do augmentations if necessary
  if use_default_augment and is_training:
    image = tf.image.random_flip_left_right(image)
  if augmentation is not None:
    tensors_dict = augmentation(image)
  else:
    tensors_dict = {'image': image}
  # cast and rescale all image tensors
  dtype = tf.bfloat16 if use_bfloat16 else tf.float32
  for k, v in tensors_dict.items():
    if k.endswith('image'):
      v = tf.cast(v, dtype)
      if scale_and_center:
        v = v / tf.constant(127.5, dtype) - tf.constant(1.0, dtype)
      tensors_dict[k] = v
  return tensors_dict


def get_tfdata(tfds_name, split, is_training, external_data=None):
  """Reads tfrecords and convert to tf.data."""

  def _fetch_dataset_fn(filename):
    dataset = tf.data.TFRecordDataset(filename, buffer_size=BUFFER_SIZE)
    return dataset

  try:
    if external_data:
      logging.info('Load from external data {}'.format(external_data))  
      path = external_data
      nums = SIZE[f'{tfds_name}_trainnoisy']
    else:
      path = PATTERN['_'.join([tfds_name, split])]
      nums = SIZE['_'.join([tfds_name, split])]
  except:
    raise ValueError(f'Split {split} not exists')
  dataset = tf.data.Dataset.list_files(path, shuffle=is_training)
  dataset = dataset.interleave(
      _fetch_dataset_fn,
      block_length=16,
      num_parallel_calls=tf.data.experimental.AUTOTUNE)

  return DSTuple(ds=dataset, size=nums)


def load_external_meta_data(path, superclass, subclass_type):
  """Load metadata from disk."""
  label_map = np.arange(superclass**2).reshape((superclass, superclass))
  original_label_map = label_map.copy()
  # making the diag labels to be [0, superclass) and
  # the others filling the rest[superclass, superclass**2)
  for i in range(superclass):
    label_map[0, i], label_map[i, i] = label_map[i, i], label_map[0, i]

  cvs_path = os.path.join(path, 'confusion_matrix.json')
  with tf.io.gfile.GFile(cvs_path, 'r') as fid:
    reader = csv.reader(fid)
    # Rows are predictions and columns are labels.
    # conf_mtrix[:,1] is prdictions with label 1.
    confusion_matrix = np.array(list(reader))
  confusion_matrix = confusion_matrix.astype(np.int32)
  label_map = label_map_assignment(confusion_matrix, label_map, subclass_type)
  label_map = label_map.ravel()
  original_label_map = original_label_map.ravel()

  num_class = label_map.max() + 1
  class_weight = np.zeros((num_class,), dtype=np.float32)
  confusion_matrix = confusion_matrix.ravel()
  # compute inverse frequence for each new class
  for i in range(num_class):
    idx = original_label_map[label_map == i]
    weights = confusion_matrix[idx]
    class_weight[i] = np.sum(weights)
    # sanity check, will assign superclass-1 to one class
    if subclass_type == 'offdiag':
      if i >= num_class // 2:
        assert len(weights) == num_class // 2 - 1
    elif subclass_type == 'indiag':
      if i < num_class // 2:
        assert len(weights) == num_class // 2 - 1
  assert np.sum(class_weight) == np.sum(
      confusion_matrix), 'Check failed for class weights'
  class_weight = class_weight / np.sum(class_weight)
  return tf.convert_to_tensor(class_weight), tf.convert_to_tensor(label_map)


class WeakLabelDataset:
  """WeakLabeld dataset base."""

  def __init__(self,
               tfds_name,
               split,
               is_training,
               batch_size,
               num_classes,
               subclass_type='all',
               augmentation=None,
               use_bfloat16=False,
               saturate_uint8=False,
               scale_and_center=False,
               use_default_augment=False,
               external_data=None,
               return_class_weight=False,
               **kwargs):
    del kwargs
    self.tfds_name = tfds_name
    self.split = split
    self.is_training = is_training
    self.use_bfloat16 = use_bfloat16
    self.scale_and_center = scale_and_center
    self.use_default_augment = use_default_augment
    self.batch_size = batch_size
    self.augmentation = augmentation
    self.num_classes = num_classes
    self.saturate_uint8 = saturate_uint8  # not used
    self.additional_ds = []
    self.num_images = SIZE['_'.join([self.tfds_name, split])]
    self._num_images = self.num_images
    self.external_data = external_data
    self.return_class_weight = return_class_weight
    if not self.is_training:
      self.external_data = None

    if self.external_data:
      logging.warning(f'Use external data {self.external_data}')
      self.class_weight, self.label_map = load_external_meta_data(
          os.path.split(external_data)[0], self.num_classes, subclass_type)
      self.superclass = num_classes
      self.num_classes = self.class_weight.shape[0]
      logging.info(f'Classs weights \n{self.class_weight}')
      logging.info(f'Label map \n{self.label_map}')
      logging.info(f'Class size {self.num_classes}')

  def dataset_parser_external_data(self, value):
    """Parse tfrecord from a serialized string Tensor."""
    keys_to_features = {
        self.image_key: tf.io.FixedLenFeature([], tf.string),
        self.label_key: tf.io.FixedLenFeature([], tf.int64),
        self.prob_key: tf.io.FixedLenFeature([self.superclass], tf.float32)
    }

    parsed = tf.io.parse_single_example(value, keys_to_features)
    image_bytes = tf.reshape(parsed[self.image_key], shape=[])

    tensors_dict = preprocess_image(
        image_bytes=image_bytes,
        is_training=self.is_training,
        augmentation=self.augmentation,
        use_bfloat16=self.use_bfloat16,
        saturate_uint8=self.saturate_uint8,
        scale_and_center=self
        .scale_and_center,  # KerasResNet50_v2 uses [-1,1] range
        use_default_augment=self.use_default_augment)

    # convert label + prob to superclass
    # assert self.num_classes == self.label_map.shape[0]
    origin_label = tf.cast(
        tf.reshape(parsed[self.label_key], shape=()), dtype=tf.int32)
    prob = tf.cast(tf.reshape(parsed[self.prob_key], shape=(-1,)), dtype=tf.float32)
    pred = tf.argmax(prob, output_type=tf.int32)
    # this has to be consistent with confusion matrix generated in inference stage
    new_label_index = pred * self.superclass + origin_label
    new_label = tf.gather(self.label_map, new_label_index)

    tensors_dict['label'] = new_label
    if self.return_class_weight:
      tensors_dict['weight'] = tf.gather(self.class_weight, new_label)
    return tensors_dict

  def dataset_parser(self, value):
    """Parse an tfrecord from a serialized string Tensor."""
    keys_to_features = {
        self.image_key: tf.io.FixedLenFeature([], tf.string),
        self.label_key: tf.io.FixedLenFeature([], tf.int64),
    }

    parsed = tf.io.parse_single_example(value, keys_to_features)
    image_bytes = tf.reshape(parsed[self.image_key], shape=[])

    tensors_dict = preprocess_image(
        image_bytes=image_bytes,
        is_training=self.is_training,
        augmentation=self.augmentation,
        use_bfloat16=self.use_bfloat16,
        saturate_uint8=self.saturate_uint8,
        scale_and_center=self
        .scale_and_center,  # KerasResNet50_v2 uses [-1,1] range
        use_default_augment=self.use_default_augment)

    label = tf.cast(
        tf.reshape(parsed[self.label_key], shape=()), dtype=tf.int32)
    tensors_dict['label'] = label

    return tensors_dict

  @property
  def batch_shape(self):
    return [self.batch_size, IMAGE_SIZE, IMAGE_SIZE, 3]

  def concatenate(self, ds_tuple):
    self.additional_ds.append(ds_tuple)
    self.num_images += ds_tuple.size

  def make_parsed_dataset(self, ctx=None):
    """Helper function which makes tf.Dataset object of parsed records."""
    # Shuffle the filenames to ensure better randomization.
    dataset = get_tfdata(self.tfds_name, self.split, self.is_training,
                         self.external_data)
    self.num_images = dataset.size
    dataset = dataset.ds

    logging.info(f'Create TrainInput {self.tfds_name} ({self.split})')
    logging.info(f'\t Labeled size {self.num_images}')

    for ads in self.additional_ds:
      dataset = dataset.concatenate(ads.ds)
      self._num_images += ads.size
      logging.info('{} concatenates extra {} data'.format(
          self.__class__.__name__, ads.size))

    if ctx and ctx.num_input_pipelines > 1:
      dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)

    if self.is_training:
      dataset = dataset.repeat()

    if self.is_training:
      dataset = dataset.shuffle(SHUFFLE_BUFFER_SIZE)

    parser = self.dataset_parser if not self.external_data else self.dataset_parser_external_data
    # Parse,q pre-process, and batch the data in parallel
    dataset = dataset.apply(
        tf.data.experimental.map_and_batch(
            parser,
            batch_size=self.batch_size,
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
            drop_remainder=self.is_training))

    return dataset

  def input_fn(self, ctx=None):
    """Input function which provides a single batch for train or eval.

    Args:
      ctx: Input context.

    Returns:
      A `tf.data.Dataset` object.
    """
    dataset = self.make_parsed_dataset(ctx)

    # Prefetch overlaps in-feed with training
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    if self.is_training:
      options = tf.data.Options()
      options.experimental_deterministic = False
      dataset = dataset.with_options(options)

    return dataset


@gin.configurable
class Clothing1M(WeakLabelDataset):
  """Clothing1M loading base."""

  def __init__(
      self,
      external_data=None,
      return_class_weight=False,
      subclass_type='all',  # how to construct sub pseudo labels
      **kwargs):
    super(Clothing1M, self).__init__(
        tfds_name='clothing1m',
        num_classes=14,
        external_data=external_data,
        return_class_weight=return_class_weight,
        subclass_type=subclass_type,
        **kwargs)
    self.label_key = 'image/label'
    self.image_key = 'image/encoded'
    if self.external_data:
      self.prob_key = 'image/prob'
