"""Code for creating the synthetic datasets used in the paper."""
from typing import Tuple

import numpy as np


def make_dataset(d: int, m: int, N: int) -> Tuple[np.ndarray, np.ndarray]:
    """Create a synthetic regression dataset.

    The dataset is created by sampling examples from a Gaussian. A random
    ReLU network is then generated and used to label the examples.

    Args:
        d: The dimension of the examples.
        m: The number of hidden units in the randomly generated ReLU
            netork used to label the examples.
        N: The number of examples to generate.

    Returns:
        X: A float64 array with shape [N, d] containing the dataset examples.
        Y: A float64 array with shape [N] containing the labels.
    """
    X = np.random.normal(size=[N, d])
    W = np.random.normal(size=[m, d])
    b = np.random.normal(size=[m])
    v = np.sign(np.random.normal(size=[m]))
    c = np.random.normal(size=[])
    preacts = X @ W.T + b
    acts = np.maximum(preacts, 0)
    Y = acts @ v + c
    return X, Y
