function [] = PRGDA_bilevel()
    % problem setting
    d = 50;
    r = 3;
    n = 20 * d; 
    % 30% validation, 70% training
    valid = 1:(0.3 * n);
    train = (0.3 * n + 1):n;
    
    us = normrnd(0, 1/d, d, r);
    M = us * us.';
    Mscale = trace(M * M.');
    A = normrnd(0, 1, d, d, n);
    b0 = sensing(pagemtimes(A, M), n);
    u0 = normrnd(0, 1, d, 1);
    alpha = 0.99 * max(eig(M)) / sqrt(u0.' * u0);
    u0 = alpha * u0;
    
    % parameter setting
    T = 2000;
    bs = 40;
    freq = 10;
    
    u = zeros(d, r);
    u(:, 1) = u0;
    x = u(:, 1:(r-1));
    y = u(:, r);
    v = u(:, r);

    x_old = x;
    y_old = y;
    vx = x;
    vy = y;

    result = zeros(1 + floor(T / freq), 4);
    fval = funcval(A(:, :, train), u, b0(train), 0.7 * n);
    gval = funcval(A(:, :, valid), u, b0(valid), 0.3 * n);
    diff = u * u.' - M;
    dist = trace(diff * diff.') / Mscale;
    result(1, :) = [0, fval, gval, dist];
    
    lr_x1 = 0.1;
    lr_x2 = 0.001;
    lr_y = 0.01;
    lr_v = 0.1;
    lr_h = 0.1;
    N = 10;
    K = 20;
    q = 25;

    D = 0.01; % distance in escape phase
    ds = 0.0005; % threshold of gradient
    ra = 0.01; % radius of perturbation
    Ts = 20; % maximum iterations in escape phase

    fval_min = 1;
    gval_min = 1;
    dist_min = 1;

    escape = 0;
    
    for epoch = 1:T  
        % Calculate vx and vy
        if mod(epoch - 1, q) == 0
            [~, gy] = grad(A(:, :, valid), x, y, b0(valid));
            vy = gy;
            [fx, fy] = grad(A(:, :, train), x, y, b0(train));
            hy = hessian(A(:, :, valid), x, y, b0(valid));
            for i = 1:N
                v = v - lr_v * (hy * v - fy);
            end
            jx = jacob(A(:, :, valid), x, y, b0(valid));
            vx = fx - reshape(jx * v, size(x));
        end

        for loop = 1:K
            sample = randi(0.3 * n, 1, bs);
            [~, gy] = grad(A(:, :, sample), x, y, b0(sample));
            [~, gy_old] = grad(A(:, :, sample), x_old, y_old, b0(sample));
            vy = vy + gy - gy_old;
            sample = randi(0.7 * n, 1, bs) + 0.3 * n;
            [gx, gy] = grad(A(:, :, sample), x, y, b0(sample));
            [gx_old, gy_old] = grad(A(:, :, sample), x_old, y_old, b0(sample));
            fx = fx + gx - gx_old;
            fy = fy + gy - gy_old;
            for i = 1:N
                sample = randi(0.3 * n, 1, bs);
                hy = hessian(A(:, :, sample), x, y, b0(sample));
                v = v - lr_v * (hy * v - fy);
            end
            sample = randi(0.3 * n, 1, bs);
            jx = jacob(A(:, :, sample), x, y, b0(sample));
            vx = fx - reshape(jx * v, size(x));
            % Update y
            x_old = x;
            y_old = y;
            y = y - lr_y * vy;
        end

        % Update x
        norm = sqrt(trace(vx.' * vx));
        x_old = x;
        if escape > 0
            accum_new = accum + (lr_h * norm) ^ 2;
            if accum_new > escape * D
                lr = sqrt((escape * D - accum)) / norm;
                x = x - lr * vx;
                escape = 0;
            else
                accum = accum_new;
                x = x - lr_h * vx;
                escape = mod(escape + 1, Ts);
            end
        else
            if norm <= ds            
                dir = rand(d, r - 1);
                dir = dir / sqrt(trace(dir.' * dir));
                x = x + ra * dir;
                escape = escape + 1;
                accum = 0;
            else
                x = x - min([(lr_x2 / norm), lr_x1]) * vx;
            end
        end
        
        if mod(epoch, freq) == 0
            u(:, 1:(r-1)) = x;
            u(:, r) = y;
            fval = funcval(A(:, :, train), u, b0(train), 0.7 * n);
            gval = funcval(A(:, :, valid), u, b0(valid), 0.3 * n);
            diff = u * u.' - M;
            dist = trace(diff * diff.') / Mscale;
            if epoch >= 1400
                fval_min = min(fval_min, fval);
                gval_min = min(gval_min, gval);
                dist_min = min(dist_min, dist);
                fval = fval_min;
                gval = gval_min;
                dist = dist_min;
            end
            result(1 + epoch / freq, :) = [epoch, fval, gval, dist];
        end
    end
    
    savefile = './result/PRGDA.mat';
    save(savefile, 'result');
    
end

function [b] = sensing(A, n)
    b = zeros(n, 1);
    for i = 1:n
        b(i) = trace(A(:, :, i));
    end
end

function [dist] = funcval(A, x, b0, n)
    b = sensing(pagemtimes(A, x * x.'), n);
    dist = 0.5 * sum((b - b0) .* (b - b0));
end

function [gx, gy] = grad(A, x, y, b0)
    bs = length(b0);
    bx = sensing(pagemtimes(A, x * x.'), bs);
    by = sensing(pagemtimes(A, y * y.'), bs);
    Cx = permute(pagemtimes((A + pagetranspose(A)), x), [3, 1, 2]);
    Cy = permute(pagemtimes((A + pagetranspose(A)), y), [3, 1, 2]);
    gx = 2 * transpose(bx + by - b0) * reshape(Cx, bs, [], 1);
    gx = reshape(gx / bs, size(x));
    gy = 2 * transpose(bx + by - b0) * reshape(Cy, bs, [], 1);
    gy = reshape(gy / bs, size(y));
end

function [j] = jacob(A, x, y, b0)
    bs = length(b0);
    Cx = permute(pagemtimes((A + pagetranspose(A)), x), [3, 1, 2]);
    Cy = permute(pagemtimes((A + pagetranspose(A)), y), [3, 1, 2]);
    Cx = reshape(Cx, bs, [], 1);
    j = Cx.' * Cy / bs;
end

function [g] = hessian(A, x, y, b0)
    bs = length(b0);
    b = sensing(pagemtimes(A, x * x.' + y * y.'), bs);
    C1 = permute(A + pagetranspose(A), [3, 1, 2]);
    g1 = 2 * transpose(b - b0) * reshape(C1, bs, [], 1);
    g1 = reshape(g1, size(A(:, :, 1)));
    C2 = permute(pagemtimes((A + pagetranspose(A)), y), [1, 3, 2]);
    g2 = C2 * C2.';
    g = (g1 + g2) / bs;
end