using JuMP, GLPK, MathOptInterface
const MOI = MathOptInterface
using Random
using BenchmarkTools # Add BenchmarkTools package
using Statistics     # For median and mean

# Global problem parameters
const nx_param = 50 # nx from the problem statement
# nt is 1, so we only have one t variable, t[1]

# Global counter for outer iterations (similar to NumStep in Mathematica)
global NumStep = 0

# Define the inner optimization function.
# Given t1_val (which is t[1] from the problem), solve:
# minimize (sum {i in 1..nx_param} (x[i]/i))
# subject to
# (sum {i in 1..nx_param} (-t1_val^(i-1)*x[i])) + sin(t1_val) <= 0
# x[i] >= 0
function solve_inner(t1_val::Float64)
    # Create a model with GLPK (a linear programming solver)
    model = Model(GLPK.Optimizer)
    set_silent(model) # Suppress solver output

    # x variables: x[i] >= 0 for i in 1..nx_param
    @variable(model, x[1:nx_param] >= 0)

    # Objective: minimize (sum {i in 1..nx_param} (x[i]/i))
    @objective(model, Min, sum(x[i] / i for i in 1:nx_param))

    # Constraint: (sum {i in 1..nx_param} (-t1_val^(i-1)*x[i])) + sin(t1_val) <= 0
    # This can be rewritten as: sum_{i=1..nx_param} x[i] * (-t1_val^(i-1)) <= -sin(t1_val)
    # Note: Julia's 0.0^0 evaluates to 1.0, which is standard for power series.
    # So, for i=1, the term is x[1] * (-t1_val^0) = -x[1].
    @constraint(model, sum(x[i] * (-t1_val^(i - 1)) for i in 1:nx_param) <= -sin(t1_val))
    
    optimize!(model)
    status = termination_status(model)

    if status == MOI.OPTIMAL || status == MOI.LOCALLY_SOLVED # For LP, OPTIMAL is expected
        return objective_value(model)
    else
        # If the solver fails (e.g., infeasible, or unbounded in the minimization direction),
        # return a value indicating a very poor solution for minimization (+Inf).
        # The outer 'fun' will negate this, so SA (which maximizes) will see -Inf and avoid this t1_val.
        # println("Inner solver failed for t1_val = $t1_val with status: $status") # Optional debug
        return Inf 
    end
end

# Define the outer function "fun" that SA will call.
# It takes a vector of parameters (here, just t[1]).
# It returns the NEGATIVE of the inner minimization result because
# the provided simulated_annealing function is designed for MAXIMIZATION.
function fun(params::Vector{Float64})
    t1_val = params[1] # t[1] is the first (and only) parameter for SA

    # The SA should keep t1_val within bounds [0,1] due to clamp.
    # If somehow t1_val is outside, sin might error or behave unexpectedly if complex numbers were allowed.
    # However, sin(Float64) is well-defined for all real inputs.
    # The problem statement constrains t1 to [0,1].

    inner_objective_value = solve_inner(t1_val)
    return -inner_objective_value # SA maximizes, so we maximize the negative of the minimum
end

# A simple simulated annealing algorithm for MAXIMIZATION.
# It searches over a dim-dimensional box defined by lower and upper bounds.
function simulated_annealing(obj_fun, lower::Vector{Float64}, upper::Vector{Float64};
                             max_iters::Int = 10000, T0::Float64 = 1.0, α::Float64 = 0.995,
                             step_size_factor::Float64 = 0.1) # step_size_factor is relative to the range of each parameter
    
    dim = length(lower)
    # Start at a random point within the bounds.
    current_sol = lower .+ rand(dim) .* (upper .- lower)
    current_val = obj_fun(current_sol)
    
    # Initialize best solution found so far
    best_sol = copy(current_sol)
    best_val = current_val
    
    T = T0
    global NumStep = 0  # Reset counter at the start of each SA run

    for iter in 1:max_iters
        # Generate a candidate solution by perturbing the current solution.
        # Perturbation is a random step, scaled by step_size_factor of the domain range for each dimension.
        perturbation = (rand(dim) .- 0.5) .* (upper .- lower) .* step_size_factor
        candidate_sol = current_sol .+ perturbation
        
        # Ensure the candidate solution stays within the defined bounds.
        candidate_sol = clamp.(candidate_sol, lower, upper)
        
        candidate_val = obj_fun(candidate_sol)
        
        # Δ is the improvement. Positive Δ is better for maximization.
        # If candidate_val is -Inf (e.g., inner problem failed badly for maximization perspective), 
        # Δ will be -Inf or very small, making acceptance less likely unless T is very high.
        Δ = candidate_val - current_val 

        # Metropolis-Hastings-like acceptance criterion
        if Δ > 0 || (T > 0 && exp(Δ / T) > rand()) # Accept better solutions or worse ones with probability
            current_sol = candidate_sol
            current_val = candidate_val
            NumStep += 1  # Count each accepted step
            
            if current_val > best_val # If this candidate is the best so far
                best_sol = copy(current_sol) # Store it
                best_val = current_val
            end
        end

        # Cool down the temperature.
        T *= α
    end
    # best_val is the maximum of (-inner_objective)
    # best_sol is the parameter vector (t[1] here) that achieved it
    return best_val, best_sol 
end

# --- Main execution flow ---

# Set up the outer optimization bounds for t[1]: t[1] is in [0, 1].
# The 'params' vector for SA will contain one element, representing t[1].
lower_bounds_sa = [0.0]
upper_bounds_sa = [1.0]

# Define a function that wraps the whole optimization process
function run_optimization()
    # Run the simulated annealing optimization.
    # sa_best_val will be the maximum value found for fun(params), 
    # which is max(-inner_objective_value).
    # sa_best_params will be the vector of t values (just t[1] here) that achieved this.
    sa_best_val, sa_best_params = simulated_annealing(
        fun,                    # Objective function for SA to maximize
        lower_bounds_sa,        # Lower bounds for t[1]
        upper_bounds_sa,        # Upper bounds for t[1]
        max_iters=10000,        # Max iterations, as in the example
        T0=1.0,                 # Initial temperature, as in the example
        α=0.995,                # Cooling rate, as in the example
        step_size_factor=0.1    # Relative step size for generating candidates
    )

    # The true minimum objective value for the original problem is -sa_best_val.
    actual_min_original_obj = -sa_best_val 
    return actual_min_original_obj, sa_best_params
end

# --- Benchmarking and Execution ---

# Run benchmark
# Note: BenchmarkTools runs the function multiple times. 
# NumStep will reflect the count from the last run within a benchmark sample.
# The run_optimization function calls SA, which resets NumStep internally.
# So, NumStep printed after the "final run" will be correct for that specific run.
println("Starting benchmark (this might take a moment)...")
# Using fewer evals/samples for quicker demonstration; defaults are usually higher.
benchmark_result = @benchmark run_optimization() evals=5 samples=2 seconds=120

# Run once normally to get the results for printing
println("\nStarting final run for results...")
res_min_obj, res_optimal_t_params_vec = run_optimization()

# Print results
println("\n--- Results ---")
println("Best value (minimum of original objective sum(x[i]/i)): ", res_min_obj)
if !isempty(res_optimal_t_params_vec)
    println("Optimal t[1] parameter found by SA: ", res_optimal_t_params_vec[1])
else
    # This case should not happen if SA runs correctly
    println("Optimal t[1] parameter: Not found") 
end
println("Number of accepted SA steps (for the final run): ", NumStep)

# Print benchmark results
println("\n--- Benchmark Results ---")
if benchmark_result.times !== nothing && !isempty(benchmark_result.times)
    println("Median time: ", median(benchmark_result.times) / 1_000_000, " ms")
    println("Mean time: ", mean(benchmark_result.times) / 1_000_000, " ms")
else
    println("Median time: N/A (benchmark times possibly empty or not recorded properly)")
    println("Mean time: N/A (benchmark times possibly empty or not recorded properly)")
end
println("Memory: ", benchmark_result.memory, " bytes")
println("Allocations: ", benchmark_result.allocs)
println("\nFull benchmark object:")
display(benchmark_result)


# Optional: Verification step
# This re-solves the inner LP with the optimal t[1] found by SA
# to confirm the objective value.
if isfinite(res_min_obj) && !isempty(res_optimal_t_params_vec)
    optimal_t1_found = res_optimal_t_params_vec[1]
    println("\nVerifying inner solution with optimal t1 = ", optimal_t1_found, "...")

    model_verify = Model(GLPK.Optimizer)
    set_silent(model_verify)
    @variable(model_verify, x_verify[1:nx_param] >= 0)
    @objective(model_verify, Min, sum(x_verify[i] / i for i in 1:nx_param))
    @constraint(model_verify, sum(x_verify[i] * (-optimal_t1_found^(i - 1)) for i in 1:nx_param) <= -sin(optimal_t1_found))
    
    optimize!(model_verify)

    if termination_status(model_verify) == MOI.OPTIMAL
        verified_obj_val = objective_value(model_verify)
        println("Verified inner objective value: ", verified_obj_val)
        
        # Compare with res_min_obj from SA
        if abs(verified_obj_val - res_min_obj) < 1e-6 # Tolerance for floating point arithmetic
            println("Verified value matches reported minimum objective. Consistency check passed.")
        else
            println("WARNING: Verified value (", verified_obj_val, ") does not closely match SA's reported minimum (", res_min_obj, ").")
            println("Discrepancy: ", abs(verified_obj_val - res_min_obj))
            println("This can happen due to SA's heuristic nature or floating point precision.")
        end
    else
        println("Verification LP failed. Solver status for t1=", optimal_t1_found, ": ", termination_status(model_verify))
        println("This might indicate issues with the t1 value found or the LP formulation at that specific t1.")
    end
else
    println("\nSkipping verification step as results were non-finite or t_params vector was empty.")
end

# For sanity check, test solve_inner with known t1 values:
println("\n--- Sanity Checks for solve_inner ---")
obj_at_t_zero = solve_inner(0.0)
println("Objective from solve_inner(0.0): $obj_at_t_zero") # Expected: 0.0

obj_at_t_one = solve_inner(1.0)
println("Objective from solve_inner(1.0): $obj_at_t_one")
# For t1=1.0, constraint: sum(x[i]*(-1)^(i-1)) <= -sin(1) => x[1]-x[2]+x[3]-... <= -sin(1)
# To minimize sum(x[i]/i), with x[i]>=0.
# If we set x[1]=sin(1) and other x[i]=0, constraint is sin(1) <= -sin(1) => 2sin(1)<=0, which is false as sin(1)>0.
# The constraint for t1=1.0 is actually: -x[1] + x[2] - x[3] + ... <= -sin(1.0)
# If we set x[1] = sin(1.0) and all other x[i]=0, constraint becomes -sin(1.0) <= -sin(1.0). This is met.
# The objective value would be x[1]/1 = sin(1.0). This is likely the minimum.
println("Value of sin(1.0) for comparison: $(sin(1.0))")
# So, obj_at_t_one should be close to sin(1.0).