using JuMP, GLPK, MathOptInterface
const MOI = MathOptInterface
using Random
using BenchmarkTools # Add BenchmarkTools package
using Statistics     # For mean and median functions for benchmark results

# Global counter for accepted SA steps (similar to NumStep in Mathematica)
global NumStep = 0

# Define the inner optimization function for the new problem.
# Given t1_val (a scalar Float64) and nx (an Int), solve the LP for x_vars.
function solve_inner_lp_for_new_problem(t1_val::Float64, nx::Int)
    # Create a model with GLPK (a linear programming solver)
    model = Model(GLPK.Optimizer)
    set_silent(model) # Suppress solver output to keep SA console clean

    # Define x variables: x_vars[i] >= 0 for i in 1..nx
    @variable(model, x_vars[1:nx] >= 0)

    # Objective: minimize (sum {i in 1..nx} (x_vars[i]/i))
    @objective(model, Min, sum(x_vars[i] / i for i in 1:nx))

    # Constraint: (sum {i in 1..nx} (-t1_val^(i-1)*x_vars[i])) + exp(t1_val) <= 0
    # Rewritten as: sum {i in 1..nx} (-t1_val^(i-1)*x_vars[i]) <= -exp(t1_val)
    
    # Calculate coefficients for the constraint sum: -t1_val^(i-1)
    constraint_coeffs = Vector{Float64}(undef, nx)
    if t1_val == 0.0
        # Handle 0^0 case for t1_val^(i-1) when i=1
        constraint_coeffs[1] = -1.0 # -(0.0^(1-1)) = -1.0 
        if nx > 1
            constraint_coeffs[2:nx] .= 0.0 # -(0.0^(k)) = 0.0 for k > 0
        end
    else
        for i in 1:nx
            constraint_coeffs[i] = -t1_val^(i-1)
        end
    end
    
    rhs_val = -exp(t1_val)
    @constraint(model, sum(constraint_coeffs[i] * x_vars[i] for i in 1:nx) <= rhs_val)

    # Solve the LP
    optimize!(model)
    status = termination_status(model)

    # Return the objective value or an indicator of problem status
    if status == MOI.OPTIMAL
        return objective_value(model)
    elseif status == MOI.INFEASIBLE
        # No x_vars satisfy constraints for this t1_val; very bad for an outer minimization.
        return Inf
    elseif status == MOI.DUAL_INFEASIBLE
        # Primal is unbounded (minimization problem can go to -Inf).
        # This is highly desirable for an outer minimization.
        return -Inf
    else
        # Other statuses (e.g., solver error, time limit).
        # Optionally, print a warning:
        # println("Warning: Inner LP solver failed with status: $status for t1_val = $t1_val")
        return Inf # Treat as a very poor outcome for minimization
    end
end

# Define the outer function "fun_for_sa" that SA will try to minimize.
# It takes a vector of parameters (params_sa, which will contain t1_val)
# and any fixed arguments needed by the inner solver (nx_val).
# It returns the result of the inner minimization.
function fun_for_sa(params_sa::Vector{Float64}, nx_val::Int)
    t1 = params_sa[1] # SA parameters vector contains only t[1] for this problem
    return solve_inner_lp_for_new_problem(t1, nx_val)
end

# Simulated annealing algorithm for minimization.
# obj_function: The function to minimize. Must accept (params_vector, fixed_args_tuple...)
# lower_bounds_sa, upper_bounds_sa: Vectors for SA search parameter bounds.
# fixed_args_for_obj: A tuple of fixed arguments to be passed to obj_function.
function simulated_annealing_for_minimization(
                                    obj_function::Function, 
                                    lower_bounds_sa::Vector{Float64}, 
                                    upper_bounds_sa::Vector{Float64},
                                    fixed_args_for_obj::Tuple; # e.g., (nx_value,)
                                    max_iters::Int = 10000, 
                                    T0::Float64 = 1.0,        # Initial temperature
                                    α::Float64 = 0.995)       # Cooling rate
    
    num_dims = length(lower_bounds_sa)
    
    # Start at a random point within the bounds.
    current_params = lower_bounds_sa .+ rand(num_dims) .* (upper_bounds_sa .- lower_bounds_sa)
    # Call the objective function, splatting the fixed arguments tuple
    current_val = obj_function(current_params, fixed_args_for_obj...) 

    # Initialize best known solution
    best_params = copy(current_params)
    best_val = current_val
    # If the initial point is infeasible (obj returns Inf), best_val starts at Inf.
    # If it returns -Inf, best_val starts at -Inf.

    T = T0  # Current temperature
    global NumStep = 0  # Reset global counter for accepted steps at the start of SA

    for iter in 1:max_iters
        # Generate a candidate solution by perturbing the current solution.
        # Perturbation scale factor (0.1 of range) is taken from the original example.
        step_size_factor = 0.1 
        perturbation = (rand(num_dims) .- 0.5) .* (upper_bounds_sa .- lower_bounds_sa) .* step_size_factor
        candidate_params = current_params .+ perturbation
        
        # Ensure candidate is within bounds
        candidate_params = clamp.(candidate_params, lower_bounds_sa, upper_bounds_sa)

        candidate_val = obj_function(candidate_params, fixed_args_for_obj...)
        
        # Calculate change in objective value
        delta_obj = candidate_val - current_val

        # Acceptance criterion (Metropolis)
        # Accept if candidate is better (delta_obj < 0).
        # Accept worse solutions with a probability exp(-delta_obj / T).
        # Special handling for Inf/-Inf to avoid NaN/DomainError with exp() and ensure correct logic.
        if delta_obj < 0 || # Definitely accept if better
           (T > 1e-9 && # Avoid division by zero if T gets too small
            isfinite(current_val) && isfinite(candidate_val) && # Both are normal numbers
            rand() < exp(-delta_obj / T)) ||
           (current_val == Inf && isfinite(candidate_val)) || # Moving from Inf to finite is always good
           (candidate_val == -Inf && current_val != -Inf) # Moving to -Inf is always good (unless already there)
            
            current_params = candidate_params
            current_val = candidate_val
            NumStep += 1  # Count each accepted step

            if current_val < best_val # Update best solution found so far
                best_val = current_val
                best_params = copy(current_params)
            end
        end

        # Cool down the temperature
        T *= α
    end
    return best_val, best_params
end

# --- Main part of the script ---

# Define problem-specific constants
const PROBLEM_NX = 50
# const PROBLEM_NT = 1; # nt=1 means t is just t[1]

# Define a function that wraps the whole optimization process for the new problem
function run_full_optimization()
    # Bounds for the SA search variable t[1]: [0, 1]
    sa_lower_bounds = [0.0]
    sa_upper_bounds = [1.0]

    # Fixed arguments to be passed to fun_for_sa (in this case, just nx)
    # Must be a tuple.
    sa_fixed_args = (PROBLEM_NX,)

    # Run the simulated annealing optimization.
    # fun_for_sa is the objective function for SA.
    # SA will try to find params_sa (i.e., t[1]) that minimize fun_for_sa.
    optimal_value, optimal_t_params = simulated_annealing_for_minimization(
                                                     fun_for_sa, 
                                                     sa_lower_bounds, 
                                                     sa_upper_bounds,
                                                     sa_fixed_args;
                                                     max_iters=10000, # As per original example
                                                     T0=1.0,          # As per original example
                                                     α=0.995)         # As per original example
    return optimal_value, optimal_t_params
end

# --- Benchmarking and Execution ---

# Run benchmark
# Note: For complex problems, 10000 iterations can be long for benchmarking.
# Consider reducing max_iters within @benchmark if needed for faster feedback,
# but for apples-to-apples with original methodology, keep it the same.
println("Starting benchmark for the new problem... (nx=$PROBLEM_NX, SA iters=10000)")
# The benchmark will run run_full_optimization() multiple times.
# NumStep will be reset at the beginning of each simulated_annealing_for_minimization call.
benchmark_result = @benchmark run_full_optimization()

# Run once normally to get the results for printing
# This ensures NumStep reflects the count for this specific run.
println("\nStarting final run to get results for printing...")
final_value, final_t_params = run_full_optimization()

# --- Print Results ---
println("\n--- Results for the New Optimization Problem ---")
if final_value == -Inf
    println("Best value (minimum of objective): -Inf")
    println("This indicates the inner LP was unbounded for the found optimal t[1].")
elseif final_value == Inf
    println("Best value (minimum of objective): Inf")
    println("This indicates the inner LP was likely always infeasible, or no feasible solution was found.")
else
    println("Best value (minimum of objective): ", final_value)
end
println("Optimal SA parameter t[1]: ", final_t_params[1]) # final_t_params is a vector like [t1_optimal]
println("Number of accepted SA steps in the final run: ", NumStep)

# Print benchmark results summary
println("\nBenchmark Results (summary):")
println("Median time: ", median(benchmark_result.times) / 1_000_000, " ms") # Times are in nanoseconds
println("Mean time: ", mean(benchmark_result.times) / 1_000_000, " ms")
println("Memory allocated: ", BenchmarkTools.prettymemory(benchmark_result.memory))
println("Allocations: ", benchmark_result.allocs)

# For more detailed benchmark info, uncomment:
# display(benchmark_result)