using JuMP, Ipopt, BenchmarkTools
using Random

# Global counter for outer iterations
global NumStep = 0

function solve_inner(t1)
    model = Model(Ipopt.Optimizer)
    set_silent(model)

    @variable(model, -1 <= x1 <= 1)
    @variable(model,  0 <= x2 <= 0.2)

    @NLobjective(model, Min, (x1 - 2)^2 + (x2 - 0.2)^2)
    @NLconstraint(model,
        (5 * sin(pi * sqrt(t1))) / (1 + t1^2) * x1^2 - x2 <= 0
    )

    optimize!(model)
    if termination_status(model) in (MOI.OPTIMAL, MOI.LOCALLY_SOLVED)
        return objective_value(model)
    else
        return -Inf
    end
end

fun(params::Vector{Float64}) = solve_inner(params[1])

function simulated_annealing(obj, lower::Vector{Float64}, upper::Vector{Float64};
                             max_iters::Int = 10_000, T0::Float64 = 1.0, α::Float64 = 0.995)
    current = lower .+ rand(length(lower)) .* (upper .- lower)
    current_val = obj(current)
    best, best_val = copy(current), current_val
    T = T0
    global NumStep = 0

    for _ in 1:max_iters
        candidate = current .+ (rand(length(lower)) .- 0.5) .* (upper .- lower) .* 0.1
        candidate = clamp.(candidate, lower, upper)
        cand_val = obj(candidate)
        Δ = cand_val - current_val
        if Δ > 0 || exp(Δ/T) > rand()
            current, current_val = candidate, cand_val
            NumStep += 1
            if current_val > best_val
                best, best_val = copy(current), current_val
            end
        end
        T *= α
    end

    best_val, best
end

lower_bounds = [1e-6]
upper_bounds = [8.0]

function run_optimization()
    simulated_annealing(fun, lower_bounds, upper_bounds;
                        max_iters=10_000, T0=1.0, α=0.995)
end

# Benchmark + run
benchmark_result = @benchmark run_optimization()
res_val, res_params = run_optimization()

println("Best value: ", res_val)
println("Optimal t: ", res_params)
println("Accepted steps: ", NumStep)

println("\nBenchmark:")
println("Median time: ", median(benchmark_result.times)/1e6, " ms")
println("Mean time:   ", mean(benchmark_result.times)/1e6, " ms")
println("Memory:      ", benchmark_result.memory, " bytes")
println("Allocs:      ", benchmark_result.allocs)
display(benchmark_result)
