import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.optim as optim 

X = np.load('mnist_test_seq.npy')
X = X[:,:20,:,:]
X = torch.tensor(X).cuda()
X = 2*X.float().sign() - 1

input_size = 64*64
M = 1000
U = torch.randn(M, input_size).cuda()
V = torch.randn(input_size, M).cuda()
        
input_size = 64*64    
hidden_size = 1000
lr = float(sys.argv[2])


for e in range(2000):

    train_loss = 0.0
    
    for s in range(1):

        for t in range(19):

            x = X[t,s,:,:].reshape(input_size, 1)
            y = X[t+1,s,:,:].reshape(input_size, 1)

            h = torch.sign(U @ x)
            z = torch.sign(V @ h)

            # delta = (y - z) * (1-torch.tanh(V @ h)**2)
            # delta = (y - z)
            delta = (y - z) * (1-torch.sigmoid(V @ h)) * torch.sigmoid(V @ h)

            V += lr * delta * h.T
            
            # delta = (V.T @ delta) * (1-torch.tanh(U @ x)**2)
            delta = (V.T @ delta) * (1-torch.sigmoid(U @ x)) * torch.sigmoid(U @ x)
            # delta = V.T @ delta

            U += lr * delta * x.T

            loss = torch.abs(y - z).sum()
            
            
            train_loss += loss.item()
    
    
    print(e, train_loss)


train_loss = 0.0
for s in range(20):

    x = X[:19,s,:,:].reshape(19, input_size)
    y = X[1:,s,:,:].reshape(19, input_size)

    h = torch.sign(U @ x)
    z = torch.sign(V @ h)

    loss = torch.abs(y - z).sum()
    
    train_loss += loss.item()    

print(train_loss)

f = h5py.File('mnist_tanh.h5', 'w')
f.create_dataset('U', data=U.cpu().detach().numpy())
f.create_dataset('V', data=V.cpu().detach().numpy())