%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% SETUP

clearvars;
restoredefaultpath;
format compact;
clc;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% DIRECTORIES

wd = pwd;                                       % working directory
id = '../../../Datasets/Handwriting/infimnist'; % infimnist directory
addpath(wd);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TEST/EVALUATION SET

cd(id);
images_filename = 'data/t10k-images-idx3-ubyte';
labels_filename = 'data/t10k-labels-idx1-ubyte';
addOffset = 1;
[Ze,Ye] = load_mnist(images_filename,labels_filename,addOffset);
[dimInput,numTestExamples] = size(Ze);
maxY = max(Ye);
cd(wd);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% SETUP

dimHidden = 1;
windowSize = 50000;
numTrainExamples = 1e8;
numWindows = ceil(numTrainExamples/windowSize);
models = cell(maxY,maxY);
f = @(alpha,sq_Uz,sq_Vz,y) ...
   sq_Uz./(1-alpha*y) + sq_Vz./(1+alpha*y) - alpha;
opts = optimset('TolX',1e-12);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% SEED

seed = 1;
    
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% OUTPUT

tStart = tic;
fprintf(1,'%% DoS hidden vector dimensionality: %d\n',dimHidden);
fprintf(1,'%% seed: %d\n',seed);
fprintf(1,'%% window size: %d\n',windowSize);
fprintf(1,'%% window mErr cErr eErr aErr tSec\n');

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% INITIALIZE

rng(seed);
for c1=1:maxY-1
  for c2=c1+1:maxY
    models{c1,c2}.U = randn(dimHidden,dimInput)/sqrt(dimHidden*dimInput);
    models{c1,c2}.V = randn(dimHidden,dimInput)/sqrt(dimHidden*dimInput);
    models{c1,c2}.Ua = models{c1,c2}.U;
    models{c1,c2}.Va = models{c1,c2}.V;
  end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TRAIN

% LOOP
for window=1:numWindows

  % GENERATE WINDOW
  tWindow = tic;
  idxA = numTestExamples + (window-1)*windowSize;
  idxB = idxA + windowSize - 1;
  addOffset = 1;
  labels_filename = sprintf('labels%02d_%04d.dat',dimHidden,window);
  images_filename = sprintf('patterns%02d_%04d.dat',dimHidden,window);
  commandL = sprintf('./infimnist lab %d %d > %s/%s',idxA,idxB,wd,labels_filename);
  commandP = sprintf('./infimnist pat %d %d > %s/%s',idxA,idxB,wd,images_filename);
  cd(id); system(commandL); system(commandP); cd(wd);
  [Zt,Yt] = load_mnist(images_filename,labels_filename,addOffset);
  
  % LOOP OVER WINDOW
  votesT = zeros(maxY,windowSize);
  for t=1:windowSize
    z = Zt(:,t);
    for c1=1:maxY-1
      for c2=c1+1:maxY
        Uz = models{c1,c2}.U * z;
        Vz = models{c1,c2}.V * z;
        sq_Uz = Uz'*Uz;
        sq_Vz = Vz'*Vz;
        guess = sign(sq_Uz-sq_Vz);
        votesT(c1,t) = votesT(c1,t) + (guess>0);
        votesT(c2,t) = votesT(c2,t) + (guess<0);
        y = (Yt(t)==c1) - (Yt(t)==c2);
        if (y==0)
          continue;
        end
        % PASSIVE
        score = y*(sq_Uz-sq_Vz);
        if (score>1)
          continue;
        end  
        % AGGRESSIVE
        sq_z = z'*z;
        alpha = fminbnd(@(alpha) f(alpha,sq_Uz,sq_Vz,y),0,1,opts);
        models{c1,c2}.U = models{c1,c2}.U + (alpha/(y-alpha))*(Uz*z')/sq_z;
        models{c1,c2}.V = models{c1,c2}.V - (alpha/(y+alpha))*(Vz*z')/sq_z;
      end
    end
  end
  
  % TEST ERROR FROM CURRENT MODEL
  votesE = zeros(maxY,numTestExamples);
  for c1=1:maxY-1
    for c2=c1+1:maxY
      Uz = models{c1,c2}.U * Ze;
      Vz = models{c1,c2}.V * Ze;
      vote1 = (sum((Uz.*Uz)-(Vz.*Vz),1)>0);
      vote2 = 1-vote1;
      votesE(c1,:) = votesE(c1,:) + vote1;
      votesE(c2,:) = votesE(c2,:) + vote2;
    end
  end
  
  % ALIGNED AVERAGE
  for c1=1:maxY-1
    for c2=c1+1:maxY
      [U,V] = align_lorentz(models{c1,c2}.U,models{c1,c2}.V); 
      U = align_ortho(U,models{c1,c2}.Ua);                    
      V = align_ortho(V,models{c1,c2}.Va);                    
      Ua = ((window-1)*models{c1,c2}.Ua + U)/window;       
      Va = ((window-1)*models{c1,c2}.Va + V)/window;      
      [Ua,Va] = align_lorentz(Ua,Va);                  
      models{c1,c2}.Ua = Ua;                          
      models{c1,c2}.Va = Va;                         
    end
  end

  % TEST ERROR FROM AVERAGED MODEL
  votesA = zeros(maxY,numTestExamples);
  for c1=1:maxY-1
    for c2=c1+1:maxY
      Uz = models{c1,c2}.Ua * Ze;
      Vz = models{c1,c2}.Va * Ze;
      vote1 = (sum((Uz.*Uz)-(Vz.*Vz),1)>0);
      vote2 = 1-vote1;
      votesA(c1,:) = votesA(c1,:) + vote1;
      votesA(c2,:) = votesA(c2,:) + vote2;
    end
  end
  
  % TABULATE VOTES
  [~,guessT] = max(votesT);
  [~,guessE] = max(votesE);
  [~,guessA] = max(votesA);
  mErr = mean(guessT~=Yt);
  eErr = mean(guessE~=Ye);
  aErr = mean(guessA~=Ye);
  if (window==1)
    cErr = mErr;
  else
    cErr = ((window-1)*cErr + mErr)/window;
  end

  % PRINT, CLEAN UP
  fprintf('%04d   %1.4f  %1.4f  %1.4f  %1.4f  %6f\n',...
    window,mErr,cErr,eErr,aErr,toc(tWindow));
  delete(labels_filename);
  delete(images_filename);
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% DONE

toc(tStart);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
