clear all
close all

T = 20;
K = 64;
N = K^2;
% N = 64*64;
M = 1000;
S = 20;
iters = 500;

X = h5read('mnist_test_seq.h5','/data');
X = X(:,:,1:S,:);
size(X)

% i = randi(10000);
% for t = 1:20    
% %     imshow(sign(X(1:2:end,1:2:end,i,t)')*255)
%     imshow(sign(X(:,:,i,t)')*255)
%     pause(0.2)
% end

% X = 2*single(sign(X(1:2:end,1:2:end,:,:))) - 1;
X = 2*single(sign(X(:,:,:,:))) - 1;

% X = sign(randn(size(X)));
% X = repmat(sign(randn(32,32,1,1)), [1 1 1 20]);

k1 = 1;
k2 = 1;
r1 = 0.001;
r2 = 0.001;

U = 0.001*single(randn(M,N));
V = 0.001*single(randn(N,M));
% U = single(zeros(M,N));
% V = single(zeros(N,M));

W = 0.001*single(randn(M,N));

err_u = zeros(1,iters);
err_v = zeros(1,iters);

for iter = 1:iters
    
    for s = 1:S
        for t = 1:T-1
            x1 = reshape(X(:,:,s,t),[N 1]);
            x2 = reshape(X(:,:,s,t+1),[N 1]);

%             u = threshold(k1-y.*(U*x1));
%             v = threshold(k2-x2.*(V*y));
%             z = (V'*(x2.*v));
%             u = 1-tanh((U*x1)).^2;             
%             U = U + r1*(z.*u)*x1';
%             V = V + r2*(v.*x2)*y';
%             y = sign(U*x1);
            z = sign(W*x2);        
            u = threshold(k1 - z .* (U*x1));
            U = U + r1 * (u.*z) * x1';
 
            y = sign(U*x1);
            v = threshold(k2 - x2 .*(V*y));
            V = V + r2 * (v.*x2) * y';            
            
            err_u(iter) = err_u(iter) + mean(u);
            err_v(iter) = err_v(iter) + mean(v);
        end
    end

    disp(sprintf('%d %0.2f %0.2f',iter,err_u(iter),err_v(iter)));
end

Z = single(zeros(N,T));
Y = single(zeros(N,T));
Y(:,1) = reshape(X(:,:,1,1),[N 1]);
Z(:,1) = reshape(X(:,:,1,1),[N 1]);

for t = 1:T-1   
    h        = sign(U*Y(:,t));
    Y(:,t+1) = sign(V*h);   
    Z(:,t+1) = reshape(X(:,:,1,t+1),[N 1]);
end

norm(Y-Z)

h5create('64_1000_20_mnist_two.h5', '/U', size(U), 'Datatype', 'single');
h5create('64_1000_20_mnist_two.h5', '/V', size(V), 'Datatype', 'single');
h5create('64_1000_20_mnist_two.h5', '/err_u', size(err_u));
h5create('64_1000_20_mnist_two.h5', '/err_v', size(err_v));
h5write('64_1000_20_mnist_two.h5', '/U', U);
h5write('64_1000_20_mnist_two.h5', '/V', V);
h5write('64_1000_20_mnist_two.h5', '/err_u', err_u);
h5write('64_1000_20_mnist_two.h5', '/err_v', err_v);