"""Utilities for file download and caching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from abc import abstractmethod
from contextlib import closing
import errno
import functools
import gc
import hashlib
import multiprocessing.pool
import os
import random
import shutil
import signal
import sys
import tarfile
import threading
import time
import weakref
import zipfile
import logging
import csv
import collections
import traceback
import logging.handlers
import tempfile
import tarfile
import zipfile
import weakref



import numpy as np
import six
from six.moves.urllib.error import HTTPError
from six.moves.urllib.error import URLError
from six.moves.urllib.request import urlopen

from typing import Dict
from typing import List
from typing import Set
from typing import Tuple

from utils import get_file

FED_GLD_SPLIT_FILE_BUNDLE = 'landmarks-user-160k'
FED_GLD_SPLIT_FILE_DOWNLOAD_URL = 'http://storage.googleapis.com/gresearch/federated-vision-datasets/%s.zip' % FED_GLD_SPLIT_FILE_BUNDLE
FED_GLD_SPLIT_FILE_BUNDLE_MD5_CHECKSUM = '53c36bd7d5fc12f927af2820b7e4a57c'
FED_GLD_TRAIN_SPLIT_FILE = 'federated_train.csv'
FED_GLD_TEST_SPLIT_FILE = 'test.csv'
GLD_SHARD_BASE_URL = 'https://s3.amazonaws.com/google-landmark'
NUM_SHARD_TRAIN = 500
MINI_GLD_TRAIN_DOWNLOAD_URL = 'https://storage.googleapis.com/tff-datasets-public/mini_gld_train_split.csv'
MINI_GLD_TRAIN_SPLIT_FILE = 'mini_gld_train_split.csv'
MINI_GLD_TEST_DOWNLOAD_URL = 'https://storage.googleapis.com/tff-datasets-public/mini_gld_test.csv'
MINI_GLD_TEST_SPLIT_FILE = 'mini_gld_test.csv'
MINI_GLD_TRAIN_SPLIT_FILE_MD5_CHECKSUM = '9fd62cf79a67046fdd673d3a0ac52841'
MINI_GLD_TEST_SPLIT_FILE_MD5_CHECKSUM = '298e9d19d66357236f66fe8e22920933'
FED_GLD_CACHE = 'gld160k'
MINI_GLD_CACHE = 'gld23k'
TRAIN_SUB_DIR = 'train'
TEST_FILE_NAME = 'test.tfRecord'
LOGGER = 'gldv2'
KEY_IMAGE_BYTES = 'image/encoded_jpeg'
KEY_IMAGE_DECODED = 'image/decoded'
KEY_CLASS = 'class'

def _listener_process(queue: multiprocessing.Queue, log_file: str):
  """Sets up a separate process for handling logging messages.
  This setup is required because without it, the logging messages will be
  duplicated when multiple processes are created for downloading GLD dataset.
  Args:
    queue: The queue to receive logging messages.
    log_file: The file which the messages will be written to.
  """
  root = logging.getLogger()
  h = logging.FileHandler(log_file)
  fmt = logging.Formatter(
      fmt='%(asctime)s %(levelname)-8s %(message)s',
      datefmt='%Y-%m-%d %H:%M:%S')
  h.setFormatter(fmt)
  root.addHandler(h)
  while True:
    try:
      record = queue.get()
      # We send None as signal to stop
      if record is None:
        break
      logger = logging.getLogger(record.name)
      logger.handle(record)
    except Exception:  # pylint: disable=broad-except
      print('Something went wrong:', file=sys.stderr)
      traceback.print_exc(file=sys.stderr)


def _read_csv(path: str) -> List[Dict[str, str]]:
  """Reads a csv file, and returns the content inside a list of dictionaries.
  Args:
    path: The path to the csv file.
  Returns:
    A list of dictionaries. Each row in the csv file will be a list entry. The
    dictionary is keyed by the column names.
  """
  with open(path, 'r') as f:
    return list(csv.DictReader(f))


def _filter_images(shard: int, all_images: Set[str], image_dir: str,
                   base_url: str):
  """Download full GLDv2 dataset, only keep images that are included in the federated gld v2 dataset.
  Args:
    shard: The shard of the GLDv2 dataset.
    all_images: A set which contains all images included in the federated GLD
      dataset.
    image_dir: The directory to keep all filtered images.
    base_url: The base url for downloading GLD v2 dataset images.
  Raises:
    IOError: when failed to download checksum.
  """
  shard_str = '%03d' % shard
  images_tar_url = '%s/train/images_%s.tar' % (base_url, shard_str)
  # images_md5_url = '%s/md5sum/train/md5.images_%s.txt' % (base_url, shard_str)
  with tempfile.TemporaryDirectory() as tmp_dir:
    logger = logging.getLogger(LOGGER)
    # logger.info('Start to download checksum for shard %s', shard_str)
    # get_file(
    #     'images_md5_%s.txt' % shard_str,
    #     origin=images_md5_url,
    #     cache_dir=tmp_dir)
    # logger.info('Downloaded checksum for shard %s successfully.', shard_str)
    logger.info('Start to download data for shard %s', shard_str)
    get_file(
        'images_%s.tar' % shard_str,
        origin=images_tar_url,
        extract=True,
        cache_dir=tmp_dir)
    logger.info('Data for shard %s was downloaded successfully.', shard_str)
    count = 0
    for root, _, files in os.walk(tmp_dir):
      for filename in files:
        name, extension = os.path.splitext(filename)
        if extension == '.jpg' and name in all_images:
          count += 1
          shutil.copyfile(
              os.path.join(root, filename), os.path.join(image_dir, filename))
    logger.info('Moved %d images from shard %s to %s', count, shard_str,
                image_dir)


def _download_data(
    num_worker: int, cache_dir: str, base_url: str
):
  """
  Download the entire GLD v2 dataset, subset the dataset to only include the
  images in the federated GLD v2 dataset, and create both gld23k and gld160k
  datasets.
  Args:
    num_worker: The number of threads for downloading the GLD v2 dataset.
    cache_dir: The directory for caching temporary results.
    base_url: The base url for downloading GLD images.
  """
  logger = logging.getLogger(LOGGER)
  logger.info('Start to download fed gldv2 mapping files')

  path = get_file(
      '%s.zip' % FED_GLD_SPLIT_FILE_BUNDLE,
      origin=FED_GLD_SPLIT_FILE_DOWNLOAD_URL,
      extract=True,
      archive_format='zip',
      cache_dir=cache_dir)

  get_file(
      MINI_GLD_TRAIN_SPLIT_FILE,
      origin=MINI_GLD_TRAIN_DOWNLOAD_URL,
      cache_dir=cache_dir)
  get_file(
      MINI_GLD_TEST_SPLIT_FILE,
      origin=MINI_GLD_TEST_DOWNLOAD_URL,
      cache_dir=cache_dir)

  logger.info('Fed gldv2 mapping files are downloaded successfully.')
  base_path = os.path.dirname(path)
  train_path = os.path.join(base_path, FED_GLD_SPLIT_FILE_BUNDLE,
                            FED_GLD_TRAIN_SPLIT_FILE)
  test_path = os.path.join(base_path, FED_GLD_SPLIT_FILE_BUNDLE,
                           FED_GLD_TEST_SPLIT_FILE)
  train_mapping = _read_csv(train_path)
  test_mapping = _read_csv(test_path)
  all_images = set()
  all_images.update([row['image_id'] for row in train_mapping],
                    [row['image_id'] for row in test_mapping])
  image_dir = os.path.join(cache_dir, 'images')
  if not os.path.exists(image_dir):
    os.mkdir(image_dir)
  logger.info('Start to download GLDv2 dataset.')
  with multiprocessing.pool.ThreadPool(num_worker) as pool:
    train_args = [
        (i, all_images, image_dir, base_url) for i in range(NUM_SHARD_TRAIN)
    ]
    pool.starmap(_filter_images, train_args)

  logger.info('Finish downloading GLDv2 dataset.')



def load_data(num_worker: int = 1,
              cache_dir: str = 'cache',
              gld23k: bool = False,
              base_url: str = GLD_SHARD_BASE_URL):

  if not os.path.exists(cache_dir):
    os.mkdir(cache_dir)
  q = multiprocessing.Queue(-1)
  listener = multiprocessing.Process(
      target=_listener_process,
      args=(q, os.path.join(cache_dir, 'load_data.log')))
  listener.start()
  logger = logging.getLogger(LOGGER)
  qh = logging.handlers.QueueHandler(q)
  logger.addHandler(qh)
  logger.info('Start to load data.')
  logger.info('Loading from cache failed, start to download the data.')

  _download_data(num_worker, cache_dir, base_url)

  # no return 
  # fed_gld_train, fed_gld_test, mini_gld_train, mini_gld_test = _download_data(
  #       num_worker, cache_dir, base_url)

  q.put_nowait(None)
  listener.join()





if __name__ == '__main__':
  load_data(4, 'cache')











