import sys
import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.optim as optim 
import matplotlib.pyplot as plt

T = 136 - 34 + 1
N = 128*88
M = 200

lr = float(sys.argv[2])

epochs = 2000

X = torch.zeros(N,T)
for i in range(34,136+1):
    I = plt.imread('OULP_SilhouetteSample/%08d.png' % i)    
    X[:,i-34] = torch.tensor(I).T.float().sign().reshape(1,N)
X = X.cuda()
X = 2*X - 1
    
def H(x): return torch.heaviside(x, torch.tensor([0.0]).cuda())

class Net(nn.Module):
    def __init__(self, input_size, M):
        super(Net, self).__init__()
        self.U = nn.Linear(input_size, M, bias=False)
        self.V = nn.Linear(M, input_size, bias=False)
        
    def forward(self, x):
        x = torch.tanh(self.U(x))
        x = torch.tanh(self.V(x))        
        return x    

model = Net(N, M)
model = torch.compile(model)
model.cuda()
# optimizer = optim.SGD(model.parameters(), lr=lr)
optimizer = optim.Adam(model.parameters(), lr=lr)


train_loss_epoch = np.zeros(epochs)

for e in range(epochs):
        
    train_loss = 0.0
    for t in range(T-1):
        optimizer.zero_grad()

        x = X[:,t].reshape(1,N)
        y = X[:,t+1].reshape(1,N)        

        z = model(x)

        loss = torch.pow(y - z, 2).sum()
        loss.backward()

        optimizer.step()

        train_loss += loss.item()
    
    print('%d %0.2f' % (e, train_loss))
    train_loss_epoch[e] = train_loss


f = h5py.File('silhouette_tanh_adam.h5', 'w')
f.create_dataset('U', data=model.U.weight.cpu().detach().numpy())
f.create_dataset('V', data=model.V.weight.cpu().detach().numpy())
f.create_dataset('e', data=train_loss_epoch)