%% learn svd on image rotations and estimate speed
%% 8x8 image in 10x10 frame
%%  run raw Van Hateren data
clearvars;%close all
%%
%- down-sampling
downsample=1; % 1-no downsampling
dosmooth=0;
par.rndseed=2021;rng(par.rndseed);
%% parameters a-la Cadieu
p.data_root='~/Documents/MATLAB/invariance/video/twolayer-master/data/vid075-chunks';
p.num_chunks = 50;
p.imsz=128;
p.chunk=64;% size of each movie chunk  
p.whitening.pixel_noise_fractional_variance = .01;
p.whitening.pixel_noise_variance_cutoff_ratio = 1.25; % 1 + var(signal)/var(noise)

%- patches
p.BUFF=4; % left-right limit for patches
p.topmargin=15;
szpatch=14;
p.patch_sz=szpatch*downsample;
disp(p)
%% Loading data from raw van Hateren
tic;
full_video=zeros(128,128,64*50);
for iter_im_idx=1:50
    vid_van=read_chunk(p.data_root,...
        iter_im_idx,128,64);
    full_video(:,:,(iter_im_idx-1)*64+1:64*iter_im_idx)=vid_van(:,:,:);
end
toc
%% prep data
tic;
T=1e6;
X=zeros(szpatch^2,T);
randRows=rand(1,T);
randCols=rand(1,T);
randFrame=randi( size(full_video,3), 1, T );
for t=1:T
    patch = crop2d(full_video,p, randFrame(t), randRows(t),randCols(t));
    if downsample>1  && dosmooth % downsample-smooth
        smoothed = smooth( patch, downsample );
        X(:,t) = smoothed(:);
    else % no smooth or no downsample: 
        tmp = patch(1:downsample:end,1:downsample:end);
        X(:,t) = tmp(:);
    end
end
X=X./vecnorm(X);
[zcaMatrix,xmean] = whiten( X, p.whitening.pixel_noise_variance_cutoff_ratio, p.whitening.pixel_noise_fractional_variance );
X=zcaMatrix*(X-xmean);
Cx=X*X'/T;
toc 
%%  
clear full_video
%% learning params
par.len=szpatch;
par.dcut=6;
par.kk=20;%floor(n/2);
par.angleMax=45; % degrees
par.iter=T;
par.etaSVD=0.01;
par.etaPCA=0.0002;
% par.rndseed=2021;
% rng(par.rndseed);
disp(par)
%% image params
center=par.len/2+.5;
dist=sqrt( ((1:par.len)-center).^2+((1:par.len)-center)'.^2 );
mask=dist<=par.dcut;
%figure,imagesc(mask)
n=sum(mask(:)); % vector length
remask=find(mask); % to restore image
[r,c]=find(mask);
diam=max(c)-min(c)+1; % size of active region
cutout=min(c):max(c); % size of active region
%% GHA prep
angles=par.angleMax*rand(1,par.iter);
uinit=randn(n,par.kk);u=uinit./vecnorm(uinit);
vinit=randn(n,par.kk);v=vinit./vecnorm(vinit);
% speedest=nan(par.iter,par.kk);
% sval=nan(par.iter,par.kk);
% fitLog=nan(par.iter,par.kk);
aveChi=zeros(n);
du=nan(1,par.iter);
%% SVD loop
tic;
for i=1:par.iter
 
    [x,y] = rotate(X(:,i),par.len,mask,angles(i));    
    [u1,v1] = skewGhaXYStep(x,y,u,v,par.etaSVD);
    du(i)=norm([u,v]-[u1,v1],'fro');
    u=u1; v=v1;
    
    aveChi = aveChi*(1-1/par.iter) + (y*x'-x*y')/par.iter;

    if any(isnan(u(:))),error('nan');end
    if any(isnan(v(:))),error('nan');end
    
    if rem(i,par.iter/20)==0
        fprintf('%d    %f  %f\n', i, toc, max(abs(u(:)))); end
end
disp(['SVD time ',num2str(toc),' sec'])
%figure,plot(du)
%% Speed estimation loop
speedest=nan(T,par.kk);
x=nan(n,T);
xplus=nan(n,T);
tic;
for i=1:par.iter
 
    [ximg,yimg] = rotate(X(:,i),par.len,mask,angles(i));    
    x(:,i)=ximg;
    xplus(:,i)=yimg;

    ay=u'*yimg; by=v'*yimg; ylen=sqrt(ay.^2+by.^2);
    ax=u'*ximg; bx=v'*ximg; xlen=sqrt(ax.^2+bx.^2);
    speedest(i,:)=asin( (ay.*bx-by.*ax)./ylen./xlen );

end
disp(['Estimation time ',num2str(toc),' sec'])
corr(angles',speedest)
corr(angles',mean(speedest,2))
mdl=fitlm(speedest,angles);
disp(mdl.Rsquared)
%%
tilefig(par,cutout,remask,u,v); 
%%
save('rot_d12.mat','remask','szpatch','angles','x','xplus','speedest');
%% --------------------------------
%%
function X = crop2d(F,p, i, rowSeed,colSeed) % adapted from Cadieu
row=p.topmargin+p.BUFF+ceil((p.imsz-p.patch_sz-(p.topmargin+2*p.BUFF))*rowSeed);
col=p.BUFF+ceil((p.imsz-p.patch_sz-2*p.BUFF)*colSeed);
X=F( row:row+p.patch_sz-1, col:col+p.patch_sz-1, i );
end
%%
function c=smooth(a,s)
[p,q,r]=size(a);
b=squeeze(mean(reshape(a,s,p/s,q,r)));
c=squeeze(mean(reshape(b,p/s,s,q/s,r),2));
end
%%
function [x,y] = rotate(imgcol,len,mask,angle)
    img=reshape(imgcol,len,len);
    imrot=imrotate(img,angle,'bilinear','crop');
    x_=img(mask);
    x=x_/norm(x_);
    y_=imrot(mask);
    y=y_/norm(y_);
end
%%
function tilefig(par,cutout,remask,uu,vv)
    figure;
    %pos=get(gcf,'Position');pos(3)=(par.kk/2)*pos(4);set(gcf,'Position',pos);
    tiles=tiledlayout(8,par.kk/4,'TileSpacing','none','Padding','none');
    for i=1:par.kk
        uimg=zeros(par.len);
        uimg(remask)=uu(:,i);
        nexttile(i),imagesc(uimg(cutout,cutout)),colormap('gray'),xticklabels([]),yticklabels([]);axis equal tight;
        %title(ds(2*i,2*i));
        vimg=zeros(par.len);
        vimg(remask)=vv(:,i);
        nexttile(par.kk+i),imagesc(vimg(cutout,cutout)),colormap('gray'),xticklabels([]),yticklabels([]);axis equal tight;
    end
end
