import torch
import torch.nn as nn
import math

def partition(D, L):
    return torch.randperm(D).view(L, -1)

def make_data(N, D, L, d, S, w):
    y = torch.randn(N).sign()
    q = math.log(d) / D
    sigma = 1 / math.sqrt(d)
    noise = torch.randn(N, D, d) * sigma
    X = torch.zeros(N, D, d)
    # deltas = torch.rand(N, D) - q
    # deltas = ((deltas.sgn() - 1).sgn() * (torch.rand(N, D) - 0.5).sgn()).sgn()
    # X = X
    for i in range(N):
        l = torch.randint(L, (1,))[0]
        R = S[l]
        for j in range(D):
            if j in R:
                X[i][j] = y[i] * w + noise[i][j]
            else:
                prob = 2 * (torch.rand(1) - 0.5)
                if prob > 0 and prob < q/2:
                    delta = 1
                elif prob < 0 and prob > -q/2:
                    delta = -1
                else:
                    delta = 0
                X[i][j] = delta * w + noise[i][j]
    return X, y, w, S



