function [] = PRGDA_minimax()
    % problem setting
    d = 50;
    r = 3;
    n = 20 * d;
    
    us = normrnd(0, 1/d, d, r);
    M = us * us.';
    Mscale = trace(M * M.');
    A = normrnd(0, 1, d, d, n);
    b0 = sensing(pagemtimes(A, M), n);
    u0 = normrnd(0, 1, d, 1);
    alpha = 0.99 * max(eig(M)) / sqrt(u0.' * u0);
    u0 = alpha * u0;
    
    % parameter setting
    T = 2000;
    D = 0.01; % distance in escape phase
    ds = 0.0005; % threshold of gradient
    ra = 0.01; % radius of perturbation
    Ts = 20; % maximum iterations in escape phase
    lr_x1 = 0.1;
    lr_x2 = 0.001;
    lr_y = 0.01;
    lr_h = 0.1;
    q = 25;
    bs = 40;
    K = 5;
    freq = 10;
    
    % Pullback
    x = zeros(d, r);
    x(:, 1) = u0;
    y = ones(n, 1) / n;
    result = zeros(1 + floor(T / freq), 4);
    val = fval(A, x, b0, n);
    nabla = grad_phi(A, x, b0, n);
    grad_norm = sqrt(trace(nabla.' * nabla));
    diff = x * x.' - M;
    dist = trace(diff * diff.') / Mscale;
    result(1, 2:4) = [val, grad_norm, dist];
    
    v = grad_x(A, x, y, b0, n);
    u = grad_y(A, x, y, b0, n);
    escape = 0;
    
    for epoch = 1:T
        norm = sqrt(trace(v.' * v));
        x_old = x;
        if escape > 0
            accum_new = accum + (lr_h * norm) ^ 2;
            if accum_new > escape * D
                lr = sqrt((escape * D - accum)) / norm;
                x = x - lr * v;
                escape = 0;
            else
                accum = accum_new;
                x = x - lr_h * v;
                escape = mod(escape + 1, Ts);
            end
        else
            if norm <= ds            
                dir = rand(d, r);
                dir = dir / sqrt(dir.' * dir);
                x = x + ra * dir;
                escape = escape + 1;
                accum = 0;
            else
                x = x - min([(lr_x2 / norm), lr_x1]) * v;
            end
        end
        
        if mod(epoch, q) == 0
            v = grad_x(A, x, y, b0, n);
            u = grad_y(A, x, y, b0, n);
        end
        
        y_old = y;
        for loop = 1:K
            sample = randi(n, 1, bs);
            v = v + grad_x(A(:, :, sample), x, y(sample), b0(sample), n) ...
                - grad_x(A(:, :, sample), x_old, y_old(sample), b0(sample), n);
            u(sample) = u(sample) + grad_y(A(:, :, sample), x, y(sample), b0(sample), n) ...
                - grad_y(A(:, :, sample), x_old, y_old(sample), b0(sample), n);
            x_old = x;
            y_old = y;
            y = projection(y + lr_y * u, n);
        end
        
        if mod(epoch, freq) == 0
            val = fval(A, x, b0, n);
            nabla = grad_phi(A, x, b0, n);
            grad_norm = sqrt(trace(nabla.' * nabla));
            diff = x * x.' - M;
            dist = trace(diff * diff.') / Mscale;
            result(1 + epoch / freq, :) = [epoch, val, grad_norm, dist];
            if epoch + freq > T
                lambda = lambda_hessian(A, x, b0, n, d, r);
            end
        end
    end
    
    savefile = './result/Pullback.mat';
    save(savefile, 'result', 'lambda');
end

function [yp] = projection(y, n)
    z = sort(y, 'descend');
    sum = 0;
    t = 0;
    for i = 1:n
        sum = sum + z(i);
        t = (sum - 1) / i;
        if i < n && z(i + 1) <= t && t < z(i)
            break
        end
    end
    yp = max(y - t, 0);
end

function [b] = sensing(A, n)
    b = zeros(n, 1);
    for i = 1:n
        b(i) = trace(A(:, :, i));
    end
end

function [val] = fval(A, x, b0, n)
    b = sensing(pagemtimes(A, x * x.'), n);
    dist = 0.5 * (b - b0) .* (b - b0);
    ystar = projection(dist + ones(n, 1) / n, n);
    val = ystar.' * dist;
end

function [g] = grad_x(A, x, y, b0, n)
    bs = length(b0);
    b = sensing(pagemtimes(A, x * x.'), bs);
    C = permute(pagemtimes((A + pagetranspose(A)) / 2, x), [3, 1, 2]);
    g = 2 * transpose(y .* (b - b0)) * reshape(C, bs, [], 1);
    g = reshape(n * g / bs, size(x));
end

function [g] = grad_y(A, x, y, b0, n)
    bs = length(b0);
    b = sensing(pagemtimes(A, x * x.'), bs);
    dist = 0.5 * (b - b0) .* (b - b0) + ones(bs, 1) / n;
    g = n * (dist - y) / bs;
end

function [g] = grad_phi(A, x, b0, n)
    b = sensing(pagemtimes(A, x * x.'), n);
    dist = 0.5 * (b - b0) .* (b - b0);
    ystar = projection(dist + ones(n, 1) / n, n);
    C = permute(pagemtimes((A + pagetranspose(A)) / 2, x), [3, 1, 2]);
    g = 2 * transpose(ystar .* (b - b0)) * reshape(C, n, [], 1);
    g = reshape(g, size(x));
end

function [l] = lambda_hessian(A, x, b0, n, d, r)
    C = (A + pagetranspose(A)) / 2;
    b = sensing(pagemtimes(A, x * x.'), n);
    dist = 0.5 * (b - b0) .* (b - b0);
    ystar = projection(dist + ones(n, 1) / n, n);
    h1 = reshape(pagemtimes(C, x), [], n, 1);
    h1 = 2 * (ystar.' .* h1) * h1.';
    h2 = transpose(ystar .* (b - b0)) * reshape(permute(C, [3, 1, 2]), n, [], 1);
    h2 = 2 * reshape(h2, d, d);
    hc = repmat({h2}, 1, r);
    h2 = blkdiag(hc{:});
    h = h1 + h2;
    l = min(eig(h));
end