import sys
import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.optim as optim 

X = np.load('mnist_test_seq.npy')
X = X[:,:20,:,:]
X = torch.tensor(X).cuda()
X = 2*X.float().sign() - 1

class Net(nn.Module):
    def __init__(self, input_size, M):
        super(Net, self).__init__()
        self.U = nn.Linear(input_size, M, bias=False)
        self.V = nn.Linear(M, input_size, bias=False)
        
    def forward(self, x):
        x = torch.tanh(self.U(x))
        x = torch.tanh(self.V(x))        
        return x    
    
input_size = 64*64    
hidden_size = 1000
epoches = 2000
lr = float(sys.argv[2])

model = Net(input_size, hidden_size)
model = torch.compile(model)
model.cuda()
optimizer = optim.SGD(model.parameters(), lr=lr)
# optimizer = optim.Adam(model.parameters(), lr=lr)

train_loss_epoch = np.zeros(epoches)

for e in range(epoches):

    train_loss = 0.0
    
    for s in range(20):

        for t in range(19):

            optimizer.zero_grad()

            x = X[t,s,:,:].reshape(1, input_size)
            y = X[t+1,s,:,:].reshape(1, input_size)

            z = model(x)

            loss = torch.pow(y - z, 2).sum()
            loss.backward()

            optimizer.step()

            train_loss += loss.item()
    
    print(e, train_loss)
    
    train_loss_epoch[e] = train_loss


f = h5py.File('mnist_tanh_adam.h5', 'w')
f.create_dataset('U', data=model.U.weight.cpu().detach().numpy())
f.create_dataset('V', data=model.V.weight.cpu().detach().numpy())
f.create_dataset('e', data=train_loss_epoch)