%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% SETUP

clearvars;
restoredefaultpath;
format compact;
clc;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% DIRECTORIES

wd = pwd;                                       % working directory
id = '../../../Datasets/Handwriting/infimnist'; % infimnist directory
addpath(wd);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TEST/EVALUATION SET

cd(id);
images_filename = 'data/t10k-images-idx3-ubyte';
labels_filename = 'data/t10k-labels-idx1-ubyte';
addOffset = 1;
[Ze,Ye] = load_mnist(images_filename,labels_filename,addOffset);
[dimInput,numTestExamples] = size(Ze);
maxY = max(Ye);
cd(wd);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% SETUP

windowSize = 50000;
numTrainExamples = 1e8;
numWindows = ceil(numTrainExamples/windowSize);
models = cell(maxY,maxY);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% DIARY

tStart = tic;
fprintf(1,'%% linear\n');
fprintf(1,'%% window size: %d\n',windowSize);
fprintf(1,'%% window mErr cErr eErr aErr tSec\n');

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% INITIALIZE

for c1=1:maxY-1
  for c2=c1+1:maxY
    models{c1,c2}.v = zeros(dimInput,1);
    models{c1,c2}.va = zeros(dimInput,1);
  end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TRAIN

% LOOP
for window=1:numWindows

  % GENERATE WINDOW
  tWindow = tic;
  idxA = numTestExamples + (window-1)*windowSize;
  idxB = idxA + windowSize - 1;
  addOffset = 1;
  labels_filename = sprintf('labels_%04d.dat',window);
  images_filename = sprintf('patterns_%04d.dat',window);
  commandL = sprintf('./infimnist lab %d %d > %s/%s',idxA,idxB,wd,labels_filename);
  commandP = sprintf('./infimnist pat %d %d > %s/%s',idxA,idxB,wd,images_filename);
  cd(id); system(commandL); system(commandP); cd(wd);
  [Zt,Yt] = load_mnist(images_filename,labels_filename,addOffset);
  
  % LOOP OVER WINDOW
  votesT = zeros(maxY,windowSize);
  for t=1:windowSize
    for c1=1:maxY-1
      for c2=c1+1:maxY
        v = models{c1,c2}.v;
        z = Zt(:,t);
        dp = v'*z;
        guess = sign(dp);
        votesT(c1,t) = votesT(c1,t) + (guess>0);
        votesT(c2,t) = votesT(c2,t) + (guess<0);
        y = (Yt(t)==c1) - (Yt(t)==c2);
        score = y*dp;
        % PASSIVE
        if (y==0 || score>1)
          continue;
        end  
        % AGGRESSIVE
        v = v + y*z*(1-score)/(z'*z);
        models{c1,c2}.v = v;
      end
    end
  end
  
  % TEST ERROR
  votesE = zeros(maxY,numTestExamples);
  for c1=1:maxY-1
    for c2=c1+1:maxY
      vote1 = ((models{c1,c2}.v'*Ze) > 0);
      vote2 = 1-vote1;
      votesE(c1,:) = votesE(c1,:) + vote1;
      votesE(c2,:) = votesE(c2,:) + vote2;
    end
  end

  % TEST ERROR FROM AVERAGED MODEL
  votesA = zeros(maxY,numTestExamples);
  for c1=1:maxY-1
    for c2=c1+1:maxY
      models{c1,c2}.va = ((window-1)*models{c1,c2}.va + models{c1,c2}.v)/window;
      vote1 = ((models{c1,c2}.va'*Ze) > 0);
      vote2 = 1-vote1;
      votesA(c1,:) = votesA(c1,:) + vote1;
      votesA(c2,:) = votesA(c2,:) + vote2;
    end
  end

  % TABULATE VOTES
  [~,guessT] = max(votesT);
  [~,guessE] = max(votesE);
  [~,guessA] = max(votesA);
  mErr = mean(guessT~=Yt);
  eErr = mean(guessE~=Ye);
  aErr = mean(guessA~=Ye);
  if (window==1)
    cErr = mErr;
  else
    cErr = ((window-1)*cErr + mErr)/window;
  end
  
  % PRINT, CLEAN UP
  fprintf('%04d   %1.4f  %1.4f  %1.4f  %1.4f  %6f\n',...
    window,mErr,cErr,eErr,aErr,toc(tWindow));
  delete(labels_filename);
  delete(images_filename);
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% DONE

toc(tStart);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

