"""Data generation methods"""

import numpy as np
import tensorflow as tf

def generate_batch_from_dataset(k, batch_size, dataset):
  """
  Generates a batch of length batch_size, and prompt length k 
  from a dataset of prompt length >= k, and number of 
  prompts >= batch_size.  
  """
  k_max = dataset[1].shape[1] - 1
  indices = np.arange(k_max)
  np.random.shuffle(indices)
  indices = indices[:k]
  y_indices = np.copy(indices)
  indices = np.vstack((2 * indices, 2 * indices + 1)).T.flatten()
  indices = np.append(indices, [2 * k_max])
  y_indices = np.append(y_indices, [k_max])
  prompt_indices = np.random.choice(
      len(dataset[0]), batch_size, replace=False
  ).tolist()
  return dataset[0][:, indices, :][prompt_indices], dataset[1][:, y_indices][prompt_indices] 

def generate_batch(
    d,
    k,
    batch_size,
    input_size,
    noise_level,
    return_reg=False,
    return_noiseless=False,
    centers=None,
):
  X = np.random.randn(batch_size, k + 1, input_size)
  W = np.random.randn(batch_size, input_size)
  if centers is not None:
    num_centers = np.shape(centers)[0]
    W = centers[np.random.randint(num_centers, size=batch_size).tolist(), :]
  X[:, :, d:] = 0.0
  Y_noiseless = np.einsum('mij,mj->mi', X, W)
  Y_noisy = np.copy(Y_noiseless) + np.sqrt(noise_level) * np.random.randn(*Y_noiseless.shape)
  Z = np.zeros((batch_size, 2 * k + 1, input_size))
  Z[:, ::2, :] = X
  Z[:, 1::2, 0] = Y_noisy[:, :-1]
  if return_reg and not return_noiseless:
    return (
        X[:, :-1, :],
        Y_noisy[:, :-1],
        X[:, -1, :],
        tf.convert_to_tensor(Y_noisy[:, -1], dtype=tf.float32),
        tf.convert_to_tensor(Z, dtype=tf.float32),
    )
  elif return_reg and return_noiseless:
    return (
        X[:, :-1, :],
        Y_noisy[:, :-1],
        X[:, -1, :],
        tf.convert_to_tensor(Y_noiseless[:, -1], dtype=tf.float32),
        tf.convert_to_tensor(Z, dtype=tf.float32),
    )
  else:
    return (
        Z.astype(np.float32),
        Y_noisy.astype(np.float32)
    )