function [U,D,err] = UDU_factorization_visualize(A,AT,b,U0,D0,ss,maxit)
%UDU_FACTORIZATION_PSD for least squares problem
%   Inputs: 
%   - A: handle for the linear measurement map
%   - AT: handle for the adjoint of the measurement map
%   - b: observations
%   - U0: initial estimate for U
%   - D0: initial estimate for D
%   - ss: stepsize
%   - maxit: number of iterations
%   Outputs:
%   - U: factor U
%   - D: factor D
%   - err: convergence error (least-squares)

n = size(D0,1);
U = U0; D = D0;
err = nan(maxit,1);
X = U*D*U'; X = 0.5*(X+X');

hfig1 = figure('Position',[100,100,900,1200]);
set(hfig1,'name','udu_normsu_noisy','numbertitle','off');
hfig2 = figure('Position',[600,100,900,1200]);
set(hfig2,'name','udu_diagd_noisy','numbertitle','off');
kk = 0;

for t = 1:maxit
    
    X = U*D*U'; X = 0.5*(X+X');
    AXb = A(X) - b;
    err(t) = norm(AXb)^2;
    AtXb = AT(AXb);
    gradU = (AtXb + AtXb')*U*D;
    gradD = U'*AtXb*U; % also try non-simultaneous updates    

    U = U - ss*gradU;
    nUfro = norm(U,'fro');
    if nUfro > 1
        U = U./nUfro;
    end

    D = diag(max(diag(D) - ss*diag(gradD),0));

    if any(t == [25,50,75,100,250,500,750,1000,2500,5000,7500,10000,25000,50000,75000,100000,250000,500000,750000,1000000])
        kk = kk+1;
        normU = norms(U,[],1);
        figure(hfig1)
        subplot(4,4,kk)
        plot(normU)
        ylim([0,1])
        title(['k = ',num2str(t)])
        drawnow

        diagD = diag(D);
        figure(hfig2)
        subplot(4,4,kk)
        plot(diagD,'r')
        ylim([0,10])
        title(['k = ',num2str(t)])
        drawnow

    end
    
end

end

