import array
import gzip
import os
from os import path
import struct
import urllib.request

import numpy as np
import jax.numpy as jnp
import numpy.random as npr
import pickle

import matplotlib.pyplot as plt

_DATA = "/tmp/jax_example_data/"

np.set_printoptions(threshold=np.inf)

def _download(url, filename):
  """Download a url to a file in the JAX data temp directory."""
  if not path.exists(_DATA):
    os.makedirs(_DATA)
  out_file = path.join(_DATA, filename)
  if not path.isfile(out_file):
    urllib.request.urlretrieve(url, out_file)
    print("downloaded {} to {}".format(url, _DATA))

def _partial_flatten(x):
  """Flatten all but the first dimension of an ndarray."""
  return np.reshape(x, (x.shape[0], -1))

def _one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

def mnist_raw():
  """Download and parse the raw MNIST dataset."""
  # CVDF mirror of http://yann.lecun.com/exdb/mnist/
  base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"

  def parse_labels(filename):
    with gzip.open(filename, "rb") as fh:
      _ = struct.unpack(">II", fh.read(8))
      return np.array(array.array("B", fh.read()), dtype=np.uint8)

  def parse_images(filename):
    with gzip.open(filename, "rb") as fh:
      _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
      return np.array(array.array("B", fh.read()),
                      dtype=np.uint8).reshape(num_data, rows, cols)

  for filename in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz",
                   "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]:
    _download(base_url + filename, filename)

  train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz"))
  train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz"))
  test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
  test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))

  return train_images, train_labels, test_images, test_labels

def composed_mnist(batch_size, r_scale):
  """Download, parse and process MNIST data to unit scale and one-hot labels."""
  train_images, train_labels, test_images, test_labels = mnist_raw()

  train_images = train_images / np.float32(255.)
  test_images = test_images / np.float32(255.)
  train_labels_one_hot = _one_hot(train_labels, 10)
  test_labels_one_hot = _one_hot(test_labels, 10)

  train_images_1 = train_images[:20000]
  train_images_2 = train_images[20000:40000]
  train_images_3 = train_images[40000:]
  train_images = np.dstack([train_images_1, train_images_2, train_images_3])
  train_images = train_images.reshape(train_images.shape[0], train_images.shape[1], train_images.shape[2], 1)

  train_labels_1 = train_labels_one_hot[:20000]
  train_labels_2 = train_labels_one_hot[20000:40000]
  train_labels_3 = train_labels_one_hot[40000:]
  train_labels_sys = np.hstack([train_labels_1, train_labels_2, train_labels_3])
  train_labels_hun = np.array(train_labels[:20000], dtype = np.float32)*100
  train_labels_ten = np.array(train_labels[20000:40000], dtype = np.float32)*10
  train_labels_uni = np.array(train_labels[40000:])
  train_labels_non = train_labels_hun + train_labels_ten + train_labels_uni
  train_labels_non = _one_hot(train_labels_non, 1000)
  train_labels_non = train_labels_non * r_scale
  train_labels = np.hstack([train_labels_sys, train_labels_non])

  test_images_1 = test_images[:3333]
  test_images_2 = test_images[3333:6666]
  test_images_3 = test_images[6666:9999]
  test_images = np.dstack([test_images_1, test_images_2, test_images_3])
  test_images = test_images.reshape(test_images.shape[0], test_images.shape[1], test_images.shape[2], 1)

  test_labels_1 = test_labels_one_hot[:3333]
  test_labels_2 = test_labels_one_hot[3333:6666]
  test_labels_3 = test_labels_one_hot[6666:9999] 
  test_labels_sys = np.hstack([test_labels_1, test_labels_2, test_labels_3])
  test_labels_hun = np.array(test_labels[:3333], dtype = np.float32)*100
  test_labels_ten = np.array(test_labels[3333:6666], dtype = np.float32)*10
  test_labels_uni = np.array(test_labels[6666:9999])
  test_labels_non = test_labels_hun + test_labels_ten + test_labels_uni
  test_labels_non = _one_hot(test_labels_non, 1000)
  test_labels_non = test_labels_non * r_scale
  test_labels = np.hstack([test_labels_sys, test_labels_non])  

  num_train = train_images.shape[0]
  num_complete_batches, leftover = divmod(num_train, batch_size)
  num_batches = num_complete_batches + bool(leftover)

  def data_stream():
    rng = npr.RandomState(0)
    while True:
      perm = rng.permutation(num_train)
      for i in range(num_batches):
        batch_idx = perm[i * batch_size:(i + 1) * batch_size]
        yield train_images[batch_idx], train_labels[batch_idx]

  plt.imsave('eg_im.png', train_images[0, :, :, 0], cmap='gray')

  return train_images, train_labels, test_images, test_labels, num_batches, data_stream()
