
%The main function. Want to solve min_x ||Ax-b||_p^p + reg*||Ax-b||_2^2 

% For notation: m is the dimension, n is the number of constraints
function [final,call]= LewisIRLS2(eps,p,A,b,N,vv,reg)
    [n, d] = size(A);
    x = InitialSoln(A, b, N, vv);
    T = round(log(n));
    [w] = LewisWeights(A, T);
    % w = ones(n,1) * d/n;

    lb = norm(A*x-b, 2)^p/n^(p/2-1)+reg*norm(A*x-b, 2)^2;
    [final,call] = pNorm_with_initial_vector(eps,p,A,b,N,x,lb,w,reg);

    fprintf("Iteration: %d\n:" ,call);

end

% The main algorithm.
function [final_vec,call] = pNorm_with_initial_vector(eps,p,A,b,N,x,lb,w,reg)
    % lb is a lower bound on the objective
    % Initial Solution
    current = norm(A*x-b,p)^p + reg*norm(A*x-b,2)^2; 
    best = x;
    % Check if the initial solution is 0. In that case return 0.
    if current ==0					                             
        disp("Norm = 0");
        % println("Max gradient entry:", max_gradient_entry);
    	call = 1;
        final_vec = x;
        return
    end
    call = 0;
    
    % Initial padding. An upper bound on (Initial_Solution - OPT)/16p.
    i = (current-lb)*6; 
    % Termination condition, if this is achieved, we have a (1+eps)-approximate solution
    M = i;
    [E, delta, cnt, M] = LpSolveReg(p,A,b,N,M,x,w,reg);
    call = call + cnt;
    alpha = LineSearchObj(A,b,x,p,delta,reg);  
    x = x-alpha*delta;
    obj = norm(A*x-b,p)^p + reg*norm(A*x-b,2)^2;
    while  obj < current - eps 
        best = x;
        current = obj;
        

        disp("here");
        disp([M E]);
        

        [E, delta, cnt, M] = LpSolveReg(p,A,b,N,M,x,w,reg);
        alpha = LineSearchObj(A,b,x,p,delta,reg);  
        x = x-alpha*delta;
        call = call + cnt;
        obj = norm(A*x-b,p)^p + reg*norm(A*x-b,2)^2;
    end
    final_vec = best;
end



function [ E, f, cnt, M ] = LpSolveReg(p,A,b,N,i,x,w,reg)
    %Solve min ||A*f||_p : B' * f = d
    %   output (1+eps) approximate solution consisting of
    %          flow f and resistances r

    C = p*reg^(-1/(p-2));

    M = i;
    k = (abs(A*x-b)).^(p-2) ;
    s = p*(p-1)*k + 2*reg;
    g = -(p*k.*(A*x-b) + 2*reg*(A*x-b))/M;
    
    [n, d] = size(A);
    r = ones(n, 1) * d/n + w;

    v = 2*(sum(w)+d)*s + M*C^2/2*r;
    u = s+M*C^2/2*r / sum(r);

    S = 11*d^(1/3)/C;
    E = M;

    cnt = 0;
    while sum(r) < 2*(sum(w) + d)
        [f, deltaf, E] = electric(v,u, g, A, N);
        cnt = cnt + 1; 
        if E > M 
            M = M / (max(2, E/M));
            r = ones(n, 1) * d/n + w;
        else
            if max(deltaf) < 11/C
                break
            end
            
            [nr] = update(r, deltaf, S, C);
            if sum(nr) == sum(r)
                break
            end
            r = nr;
        end 
        
        g = -(p*k.*(A*x-b) + 2*reg*(A*x-b))/M;
        v = 2*(sum(w)+d)*s + M*C^2/2*r;
        u = s+M*C^2/2*r / sum(r);
        
        
    end
    % fprintf('Lq solver finished after %d iterations\n', cnt);
end
    
function [ r ] = update(cr, cf, S, C)
    n = size(cr, 1);
    [m, i] = min(cf);
    r = cr;
    if m > S
        r(i) = r(i) + 1;
    else
        for j=1:size(r,1)
            if cf(j) >= 10/C
                r(j) = 1/52*cr(j)*cf(j)^2*C^2;
            end
        end
    end
end


function [ f ] = InitialSoln(A, b, N, vv)
    

    Calc = transpose(A)*A;
    t = transpose(A) * b;

    if max(abs(N)) == 0
        f = Calc \ t;
        % f = pinv(Calc) * t;
    else
        S = N / Calc;
        L = S * transpose(N);
        o = vv - S * t;
        k = L\ o;
        % k = pinv(L) * o;
        z = transpose(N) * k + t;
        f = Calc \ z;
        % f = pinv(Calc) * z;
    end
    
end


function [ f, deltaf, E ] = electric(v, u, g, A, N)
    [n, d] = size(A);
    B = transpose(A)* g;
    R = spdiags(v, zeros(1,1), n, n);
    Calc = transpose(A)*R*A;
    % inverseCalc = pinv(Calc);
    if max(abs(N)) == 0
        % inverseCalc = inv(Calc);
        S = Calc \ B;
        % S = inverseCalc * B;
    else
        Z = N / Calc;
        K = Z * transpose(N);
        P = K \ Z;
        V = -transpose(N) * P * B + B;
        S = Calc \ V;

        % inverseK = pinv(K);
        % S = inverseCalc*(-transpose(N)*inverseK*N*inverseCalc*B + B);
    end
    denom = transpose(B) * S;
    f = - S / denom;
    deltaf = A*f;
    R = spdiags(u, zeros(1,1), n, n);
    E = transpose(deltaf) * R * deltaf;
end


% A function that calculates the gradient of ||A(x-scale*delta)-b||_p^p + reg*||A(x-scale*delta)-b||_2^2.
% Here A,b are as in the input. We use this in the next function to find a
% scale so that given the current solution x and the next step delta, we
% can scale delta so as to make maximum progress.
function obj = GradientScaledObj(scale,p,z,w,reg)
    v = z - scale*w;
    y = abs(v).^(p-2);
    y1 = v .* (p*y + 2*reg);
    obj = -1 * (w' * y1);
end

% This finds a scaling so that given the current solution x and the next
% step delta, we can scale delta so as to make maximum progress.
function alpha = LineSearchObj(A,b,x,p,delta,reg)
    L = -3;
    U = 3;
    w = A * delta;
    z = A * x - b;
    while GradientScaledObj(U,p,z,w,reg)<0
        L = U;
        U = 2*U;
    end
    while GradientScaledObj(L,p,z,w,reg)>0
        U = L;
        L = 2*L;
    end
    assert (GradientScaledObj(L,p,z,w,reg) < 0);
    assert (GradientScaledObj(U,p,z,w,reg) > 0);
    while abs(U-L)>1e-1
        if (GradientScaledObj((L+U)/2,p,z,w,reg)>0)
            U = (L+U)/2;
        else
            L = (L+U)/2;
        end
    end
    alpha = (L+U)/2;

end