clear all
close all

T = 20;
K = 64;
N = K^2;
% N = 64*64;
M = 1000;
S = 20;
iters = 500;

X = h5read('mnist_test_seq.h5','/data');
X = X(:,:,1:S,:);
size(X)

% i = randi(10000);
% for t = 1:20    
% %     imshow(sign(X(1:2:end,1:2:end,i,t)')*255)
%     imshow(sign(X(:,:,i,t)')*255)
%     pause(0.2)
% end

% X = 2*single(sign(X(1:2:end,1:2:end,:,:))) - 1;
X = 2*single(sign(X(:,:,:,:))) - 1;


U = h5read('mnist_tanh.h5', '/U');
V = h5read('mnist_tanh.h5', '/V');



for iter = 1:iters
    
    for s = 1:S
        for t = 1:T-1
            x = reshape(X(:,:,s,t),[N 1]);
            y = reshape(X(:,:,s,t+1),[N 1]);

            h = tanh(U*x);
            z = tanh(V*h);        
            
            delta = (z - y) .* (1-tanh(V*h).^2);
            
            V = V - r1 * delta * h';
            
            delta = (V'*delta) .* (1-tanh(U*x).^2);
            
            U = U - r2 * delta * x';            

            err(iter) = err(iter) + sum(sum(abs(z - y)));
        end
    end

    disp(sprintf('%d %0.2f',iter,err(iter)));
end

Z = single(zeros(N,T));
Y = single(zeros(N,T));
Y(:,1) = reshape(X(:,:,1,1),[N 1]);
Z(:,1) = reshape(X(:,:,1,1),[N 1]);

for t = 1:T-1   
    h        = sign(U*Y(:,t));
    Y(:,t+1) = sign(V*h);   
    Z(:,t+1) = reshape(X(:,:,1,t+1),[N 1]);
end

norm(Y-Z)

h5create('64_1000_20_mnist_tanh.h5', '/U', size(U), 'Datatype', 'single');
h5create('64_1000_20_mnist_tanh.h5', '/V', size(V), 'Datatype', 'single');
h5create('64_1000_20_mnist_tanh.h5', '/err_u', size(err_u));
h5create('64_1000_20_mnist_tanh.h5', '/err_v', size(err_v));
h5write('64_1000_20_mnist_tanh.h5', '/U', U);
h5write('64_1000_20_mnist_tanh.h5', '/V', V);
h5write('64_1000_20_mnist_tanh.h5', '/err_u', err_u);
h5write('64_1000_20_mnist_tanh.h5', '/err_v', err_v);