import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.optim as optim 

X = np.load('mnist_test_seq.npy')
X = X[:,:20,:,:]
X = torch.tensor(X).cuda()
X = 2*X.float().sign() - 1

input_size = 64*64
M = 1000
U = 0.001*torch.randn(M, input_size).cuda()
V = 0.001*torch.randn(input_size, M).cuda()
P = 0.001*torch.randn(M, input_size).cuda()
        
input_size = 64*64    
hidden_size = 1000
r1 = 0.001
r2 = 0.001
k1 = 1
k2 = 1
epoches = 200

def H(x): return torch.heaviside(x, torch.tensor([0.0]).cuda())

train_loss_epoch = np.zeros(epoches)

for e in range(epoches):

    train_loss = 0.0
    
    for s in range(20):

        for t in range(19):

            x1 = X[t,s,:,:].reshape(input_size, 1)
            x2 = X[t+1,s,:,:].reshape(input_size, 1)

            z = torch.sign(P @ x2)
            u = H(k1 - z * (U @ x1))

            U += r1 * (u * z) @ x1.T

            y = torch.sign(U @ x1)
            v = H(k2 - x2 * (V @ y))

            V += r2 * (v * x2) @ y.T            

            train_loss += torch.pow(x2 - torch.sign(V @ torch.sign(U @ x1)), 2).sum()
            
    train_loss_epoch[e] = train_loss.item()
    
    
    print(e, train_loss.item())


f = h5py.File('mnist_bin.h5', 'w')
f.create_dataset('U', data=U.cpu().detach().numpy())
f.create_dataset('V', data=V.cpu().detach().numpy())
f.create_dataset('e', data=train_loss_epoch)