using LinearAlgebra, JuMP, Random, Printf, DelimitedFiles, MATLAB
using NLopt, MosekTools

solver = optimizer_with_attributes(Mosek.Optimizer, "LOG" => 0)

randnsym = n -> Symmetric(randn(n,n));
randnsym_norm = n -> (A=Symmetric(randn(n,n)); A/norm(A));
mean = v -> sum(v)/length(v)
eye = n -> Symmetric(Diagonal(ones(n)))


function smat2vec(A)
    n = size(A,1);
    n2 = binomial(n+1,2);
    v = zeros(n2);
    k = 0
    for i = 1:n
        for j = i:n
            k = k + 1;
            v[k] = A[i,j];
        end
    end
    return v;
end

function matchingCost(n,m,A,X0)
    sdp0 = Model(solver);
    @variable(sdp0, CC[1:n, 1:n], Symmetric);
    @variable(sdp0, SS[1:n, 1:n], PSD);
    @variable(sdp0, lambda[1:m]);
    @constraint(sdp0, SS .== CC - sum(A[k].*lambda[k] for k in 1:m));
    @constraint(sdp0, tr(X0 * SS) == 0 );
    @constraint(sdp0, tr(CC) == 1 );
    JuMP.optimize!(sdp0);
    C = JuMP.value.(CC);
    S = JuMP.value.(SS);
    C = C/norm(C)
    ret = termination_status(sdp0);
    return (C,S,ret);
end

function solveSDP(A,b,C)
    model = Model(solver)
    @variable(model, X[1:n, 1:n], Symmetric)
    @constraint(model, psd, X in PSDCone())
    for i=1:m
        @constraint(model, tr(A[i]*X) == b[i])
    end
    @objective(model, Min, tr(C*X))
    JuMP.optimize!(model)
    X = JuMP.value.(X);
    S = Symmetric(dual(psd));
    ret = termination_status(model);
    return (X,S,ret)
end

function setParams(; i_xtol_rel=1e-7, i_ftol_abs=0, i_maxeval=999999, xtol_rel=1e-4, maxeval=999999, constrtol_abs=1e-8)
    params = Dict( "i_xtol_rel"=>i_xtol_rel, "i_ftol_abs"=>i_ftol_abs, "i_maxeval"=>i_maxeval, "xtol_rel"=>xtol_rel, "maxeval"=>maxeval, "constrtol_abs"=>constrtol_abs);
end

function localsolve(n,m,p,A,b,C, Y0, params)
    function cost(x::Vector, grad::Vector)
        Y = reshape(x,n,p);
        if length(grad) > 0
            grad[:] = 2*vec(C*Y);
        end
        return tr(Y'*C*Y)
    end
    function constraints(result::Vector, x::Vector, grad::Matrix)
        Y = reshape(x,n,p);
        if length(grad) > 0
            for i=1:m
                grad[:,i] = 2*vec(A[i]*Y);
            end
        end
        for i=1:m
            result[i] = tr(Y'*A[i]*Y)-b[i];
        end
    end
    if params["maxeval"]==0
        return (Y0,cost(vec(Y0),[]),0,"None")
    end

    inneropt = Opt(:LD_TNEWTON_PRECOND, n*p);
    inneropt.xtol_rel = params["i_xtol_rel"];
    inneropt.ftol_abs = params["i_ftol_abs"];
    inneropt.maxeval = params["i_maxeval"];
    opt = Opt(:AUGLAG, n*p);
    opt.local_optimizer = inneropt;
    opt.xtol_rel = params["xtol_rel"];
    opt.maxeval = params["maxeval"];

    opt.min_objective = cost;
    equality_constraint!(opt, constraints, fill(params["constrtol_abs"],m));

    (minf,minx,ret) = NLopt.optimize(opt, vec(Y0));
    numevals = opt.numevals;
    Y1 = reshape(minx,n,p);
    return (Y1,minf,numevals,ret);
end

function localsolve_Ninits(n,m,p,A,b,C,Ninits,X0,S0,params)
    gap = zeros(Ninits);
    errX = zeros(Ninits);
    feas = zeros(Ninits);
    nevals = zeros(Int,Ninits);
    for i = 1:Ninits
        println("i $i");
        Y0 = randn(n,p);
        (Y1,minf,numevals,ret) = localsolve(n,m,p,A,b,C, Y0, params);
        X1 = Symmetric(Y1*Y1');
        println("|X1*S| = $(norm(X1*S0))   |X0-X1| = $(norm(X0-X1))");
        gap[i] = norm(X1*S0);
        errX[i] = norm(X0-X1);
        feas[i] = norm([tr(A[i]*X1)-b[i] for i in 1:m]);
        nevals[i] = numevals;
    end
    return (gap,errX,feas,nevals);
end

function experiment1(p0)
    n = 50;
    m = binomial(p0+1,2);
    X0 = Matrix(Diagonal([ones(p0); zeros(n-p0)]));

    Nsdps = 100;
    Ninits = 1;
    Prange = (p0-3):(p0+3);
    Gap = zeros(length(Prange),Nsdps,Ninits)
    ErrX = zeros(length(Prange),Nsdps,Ninits)
    Feas = zeros(length(Prange),Nsdps,Ninits)
    params = setParams(maxeval=Int(2e6),i_maxeval=Int(20e3))

    for k=1:Nsdps
        println("\n##########################\nk $k");
        Random.seed!(k);
        A = [randnsym_norm(n) for i in 1:m]
        b = [tr(A[i]*X0) for i in 1:m]
        (C,S0) = matchingCost(n,m,A,X0);

        for (i,p) in enumerate(Prange)
            println("p $p");
            Random.seed!(hash([k,p]));
            (gap,errX,feas,nevals) = localsolve_Ninits(n,m,p,A,b,C,Ninits,X0,S0,params);
            Gap[i,k,:] = gap;
            ErrX[i,k,:] = errX;
            Feas[i,k,:] = feas;
        end
        if k%10==0
            println("saving data")
            write_matfile("data/exp1_p$(p0).mat",Gap=Gap,ErrX=ErrX,Feas=Feas,p=Prange)
        end
    end
end

function perturbSDP(n,m,A0,X0,sigma)
    noise = [1/n*randnsym(n) for i in 1:m]
    A = A0 + sigma*noise;
    b = [tr(A[i]*X0) for i in 1:m]
    (C,S,ret) = matchingCost(n,m,A,X0);
    return (A,b,C,S,X0)
end

function experiment2(p0,k0,SigmaRange)
    n = 50;
    p = p0;
    m = binomial(p0+1,2);
    X0 = Matrix(Diagonal([ones(p0); zeros(n-p0)]));

    nrepeats = 100;
    Gap = zeros(length(SigmaRange),nrepeats);
    ErrX = zeros(length(SigmaRange),nrepeats);

    Random.seed!(k0);
    A0 = [randnsym_norm(n) for i in 1:m];
    b0 = [tr(A0[i]*X0) for i in 1:m]
    (C0,S0,ret) = matchingCost(n,m,A0,X0);
    params = setParams(maxeval=50000)

    for (i,sigma) in enumerate(SigmaRange)
        for j in 1:nrepeats
            println("\n##########################\n");
            println("i $i\nj $j\n");
            flush(stdout);
            Random.seed!(hash([sigma+1,j]));
            Y0 = randn(n,p);
            if sigma == 0
                A=A0; b=b0; C=C0; S=S0; X=X0;
            else
                (A,b,C,S,X) = perturbSDP(n,m,A0,X0,sigma)
            end
            if length(SigmaRange)==2
                fname = "data/exp2-full/in_i$(i)_j$(j).mat"
                write_matfile(fname,Y0=Y0,S0=S,X0=X)
            end
            (Y1,minf,numevals,ret) = localsolve(n,m,p,A,b,C, Y0, params);
            Libc.flush_cstdio()
            X1 = Symmetric(Y1*Y1');
            Gap[i,j] = norm(X1*S);
            ErrX[i,j] = norm(X-X1);
        end
    end

    write_matfile("data/exp2.mat",Gap_p7=Gap,ErrX_p7=ErrX)
end

### Generate data for Figure 1
### Call the function experiment1(p0) with p0 = 4,7,12
# experiment1(4)
# experiment1(7)
# experiment1(12)

### Generate data for Figures 2a, 2b
### Call the function experiment2(p0,k0), where p0 = 7 and
### where k0 is the index of an experiment with large optimality gap
### We use k0=43 below. The README indicates how to obtain such a k0
# experiment2(7,43,0:.02:.2)   # fig2a
# experiment2(7,43,[0,.2])    # fig2b
