import numpy as np
import torch
import pandas as pd



data = np.load("number.npy")

print(data.shape)



label = np.loadtxt("gisette_train.labels")

print(label.shape)
n = 5000
d = 30
A_train = torch.zeros(160, 5030, 30)
A_test = torch.zeros(40, 5030, 30)

for i in range(160):
        print(i)
        Tmp = data[i * 30 : (i + 1) * 30, :].T

        U, Sig, V = np.linalg.svd(Tmp)
        Tmp = Tmp / (Sig[0] / 100)
        y = label[i * 30 : (i + 1) *30]

        D = np.zeros((d, d))
        for j in range(30):
            D[j, j] = y[j]
        E = np.eye(d)


        A_train[i] = torch.from_numpy(np.hstack(((Tmp @ D).T, E)).T)


for i in range(40):

        print(i)
        Tmp = data[4800 + i * 30: 4800 + (i + 1) * 30, :].T

        U, Sig, V = np.linalg.svd(Tmp)
        Tmp = Tmp / (Sig[0] / 100)
        y = label[4800 + i * 30: 4800 + (i + 1) * 30]

        D = np.zeros((d, d))
        for j in range(30):
                D[j, j] = y[j]
        E = np.eye(d)


        A_test[i] = torch.from_numpy(np.hstack(((Tmp @ D).T, E)).T)


torch.save(A_train, "train160.dat")

torch.save(A_test, "test40.dat")
