function M = centeredmatrixfactory(m, n, rows_or_cols)
% Linear manifold struct. for optimization over matrices with centered cols
%
% function M = centeredmatrixfactory(m, n)
% function M = centeredmatrixfactory(m, n, 'cols')
% function M = centeredmatrixfactory(m, n, 'rows')
%
% Returns M, a structure for Manopt describing the Euclidean space of
% m-by-n matrices whose columns sum to zero (or whose rows sum to zero,
% if 'rows' is passed as last input).
%
% The metric is the standard Frobenius distance and associated trace inner
% product. Matrices on M, denoted by X, have size mxn and obey
% X*ones(n, 1) = 0 (centered columns) or ones(1, m)*X = 0 (centered rows).
%
% See also: euclideanfactory

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, July 3, 2015.
% Contributors: 
% Change log: 
%
%   Jan. 6, 2017 (NB):
%       M.tangent = M.proj now, instead of being identity. This is notably
%       necessary so that checkgradient will pick up on gradients that do
%       not lie in the appropriate tangent space.

    if ~exist('rows_or_cols', 'var') || isempty(rows_or_cols)
        rows_or_cols = 'cols';
    end
    
    % Define a centering operator: it subtracts the mean column or row.
    switch lower(rows_or_cols)
        case 'cols'
            center = @(X) bsxfun(@minus, X, mean(X, 2));
            M.dim = @() m*n - m;
        case 'rows'
            center = @(X) bsxfun(@minus, X, mean(X, 1));
            M.dim = @() m*n - n;
        otherwise
            error('The third input must be either ''rows'' or ''cols''.');
    end
    
    % This is a non-standard function to have in a Manopt manifold.
    % It is included because it might be helpful in some situations.
    M.center = center;

    M.name = @() sprintf('Space of size %d x %d matrices with centered %s', ...
                         m, n, lower(rows_or_cols));
    
    M.inner = @(x, d1, d2) d1(:).'*d2(:);
    
    M.norm = @(x, d) norm(d, 'fro');
    
    M.dist = @(x, y) norm(x-y, 'fro');
    
    M.typicaldist = @() sqrt(M.dim());
    
    M.proj = @(x, d) center(d);
    
    M.egrad2rgrad = M.proj;
    
    M.ehess2rhess = @(x, eg, eh, d) center(eh);
    
    M.tangent = M.proj;
    
    M.exp = @exp;
    function y = exp(x, d, t)
        if nargin == 3
            y = x + t*d;
        else
            y = x + d;
        end
    end
    
    M.retr = M.exp;
    
    M.log = @(x, y) y-x;

    M.hash = @(x) ['z' hashmd5(x(:))];
    
    M.randvec = @(X) randvec();
    function U = randvec()
        U = center(randn(m, n));
        U = U / norm(U, 'fro');
    end
    
    M.rand = @() center(randn(m, n));
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(x) zeros(m, n);
    
    M.transp = @(x1, x2, d) d;
    
    M.pairmean = @(x1, x2) .5*(x1+x2);
    
    M.vec = @(x, u_mat) u_mat(:);
    M.mat = @(x, u_vec) reshape(u_vec, [m, n]);
    M.vecmatareisometries = @() true;

end
