import torch
import numpy as np
from grokking_experiments import addition_function

def generate_data(num_samples, seq_length, learn_function, relevant_coords=None):
    """ Generate data for the given boolean function."""

    # Generates random binary data.
    X = torch.randint(0, 2, (num_samples, seq_length))

    # Create labels.
    if learn_function.__name__ != "junta":
        y = torch.tensor([learn_function(x) for x in X]).long()
    else:
        y = torch.tensor([learn_function(x, relevant_coords) for x in X]).long()

    return X, y

def generate_data_addition(num_samples, K=113):
    """
        Generate (X,Y) pair samples for addition function of numbers from 0 to K-1 mod K
    """

    X = torch.randint(0, K, (num_samples, 2)) # shape: (num_samples, 2): each entry in [0, K-1]
    y = (X[:, 0] + X[:, 1]) % K               # shape: (num_samples,): each entry in [0, K-1]
    return X, y
