
using DynamicPolynomials, JuMP, SumOfSquares, MosekTools, LinearAlgebra

# 
#= 
Check SOS-concavity up to degree d of function p with respect to

x in the joint strategy space defined by inequality constraints 

$g_j >= 0$ for $g_j \in gs$ and equality constraints $h_j == 0$ for 
    
$h_j \in hs$. 
=# 
function checkSOSConcave(x, p, gs, hs; specify_degree=false, d::Integer = 2, specify_basis=false, basis="default")

    # Check if the degree is properly specified
    if specify_degree
        if d < 2
            throw(ArgumentError("Degree must be at least 2."))
        end
    else
        d = 2 # default degree
    end

    m = size(x, 1)
    hessian = differentiate(differentiate(p, x), x)
    # println("Hessian = ", hessian)

    Sh = isempty(hs) ? FullSpace() : algebraic_set(hs) # set of == constraints 
    Sg = isempty(gs) ? FullSpace() : basic_semialgebraic_set(FullSpace(), gs) # set of >= constraints 
    Sx = intersect(Sh, Sg) # semialgebraic set for x 
    # println("Sx = ", Sx)

    @polyvar z[1:m] # auxiliary polynomial variables for quadratic hessian z'Hz
    Sz = algebraic_set([1 - z'*z]) # unit sephere for the domain of z
    # println("Sz = ", Sz)
    
    # Create a model to solve the SDP
    model = SOSModel(Mosek.Optimizer)
    @variable(model, t)
    @objective(model, Min, t)
    if specify_degree
        @constraint(model, cref, -z' * hessian * z + t >= 0, domain=intersect(Sx, Sz), maxdegree=d)
    else
        @constraint(model, cref, -z' * hessian * z + t >= 0, domain=intersect(Sx, Sz))
    end
    # @constraint(model, cref, -z' * hessian * z + t >= 0, domain=intersect(Sx, Sz), maxdegree=d)
    optimize!(model)
    # println(solution_summary(model))

    return model
end 

#
#= 
Check stric SOS-monotonicity up to degree d of functions 

ps := [p1, p2, p3, ..., pn] with respect to multivariate variable

xs := [x1, x2, x3, ..., xn] in the joint strategy space defined by 

inequality constraints $g_j >= 0$ for $g_j \in gs$ and equality 

constraints $h_j == 0$ for $h_j \in hs$. 
=# 
function checkSOSMonotone(xs, ps, gs, hs; specify_degree=false, d::Integer = 2)

    # Check if the degree is properly specified
    if specify_degree
        if d < 2
            throw(ArgumentError("Degree must be at least 2."))
        end
    else
        d = 2 # default degree
    end

    n = size(ps, 1) # no. of players
    m = size(vcat(xs...), 1) # no. of dimensions
    # println("n = ", n)
    # println("m = ", m)
    vs = [differentiate(ps[i], xs[i]) for i in 1:n] # pseudo-gradient 
    # println("Pseudo-gradient = ", vs)
    jacobian = differentiate(vcat(vs...), vcat(xs...)) # jacobian 
    # println("Jacobian = ", jacobian)

    Sh = isempty(hs) ? FullSpace() : algebraic_set(hs) # set of == constraints 
    Sg = isempty(gs) ? FullSpace() : basic_semialgebraic_set(FullSpace(), gs) # set of >= constraints 
    Sx = intersect(Sh, Sg) # semialgebraic set for x 
    # println("S = ", Sx)

    @polyvar z[1:m] # auxiliary polynomial variables for quadratic hessian z'Hz
    Sz = algebraic_set([1 - z'*z]) # unit sephere for the domain of z
    # println("Sz = ", Sz)
    
    # Create a model to solve the SDP
    model = SOSModel(Mosek.Optimizer)
    @variable(model, t)
    @objective(model, Min, t)
    if specify_degree
        @constraint(model, cref, - z' * jacobian * z + t >= 0, domain=intersect(Sx, Sz), maxdegree=d)
    else
        @constraint(model, cref, - z' * jacobian * z + t >= 0, domain=intersect(Sx, Sz))
    end
    # println(cref)
    # @constraint(model, cref, -z' * jacobian * z + t >= 0, domain=intersect(Sx, Sz), maxdegree=d)
    optimize!(model)
    # println(solution_summary(model))

    return model
end 
