'''80m Tiny Images dataset.'''

import numpy as np
import torch.utils.data as data
import os
import os.path
from PIL import Image
import struct


def fetch_img(filename, start_id, nimages):
  '''
  fetch_img - Read a block of images from 80m Tiny Images dataset.
  Args:
    start_id: The ID of the first image. (Starting from 0)
    nimages:  Number of images.
  Returns:
    A len(img_ids) * 3072 numpy.uint8 array. Values between 0 and 255.
  Note:
    There are 73,302,017 images in this dataset.
  Metadata Format:
    1. keyword
    2. filename
    3. width
    4. height
    5. colors
    6. date
    7. engine
    8. thumb_url
    9. source_url
    10. page
    11. ind_page
    12. ind_engine
    13. ind_overall
    14. label (1 = correct, 0 = incorrect, -1 = unlabelled)
  '''
  bytes_per_img = 3072
  with open(filename, 'rb') as f:
    f.seek(start_id * bytes_per_img)
    img_bytes = f.read(nimages * bytes_per_img)
  return np.frombuffer(img_bytes, dtype=np.uint8)


class TinyImg(data.Dataset):
  """80 million tiny images dataset.
  Args:
    root(string):  Root directory of dataset containing the file tiny_images.bin
    nimages(int):  Size of the dataset.
    start_id(int): Index of the first image.
    transform:     Data augmentation.
    metadata:      True if you need metadata. Make sure tiny_metadata.bin is in root.
  Note:
    Download url: <http://horatio.cs.nyu.edu/mit/tiny/data/index.html>
  """
  def __init__(self, root, nimages=50000, start_id=0, transform=None, metadata=False):
    self.root = os.path.expanduser(root)
    self.transform = transform
    filename = os.path.join(self.root, 'tiny_images.bin')
    if not os.path.isfile(filename):
      raise RuntimeError('Bin file does not exist.')
    self.data = fetch_img(filename, start_id, nimages)
    self.data = self.data.reshape(-1, 3, 32, 32)
    self.data = self.data.transpose((0, 3, 2, 1)) # Convert to HWC

    # Read metadata
    if metadata:
      ss = 768
      # fs = [80, 95, 2, 2, 1, 32, 10, 200, 328, 4, 4, 4, 4, 2]
      fmt = '>80s95s2s2sc32s10s200s328s4s4s4s4s2s'
      filename = os.path.join(self.root, 'tiny_metadata.bin')
      with open(filename, 'rb') as f:
        f.seek(start_id * ss)
        md_bytes = f.read(nimages * ss)
      self.metadata = list(struct.iter_unpack(fmt, md_bytes))

  def __getitem__(self, index):
    img = self.data[index]
    img = Image.fromarray(img)
    if self.transform is not None:
      img = self.transform(img)
    return img, 0

  def __len__(self):
    return len(self.data)

  def __repr__(self):
    fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
    fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
    fmt_str += '    Root Location: {}\n'.format(self.root)
    tmp = '    Transforms (if any): '
    fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
    return fmt_str