clear all
close all

T = 20;
K = 64;
N = K^2;
% N = 64*64;
M = 1000;
S = 20;

X = h5read('mnist_test_seq.h5','/data');
X = X(:,:,1:S,:);
size(X)

% i = randi(10000);
% for t = 1:20    
% %     imshow(sign(X(1:2:end,1:2:end,i,t)')*255)
%     imshow(sign(X(:,:,i,t)')*255)
%     pause(0.2)
% end

% X = 2*single(sign(X(1:2:end,1:2:end,:,:))) - 1;
X = 2*single(sign(X(:,:,:,:))) - 1;

% X = sign(randn(size(X)));
% X = repmat(sign(randn(32,32,1,1)), [1 1 1 20]);


U = 0.001*single(randn(M,N));
V = 0.001*single(zeros(N,M));

for s = 1:S
    for t = 1:T-1
        x1 = reshape(X(:,:,s,t),[N 1]);
        x2 = reshape(X(:,:,s,t+1),[N 1]);

        y = sign(U*x1);
        
        V = V + x2 * y';                    
    end
end

Z = single(zeros(N,T));
Y = single(zeros(N,T));
Y(:,1) = reshape(X(:,:,1,1),[N 1]);
Z(:,1) = reshape(X(:,:,1,1),[N 1]);

for t = 1:T-1   
    h        = sign(U*Y(:,t));
    Y(:,t+1) = sign(V*h);   
    Z(:,t+1) = reshape(X(:,:,1,t+1),[N 1]);
end

norm(Y-Z)

h5create('64_1000_20_mnist_heb.h5', '/U', size(U), 'Datatype', 'single');
h5create('64_1000_20_mnist_heb.h5', '/V', size(V), 'Datatype', 'single');
h5write('64_1000_20_mnist_heb.h5', '/U', U);
h5write('64_1000_20_mnist_heb.h5', '/V', V);