import numpy as np
import torch




n = 500
d = 7

def gen():

    A = np.random.normal(0, 1, (500, 7))
    P = np.random.normal(0, 1, (7, 3))
    Q = np.random.normal(0, 1, (3, 7))
    X = P @ Q

    W = np.random.normal(0, 1, (500, 7))

    B = A @ X + 0.001 * W
    print(np.linalg.norm(X, 'nuc'))

    return A, B



A_train = torch.zeros(270, 500, 7)
B_train = torch.zeros(270, 500, 7)
A_test = torch.zeros(30, 500, 7)
B_test = torch.zeros(30, 500, 7)

for i in range(0, 270):
    print(i)
    A, B = gen()
    A_train[i] = torch.from_numpy(A)
    B_train[i] = torch.from_numpy(B)

    if i < 30:

        A, B = gen()
        A_test[i] = torch.from_numpy(A)
        B_test[i] = torch.from_numpy(B)




torch.save([A_train, B_train], "train270.dat")

torch.save([A_test, B_test], "test30.dat")

