function [spec, topvar, polar] = polarspec(u, n, remask, diam, varargin)
% POLARSPEC Calculate spectrum of filter along the polar coordinate.
%   [spec, topvar] = POLARSPEC(u, n, remask, diam) calculates the Fourier
%   spectrum of the given filter in the polar direction.
% 
%   Specifically, the function first converts the vector `u` into an
%   `n x n` matrix `uimg` (`uimg(remask) = u`), then converts this to polar
%   coordinates (with radius capped to `diam / 2`), then performs an SVD
%   and returns `spec`, the absolute value of the FFT of the top right 
%   singular vector (corresponding to the angular direction).
%
%   The size of the polar-coordinates representation is by default 10 x 16,
%   but can be controlled:
%   spec = POLARSPEC(..., 'bins', [nbins_radial, nbins_angular])
% 
%   The function also returns `topvar`, the fraction of total variance
%   explained by the top SVD component.
%
%   [spec, topvar, polar] = POLARSPEC(...) also returns the matrix
%   representation of `u` in polar coordinates.

% parse optional arguments
parser = inputParser;
parser.CaseSensitive = true;
parser.FunctionName = mfilename;

parser.addParameter(...
    'bins', [10, 16], @(v) isnumeric(v) && isvector(v) && length(v) == 2 ...
);

% show defaults if requested
if nargin == 1 && strcmp(u, 'defaults')
    parser.parse;
    disp(parser.Results);
    return;
end

% parse options
parser.parse(varargin{:});
params = parser.Results;

% convert u to cartesian-coordinate matrix
uimg = zeros(n);
uimg(remask) = u;

% convert to polar coordiantes
polar = zeros(params.bins);
% counts tracks how many values were averaged in each bin
counts = zeros(params.bins);

center = (1 + n) / 2;
for i = 1:n
    for j = 1:n
        % find polar coordinates for each pixel
        r = norm([i - center, j - center]);
        theta = atan2(j - center, i - center) + pi;

        r_idx = floor(1 + 2 * r * (params.bins(1) - 1) / diam);
        
        if r_idx > params.bins(1)
            continue;
        end

        theta_idx = round(1 + theta * (params.bins(2) - 1) / (2 * pi));
        if theta_idx > params.bins(2)
            theta_idx = params.bins(2);
        end

        % update average of all values for each bin
        crt_count = counts(r_idx, theta_idx);
        crt_sum = crt_count * polar(r_idx, theta_idx);
        
        new_sum = crt_sum + uimg(i, j);
        polar(r_idx, theta_idx) = new_sum / (crt_count + 1);

        counts(r_idx, theta_idx) = counts(r_idx, theta_idx) + 1;
    end
end

% SVD to check that this is rank 1
[~, svd_s, svd_v] = svd(polar);
svd_s = diag(svd_s);
topvar = svd_s(1) .^ 2 / sum(svd_s .^ 2);

% FFT of angular part for top principal component
f = fft(svd_v(:, 1));
spec = abs(f(1:floor((length(f) + 1) / 2)));

end