clear all
close all

% rng(0)
T = 20*20;
K = 64;
N = K^2;
 

X = h5read('mnist_test_seq.h5','/data');
X = X(:,:,1:20,:); 
X = 2*single(sign(X(1:end,1:end,:,:))) - 1;
X = reshape(X, [N T]);


S = zeros(N,T);
S(:,1) = X(:,1);
rp = randperm(N);
for i = 1:200
    S(rp(i),1) = -S(rp(i),1);
end

O_pinv = pinv(1/N*X'*X);

for t = 1:T-1
         
    m = 1/(N-1) * (repmat(X' * S(:,t), [1 N])' - bsxfun(@times, X, S(:,t)));
    m = m * O_pinv;
     
    f = m(:,1:T-1).^20;
%     f = exp(m(:,1:T-1));
    S(:,t+1) = sign( sum(X(:,2:T) .* f, 2) );
end

% sum(sum(abs(S-X)))
sum(sum(abs(S(:,end)-X(:,end))))