import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.optim as optim 
import matplotlib.pyplot as plt

T = 136 - 34 + 1
N = 128*88
M = 200
k1 = 1
k2 = 1
r1 = 0.001
r2 = 0.001

epochs = 1000;

X = torch.zeros(N,T)
for i in range(34,136+1):
    I = plt.imread('OULP_SilhouetteSample/%08d.png' % i)    
    X[:,i-34] = torch.tensor(I).T.float().sign().reshape(1,N)    
X = X.cuda()
X = 2*X - 1
    
def H(x): return torch.heaviside(x, torch.tensor([0.0]).cuda())

U = 0.001*torch.randn(M,N).cuda()
V = 0.001*torch.randn(N,M).cuda()
P = 0.001*torch.randn(M,N).cuda()

train_loss_epoch = np.zeros(epochs)

for e in range(epochs):        

    train_loss = 0.0
    for t in range(T-1):
        x1 = X[:,t].reshape(N,1)
        x2 = X[:,t+1].reshape(N,1)

        z = torch.sign(P @ x2)
        u = H(k1 - z * (U @ x1))        
        U += r1 * (u * z) @ x1.T

        y = torch.sign(U @ x1)
        v = H(k2 - x2 * (V @ y))
        V += r2 * (v * x2) @ y.T
        
        train_loss += torch.pow(x2 - torch.sign(V @ torch.sign(U @ x1)), 2).sum()
    
    print('%d %0.2f' % (e, train_loss))
    train_loss_epoch[e] = train_loss


f = h5py.File('silhouette_bin.h5', 'w')
f.create_dataset('U', data=U.cpu().detach().numpy())
f.create_dataset('V', data=V.cpu().detach().numpy())
f.create_dataset('e', data=train_loss_epoch)