using DynamicPolynomials, JuMP, SumOfSquares, Random, Distributions, Combinatorics

# 
#= |

Given x = (x_1, x_2, ..., x_n) and y = (y_1, y_2, ..., y_m), p, q >= 1, 

    and p, q are odd, the function creates a homogeneous psd polynomial 

    b(x, y) of degree (2p + 2q) in the following form:

    b(x, y) = \sum_{i, j = 1, i <= j}^n \sum_{k, l = 1, k <= l}^m 

                \alpha_{ijkl} x_i^p x_j^p y_k^q y_l^q.

    To make sure that b(x, y) is psd, b(x, y) is constructed as follows:

    b(x, y) = (x^p)' M(y)' M(y) x^p,

    where x^p = (x_1^p, x_2^p, ..., x_n^p) and M(y) is a matrix of the 

    following form: 

    M(y)_{ij} = \sum_{k = 1}^m \beta_{ijk} y_k^q.

    Equivalently, we can write

    M(y) = \sum_{k = 1}^m M_k y_k^q,

    where M_k is an n-by-n real matrix for k = 1, 2, ..., n.

=# 

function create_psd_b(x, y, p, q; seed::Union{Nothing, Integer}=nothing)
    n = length(x)
    m = length(y)

    # Set the random seed for reproducibility
    if seed !== nothing
        Random.seed!(seed)
        println("Random number generator reseeded with seed: $seed")
    else 
        println("No seed provided, using default random number generator.")
    end

    # Create a random psd polynomial b(x, y)
    M_y = sum([rand(Uniform(-1, 1), n, n) * y[i]^q for i in 1:m])
    # println("M_y = ", M_y)
    b = only(x'.^p * M_y' * M_y * x.^p) # b(x, y) is psd
    # println("b(x, y) = ", b)
    
    return b
end

# 
#= 

Create a convex function f(x, y) according to Theorems in Appendix D in the manuscript where

    x = (x_1, x_2, ..., x_n) and y = (y_1, y_2, ..., y_m) for n, m >= 1.

    p and q are odd integers >= 1.
    
=# 

function highDegreeConvexPolynomial(x::Vector, y::Vector, p::Integer, q::Integer; seed::Union{Nothing, Integer}=nothing)
    # Get the dimensions of x and y
    n = length(x)
    m = length(y)
    
    # Check the input
    if n < 1 || m < 1
        throw(ArgumentError("n and m must be greater than or equal to 1."))
    end

    if p < 1 || q < 1
        throw(ArgumentError("p and q must be greater than or equal to 1."))
    end

    if p % 2 == 0 || q % 2 == 0
        throw(ArgumentError("p and q must be odd integers."))
    end 

    # Create a psd polynomial b(x, y)
    b = create_psd_b(x, y, p, q; seed=seed)

    # Get the values of λ, μ, and ν
    lambda = maximum([maximum.(coefficients.(differentiate(differentiate(b, x), y))); - minimum.(coefficients.(differentiate(differentiate(b, x), y)))])
    mu = p == 1 ? 0 : maximum([maximum.(coefficients.(differentiate(b, x, 2))); - minimum.(coefficients.(differentiate(b, x, 2)))])
    nu = q == 1 ? 0 : maximum([maximum.(coefficients.(differentiate(b, y, 2))); - minimum.(coefficients.(differentiate(b, y, 2)))])

    # println("lambda = ", lambda)
    # println("mu = ", mu)
    # println("nu = ", nu)

    # Construct g(x, y), h(x, y), and w(x, y)
    g = m^2 * lambda / (2 * p * (2*p - 1)) * sum(prod.(collect(with_replacement_combinations(x.^(2*p), 2)))) + n^2 * lambda / (2 * q * (2*q - 1)) * sum(prod.(collect(with_replacement_combinations(y.^(2*q), 2))))
    if p == 1 && q == 1
        h = 0
        w = 0
    elseif p > 1 && q == 1
        h = m * (m+1) * mu / (4 * p * (p+1)) * sum(prod.(collect(with_replacement_combinations(x.^(p+1), 2)))) + mu / (2 * (p-1) * (p-2)) * sum(prod.(collect(with_replacement_combinations(x.^(p-1), 2)))) * sum(prod.(collect(with_replacement_combinations(y.^(2*q), 2))))
        w = 2 * m^2 * q * mu / ((p-1) * (p-2) * (2*p-3)) * sum(prod.(collect(with_replacement_combinations(x.^(2*p-2), 2)))) + n^2 * mu / ((p-2) * (4*q-1)) * sum(prod.(collect(with_replacement_combinations(y.^(4*q), 2))))
    elseif p == 1 && q > 1
        h = n * (n+1) * nu / (4 * q * (q+1)) * sum(prod.(collect(with_replacement_combinations(y.^(q+1), 2)))) + nu / (2 * (q-1) * (q-2)) * sum(prod.(collect(with_replacement_combinations(y.^(q-1), 2)))) * sum(prod.(collect(with_replacement_combinations(x.^(2*p), 2))))
        w = 2 * n^2 * p * nu / ((q-1) * (q-2) * (2*q-3)) * sum(prod.(collect(with_replacement_combinations(y.^(2*q-2), 2)))) + m^2 * nu / ((q-2) * (4*p-1)) * sum(prod.(collect(with_replacement_combinations(x.^(4*p), 2))))
    else    # p > 1 && q > 1
        h = m * (m+1) * mu / (4 * p * (p+1)) * sum(prod.(collect(with_replacement_combinations(x.^(p+1), 2)))) + mu / (2 * (p-1) * (p-2)) * sum(prod.(collect(with_replacement_combinations(x.^(p-1), 2)))) * sum(prod.(collect(with_replacement_combinations(y.^(2*q), 2)))) + n * (n+1) * nu / (4 * q * (q+1)) * sum(prod.(collect(with_replacement_combinations(y.^(q+1), 2)))) + nu / (2 * (q-1) * (q-2)) * sum(prod.(collect(with_replacement_combinations(y.^(q-1), 2)))) * sum(prod.(collect(with_replacement_combinations(x.^(2*p), 2))))
        w = 2 * m^2 * q * mu / ((p-1) * (p-2) * (2*p-3)) * sum(prod.(collect(with_replacement_combinations(x.^(2*p-2), 2)))) + m^2 * nu / ((q-2) * (4*p-1)) * sum(prod.(collect(with_replacement_combinations(x.^(4*p), 2)))) + 2 * n^2 * p * nu / ((q-1) * (q-2) * (2*q-3)) * sum(prod.(collect(with_replacement_combinations(y.^(2*q-2), 2)))) + n^2 * mu / ((p-2) * (4*q-1)) * sum(prod.(collect(with_replacement_combinations(y.^(4*q), 2))))
    end 

    # println("g(x, y) = ", g)
    # println("h(x, y) = ", h)
    # println("w(x, y) = ", w)
    
    # Construct f(x, y)
    f = b + g + h + w
    println("f(x, y) = ", f)

    return f
end

