import torch
from torch.utils.data import Dataset
from torch.nn import functional as F


def gen_sparse_linear_classification(seed, n_train, n_test, d, k, device):
    ## save random number generator of pytorch CPU and restore it later
    old_state = torch.random.get_rng_state()
    torch.manual_seed(seed)

    ## z: iid ~ Unif({\pm1}^d) 
    train_inputs = torch.randint(0, 2, (n_train, d), dtype=torch.float32, device=device) * 2 - 1
    test_inputs = torch.randint(0, 2, (n_test, d), dtype=torch.float32, device=device) * 2 - 1
    
    ## w* is k-sparse, all but first s elements are zero
    w = torch.randint(0, 2, (k,), dtype=torch.float32, device=device) * 2 - 1

    train_targets = train_inputs[:, :k] @ w
    test_targets = test_inputs[:, :k] @ w

    torch.set_rng_state(old_state)

    return (train_inputs, train_targets), (test_inputs, test_targets)
    
