"""Dataset utilities file for large LR experiments on Cloud."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from jax import random
import numpy as np
import jax.numpy as jnp
from jax import lax
import tensorflow_datasets as tfds
import time

def flatten(x):
  """Flatten all but the first dimension."""
  return np.reshape(x, (x.shape[0], -1))


def _one_hot(x, k, dtype=np.float64):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)


def _standardize(x):
  """Standardization per sample across feature dimension."""
  axes = tuple(range(1, len(x.shape)))
  mean = np.mean(x, axis=axes, keepdims=True)
  std_dev = np.std(x, axis=axes, keepdims=True)
  return (x - mean) / std_dev


def _random_crop(x, pixels, rng):
  """x should have shape [batch, img_dim, img_dim, channels]."""
  zero = (0, 0)
  pixpad = (pixels, pixels)
  paddedX = np.pad(x, (zero, pixpad, pixpad, zero), 'reflect')
  corner = random.randint(rng, (x.shape[0], 2), 0, 2 * pixels)
  assert x.shape[1] == x.shape[2]
  img_size = x.shape[1]
  slices = [(slice(int(o[0]), int(o[0]) + img_size), slice(int(o[1]), int(o[1]) + img_size), slice(None)) for x, o in zip(paddedX, corner)]
  paddedX = np.concatenate([x[np.newaxis, s[0], s[1], s[2]] for x, s in zip(paddedX, slices)])
  return paddedX


def _random_horizontal_flip(x, prob, rng):
  """Horizontal flip with probability prob."""
  flip = random.uniform(rng, shape=(len(x), 1, 1 ,1))
  flippedX = x[:, :, ::-1, :]
  x = np.where(flip < prob, flippedX, x)
  return x


def load_and_process_dataset(string_name, output_dim, n_train=None, n_test=None, permute_train=False,data_dir=None,random_labels=False):
  """Loads dataset according to string_name and returns standardized images and one-hot labels.
  Args:
  string_name: 'cifar10', 'fashion_mnist', 'mnist'
  n_train, n_test: floats for subset sizes of train/test set. Defaults to full with None.
  permute_train: permutes train set.
  """
  if string_name in ['cifar10', 'mnist', 'fashion_mnist']:
    ds_train, ds_test = tfds.as_numpy(tfds.load(string_name, data_dir=data_dir,split=["train", "test"], batch_size=-1, as_dataset_kwargs={"shuffle_files": False}))
  else:
    raise ValueError("Invalid dataset name.")
  train_x, train_y, test_x, test_y = (ds_train["image"], ds_train["label"], ds_test["image"], ds_test["label"])
  
  if output_dim<max(train_y):
      print('Picking only',output_dim)
      train_idx,test_idx=train_y< output_dim, test_y< output_dim
      train_x,train_y=train_x[train_idx],train_y[train_idx]
      test_x,test_y=test_x[test_idx],test_y[test_idx]

  train_x, test_x = _standardize(train_x), _standardize(test_x)
  if random_labels:
      train_y=np.random.choice(output_dim,size=train_y.shape[0]) 
      test_y=np.random.choice(output_dim,size=test_y.shape[0]) 
  train_y, test_y = _one_hot(train_y, output_dim), _one_hot(test_y, output_dim)
  if permute_train:
    perm = np.random.RandomState(0).permutation(train_x.shape[0])
    train_x = train_x[perm]
    train_y = train_y[perm]
  if n_train:
    train_x = train_x[:n_train]
    train_y = train_y[:n_train]
  if n_test:
    test_x = test_x[:n_test]
    test_y = test_y[:n_test]
  
  return (train_x, train_y), (test_x, test_y)


def batch(data_tuple, batch_size):
  x, y = data_tuple
  for i in range(0, x.shape[0], batch_size):
    yield (x[i:(i+batch_size)], y[i:(i+batch_size)])


def crop(key, image_and_label):
  """Random flips and crops."""
  image, label = image_and_label

  pixels = 4
  pixpad = (pixels, pixels)
  zero = (0, 0)
  padded_image = jnp.pad(image, (pixpad, pixpad, zero), 'constant', 0.0)
  corner = random.randint(key, (2,), 0, 2 * pixels)
  corner = jnp.concatenate((corner, jnp.zeros((1,), jnp.int32)))
  img_size = (32, 32, 3)
  cropped_image = lax.dynamic_slice(padded_image, corner, img_size)

  return cropped_image, label

from jax import vmap

crop = vmap(crop, 0, 0)


def mixup(key, alpha, image_and_label):
  image, label = image_and_label 

  N = image.shape[0]

  weight = random.beta(key, alpha, alpha, (N, 1))
  mixed_label = weight * label + (1.0 - weight) * label[::-1]

  weight = jnp.reshape(weight, (N, 1, 1, 1))
  mixed_image = weight * image + (1.0 - weight) * image[::-1]

  return mixed_image, mixed_label


def transform(key, image_and_label):
  image, label = image_and_label
  
  key, split = random.split(key)

  N = image.shape[0]
  image = jnp.reshape(image, (N, 32, 32, 3))

  image = jnp.where(
      random.uniform(split, (N, 1, 1, 1)) < 0.5,
      image[:, :, ::-1],
      image)

  key, split = random.split(key)
  batch_split = random.split(split, N)
  image, label = crop(batch_split, (image, label))

  return mixup(key, 1.0, (image, label))


def minibatcher(data_tuple, batch_size, seed=None, augment=False):
  """Choose batch size to divide data size. data_tuple is a (images, labels) tuple. Seed is used for shuffling at every epoch.
  Adjust augmentation parameters if needed."""
  x, y = data_tuple
  if x.shape[0]<32:
    while True:
      yield (x,y)

  if seed is None:
    seed=int(time.process_time())
  key = random.PRNGKey(seed)
  start = 0

  while True:
    end = start + batch_size
    if end > x.shape[0]:
      key, split = random.split(key)
      permutation = random.shuffle(split, np.arange(x.shape[0], dtype=np.int64))
      x = x[permutation]
      y = y[permutation]
      start = 0
      continue
    x_batch, y_batch = x[start:end], y[start:end]
    if augment:
      key, subkey = random.split(key)
      x_batch, y_batch = transform(subkey, ( x_batch, y_batch) )

    #   key, split = random.split(key)
    #   x_batch = _random_crop(x_batch, 4, split)
    #   key, split = random.split(key)
    #   x_batch = _random_horizontal_flip(x_batch, 0.5, split)
    yield (x_batch, y_batch)
    start = end






