using JuMP, GLPK, MathOptInterface
const MOI = MathOptInterface
using Random
using BenchmarkTools # Add BenchmarkTools package

# Global counter for outer iterations (similar to NumStep in Mathematica)
global NumStep = 0

# Define the inner optimization function for the new problem.
# Given parameter t1_val (which is t[1]), solve:
# minimize d_var 
# subject to:
#   8*t1_val*p_var - 8*t1_val^2 - 1 - d_var <= 0  (from tcons1)
#  -8*t1_val*p_var + 8*t1_val^2 + 1 - d_var <= 0  (from tcons2)
# and p_var, d_var in [-50, 50] (consistent with original example's x,y,z bounds)

function solve_inner_new_problem(t1_val::Float64)
    # Create a model with GLPK (a linear programming solver)
    model = Model(GLPK.Optimizer)
    set_silent(model) # Suppress solver output

    # Define inner variables (p[1] and d from the new problem)
    # Bounds are kept consistent with the example's x,y,z variables.
    @variable(model, -50 <= p_var <= 50) # Corresponds to p[1]
    @variable(model, -50 <= d_var <= 50) # Corresponds to d
                                         # Note: d >= 0 is implied by problem constraints if p_var is unbounded.
                                         # The problem structure d >= |expr| means optimal d is non-negative.

    # Objective: minimize d_var
    @objective(model, Min, d_var)

    # Coefficients and terms based on the input parameter t1_val
    coeff_p_term = 8 * t1_val
    const_term_t_sq_related = 8 * t1_val^2 # This is 8*t[1]^2 term

    # Constraints derived from the new problem's tcons1 and tcons2:
    # tcons1: -(1-(8*p[1]*t[1]-8*t[1]^2))-d <=0
    # Rewritten: 8*p[1]*t[1] - 8*t[1]^2 - 1 - d <= 0
    # JuMP form: 8*t1_val*p_var - d_var <= 1 + 8*t1_val^2
    @constraint(model, coeff_p_term * p_var - d_var <= 1 + const_term_t_sq_related)

    # tcons2: (1-(8*p[1]*t[1]-8*t[1]^2))-d <=0
    # Rewritten: -8*p[1]*t[1] + 8*t[1]^2 + 1 - d <= 0
    # JuMP form: -8*t1_val*p_var - d_var <= -1 - 8*t1_val^2
    @constraint(model, -coeff_p_term * p_var - d_var <= -1 - const_term_t_sq_related)
    
    optimize!(model)
    status = termination_status(model)

    if status == MOI.OPTIMAL || status == MOI.LOCALLY_SOLVED # GLPK should return OPTIMAL for LPs
        return objective_value(model)
    else
        # If the solver fails, return a very poor value (negative infinity, as SA maximizes).
        # This LP should generally be feasible given the variable bounds.
        return -Inf 
    end
end

# Define the outer function "fun_new_problem" that SA will optimize.
# It takes a vector of parameters (only t1_val for this problem)
# and returns the result of the inner minimization (the value of d).
function fun_new_problem(params::Vector{Float64})
    t1_val = params[1] # The SA parameters vector will contain one element: t[1]
    return solve_inner_new_problem(t1_val)
end

# A simple simulated annealing algorithm for maximization.
# This function is kept identical to the one in the provided example.
function simulated_annealing(obj::Function, lower::Vector{Float64}, upper::Vector{Float64};
                             max_iters::Int = 10000, T0::Float64 = 1.0, α::Float64 = 0.995)
    # Start at a random point within the bounds.
    current = lower .+ rand(length(lower)) .* (upper .- lower)
    current_val = obj(current)
    best = current
    best_val = current_val
    T = T0
    global NumStep = 0  # Reset counter at start of SA run

    for iter in 1:max_iters
        # Generate a candidate by perturbing the current solution.
        candidate = current .+ (rand(length(lower)) .- 0.5) .* (upper .- lower) .* 0.1
        candidate = clamp.(candidate, lower, upper) # Ensure candidate stays within bounds
        
        candidate_val = obj(candidate)
        Δ = candidate_val - current_val

        # Accept candidate if it improves the objective or with a probability.
        if Δ > 0 || (T > 1e-9 && exp(Δ / T) > rand()) # Added T > 1e-9 to prevent issues if T is extremely small
            current = candidate
            current_val = candidate_val
            NumStep += 1  # Count each accepted step

            if current_val > best_val
                best = current
                best_val = current_val
            end
        end

        # Cool down the temperature.
        T *= α
    end
    return best_val, best
end

# Set up the outer optimization bounds for the new problem.
# t[1] is bounded between 0 and 0.75.
lower_bounds_new_problem = [0.0]
upper_bounds_new_problem = [0.75]

# Define a function that wraps the whole optimization process for the new problem.
function run_optimization_new_problem()
    # Run the simulated annealing optimization.
    # SA maximizes fun_new_problem, which returns the minimized 'd' from the inner LP.
    res_val, res_params = simulated_annealing(fun_new_problem, 
                                              lower_bounds_new_problem, 
                                              upper_bounds_new_problem; 
                                              max_iters=10000, T0=1.0, α=0.995)
    return res_val, res_params
end

# --- Main execution ---

# Run benchmark for the new problem's optimization
println("Running benchmark for the new problem...")
# Note: Benchmarking runs the function multiple times, NumStep will reflect the last run in the benchmark.
benchmark_result_new = @benchmark run_optimization_new_problem()

# Run once normally to get the results for printing and to set NumStep for the final reported value.
println("\nRunning optimization once for final results...")
res_val_new, res_params_new = run_optimization_new_problem()

# Print results for the new problem
println("\n--- Results for the New Problem ---")
println("Best value (SA maximized (inner minimized d)): ", res_val_new)
if length(res_params_new) == 1
    println("Optimal parameter (t1): ", res_params_new[1])
else
    println("Optimal parameters (t_values): ", res_params_new)
end
println("Number of accepted steps in the last run: ", NumStep)

# Print benchmark results
println("\nBenchmark Results (New Problem):")
println("Median time: ", median(benchmark_result_new.times) / 1_000_000, " ms")
println("Mean time: ", mean(benchmark_result_new.times) / 1_000_000, " ms")
println("Memory: ", benchmark_result_new.memory, " bytes")
println("Allocations: ", benchmark_result_new.allocs)
display(benchmark_result_new)