using JuMP, Ipopt, MathOptInterface
const MOI = MathOptInterface
using Random
using BenchmarkTools  # make sure you have added BenchmarkTools

# Global counter for accepted outer steps
global NumStep = 0

# Problem parameters
const nt = 2          # number of t‐variables
const np = 4          # degree of polynomial
# Precompute the list of (i,j) pairs for p[i,j], in the same order we will pack them into a vector
const coeff_indices = [(i, j) for i in 0:np for j in 0:i]
const ncoeff = length(coeff_indices)  # should be (np+1)*(np+2)/2 = 15

"""
    solve_inner(pvec::Vector{Float64})

For a given vector of polynomial coefficients `pvec` of length `ncoeff`,
solve the inner max‐error problem

    max d
    s.t.  (t1^t2 − P(t1,t2; pvec)) − d ≤ 0
         −(t1^t2 − P(t1,t2; pvec)) − d ≤ 0
         1 ≤ t1,t2 ≤ 2, d ≥ 0

Returns the optimal d (the worst‐case approximation error).
"""
function solve_inner(pvec::Vector{Float64})
    model = Model(Ipopt.Optimizer)
    set_silent(model)

    # decision variables
    @variable(model, 1.0 ≤ t1 ≤ 2.0)
    @variable(model, 1.0 ≤ t2 ≤ 2.0)
    @variable(model, d ≥ 0)

    # Build the polynomial P(t1,t2) = sum_{i,j} p[i,j] * t1^j * t2^(i-j)
    # We refer to our flat vector pvec[k], where coeff_indices[k] = (i,j).
    @NLconstraint(model,
        t1^t2
        - sum(
            pvec[k] * t1^(coeff_indices[k][2]) * t2^(coeff_indices[k][1] - coeff_indices[k][2])
            for k in 1:ncoeff
          )
        - d ≤ 0
    )
    @NLconstraint(model,
        -(
          t1^t2
          - sum(
              pvec[k] * t1^(coeff_indices[k][2]) * t2^(coeff_indices[k][1] - coeff_indices[k][2])
              for k in 1:ncoeff
            )
         )
        - d ≤ 0
    )

    # Maximize the error d
    @objective(model, Min, d)

    optimize!(model)
    status = termination_status(model)
    if status == MOI.OPTIMAL || status == MOI.LOCALLY_SOLVED
        return objective_value(model)
    else
        return -Inf  # if it fails, treat as extremely bad
    end
end

"""
    fun(pvec)

Wrapper for the outer simulated‐annealing: we want to *minimize* the worst‐case error,
but our SA routine only does maximization, so we return -d.
"""
fun(pvec) = -solve_inner(pvec)

"""
    simulated_annealing(obj, lower, upper; max_iters, T0, α)

A simple simulated annealing over the box [lower, upper] to maximize `obj`.
Returns (best_obj, best_params).
"""
function simulated_annealing(obj, lower::Vector{Float64}, upper::Vector{Float64};
                             max_iters::Int = 10_000,
                             T0::Float64 = 1.0,
                             α::Float64 = 0.995)
    # initialize
    current = lower .+ rand(length(lower)) .* (upper .- lower)
    current_val = obj(current)
    best, best_val = copy(current), current_val
    T = T0
    global NumStep = 0

    for iter in 1:max_iters
        # propose a candidate
        candidate = current .+ (rand(length(lower)) .- 0.5) .* (upper .- lower) .* 0.1
        candidate = clamp.(candidate, lower, upper)
        candidate_val = obj(candidate)
        Δ = candidate_val - current_val

        # accept criterion
        if Δ > 0 || exp(Δ / T) > rand()
            current, current_val = candidate, candidate_val
            NumStep += 1
            if current_val > best_val
                best, best_val = copy(current), current_val
            end
        end

        T *= α
    end

    return best_val, best
end

# Set up outer bounds for polynomial coefficients, e.g. in [-10,10]
lower_bounds = fill(-10.0, ncoeff)
upper_bounds = fill( 10.0, ncoeff)

# Benchmark and run
benchmark_result = @benchmark simulated_annealing(fun, lower_bounds, upper_bounds; max_iters=10_000)

# Run once to get the best solution
best_val, best_params = simulated_annealing(fun, lower_bounds, upper_bounds; max_iters=10_000)

# Interpret results
best_error = -best_val  # because fun returns -d

println("Best worst‐case error (d*): ", best_error)
println("Optimal polynomial coefficients p[i,j] in the order (i=0,j=0), (i=1,j=0),(i=1,j=1), ... :")
println(best_params)
println("Number of accepted outer steps: ", NumStep)

println("\nBenchmark Results:")
println("  Median time: ", median(benchmark_result.times) / 1e6, " ms")
println("  Mean time:   ", mean(benchmark_result.times) / 1e6, " ms")
println("  Memory:      ", benchmark_result.memory, " bytes")
println("  Allocations: ", benchmark_result.allocs)
display(benchmark_result)
