using JuMP, GLPK, MathOptInterface
const MOI = MathOptInterface
using Random
using BenchmarkTools # Add BenchmarkTools package

# Global counter for outer iterations (similar to NumStep in Mathematica)
global NumStep = 0

# Parameter d from the problem
const D_PARAM = 2

# Define the inner optimization function.
# Given t1_val, t2_val, solve for x[i,j] and dist.
function solve_inner(t1_val::Float64, t2_val::Float64)
    # Calculate Q = (1+t[1])^t[2]
    # Ensure t1_val is not -1 to avoid issues with (1+t1_val)^t2_val if t1_val could be -1
    # Given bounds 0 <= t1_val <= 1, 1+t1_val will be >= 1, so (1+t1_val)^t2_val is well-defined.
    q_val = (1.0 + t1_val)^t2_val

    # Create a model with GLPK (a linear programming solver)
    model = Model(GLPK.Optimizer)
    set_silent(model) # Suppress solver output

    # x variables: x[i,j] for 0<=j<=i<=D_PARAM
    # Using JuMP's anonymous variables and storing them in a dictionary or array of arrays
    # For simplicity, let's use a dictionary mapping (i,j) tuples to variables
    # Or, more JuMP-idiomatic, a dense axis array if indices are 0-based or 1-based.
    # Since JuMP arrays are 1-based, we can map: x_jump[i+1, j+1] corresponds to x[i,j]
    # Or, create them with @variable and store in a more structured way.
    # Let's use a dictionary for x_vars for clarity with 0-based indexing.
    x_vars = Dict{Tuple{Int,Int}, VariableRef}()
    for i in 0:D_PARAM
        for j in 0:i
            # Assuming bounds for x variables, e.g., -50 to 50, as none were specified.
            # If x variables are meant to be non-negative or have other bounds, adjust here.
            x_vars[(i,j)] = @variable(model, base_name="x_$(i)_$(j))", lower_bound=-50.0, upper_bound=50.0)
        end
    end

    @variable(model, dist >= 0) # dist must be non-negative

    # Construct S = sum {i in 0..d} ( sum {j in 0..i} (x[i,j]*t[1]^(i-j)*t[2]^j))
    # This will be a JuMP AffExpr
    s_expr = AffExpr(0.0)
    for i in 0:D_PARAM
        for j in 0:i
            coeff = (t1_val^(i-j)) * (t2_val^j)
            # Handle 0^0 case if t1_val can be 0 and i-j is 0.
            # (t1_val^(i-j)) will be t1_val^0 = 1 if i=j.
            # (t2_val^j) will be t2_val^0 = 1 if j=0.
            # Julia's 0.0^0 is 1.0, which is standard in this context.
            add_to_expression!(s_expr, coeff, x_vars[(i,j)])
        end
    end

    # Objective: minimize dist
    @objective(model, Min, dist)

    # Constraints:
    # dist >= S - Q  =>  S - Q - dist <= 0
    # dist >= -(S - Q) => Q - S - dist <= 0
    @constraint(model, s_expr - q_val - dist <= 0)
    @constraint(model, q_val - s_expr - dist <= 0)

    optimize!(model)
    status = termination_status(model)

    if status == MOI.OPTIMAL || status == MOI.LOCALLY_SOLVED
        return objective_value(model) # This is the minimized 'dist'
    else
        # If the solver fails, return a very poor value for maximization (SA is maximizing this)
        # Since dist >= 0, -Inf indicates a problem.
        # Or, if SA expects positive values, a large negative or 0 might be better if dist is always positive.
        # Given dist >= 0, a failure means we can't evaluate, so -Inf is okay for SA to discard.
        println("Inner solver failed with status: ", status, " for t1=", t1_val, ", t2=", t2_val)
        return -Inf # Or some other indicator of failure appropriate for the outer maximizer
    end
end

# Define the outer function "fun" that depends on two parameters (t1, t2).
# It returns the inner minimization result (the minimized 'dist').
function fun(params::Vector{Float64})
    t1_val, t2_val = params[1], params[2]
    return solve_inner(t1_val, t2_val)
end

# A simple simulated annealing algorithm for maximization.
# Modified to search over a 2-dimensional box for t1, t2.
function simulated_annealing(obj_fun::Function, lower::Vector{Float64}, upper::Vector{Float64};
                             max_iters::Int = 10000, T0::Float64 = 1.0, α::Float64 = 0.995,
                             step_size_factor::Float64 = 0.1) # Added step_size_factor

    # Start at a random point within the bounds.
    current_sol = lower .+ rand(length(lower)) .* (upper .- lower)
    current_val = obj_fun(current_sol)

    best_sol = copy(current_sol)
    best_val = current_val
    T = T0
    global NumStep = 0  # Reset counter at start

    for iter in 1:max_iters
        # Generate a candidate by perturbing the current solution.
        # Perturbation is relative to the range of each variable.
        perturbation = (rand(length(lower)) .- 0.5) .* (upper .- lower) .* step_size_factor
        candidate_sol = current_sol .+ perturbation
        candidate_sol = clamp.(candidate_sol, lower, upper) # Ensure candidate is within bounds

        candidate_val = obj_fun(candidate_sol)

        # If solver failed for candidate, candidate_val might be -Inf.
        # We should only proceed if candidate_val is a valid number.
        if candidate_val == -Inf && current_val != -Inf # if current is valid but candidate is not
            # Stay with current, don't evaluate acceptance
        else
            Δ = candidate_val - current_val

            # Accept candidate if it improves the objective or with a probability.
            if Δ > 0 || (T > 1e-8 && exp(Δ / T) > rand()) # Avoid division by zero if T gets too small
                current_sol = candidate_sol
                current_val = candidate_val
                NumStep += 1  # Count each accepted step

                if current_val > best_val
                    best_sol = copy(current_sol)
                    best_val = current_val
                end
            end
        end
        # Cool down the temperature.
        T *= α
        if iter % 1000 == 0 # Optional: print progress
            println("Iter: $iter, Temp: $T, Best Val: $best_val, Current Val: $current_val")
        end
    end
    return best_val, best_sol
end

# Set up the outer optimization bounds for t1 and t2.
# t1: 0 <= t[1] <= 1
# t2: 1 <= t[2] <= 2.5
lower_bounds_t = [0.0, 1.0]
upper_bounds_t = [1.0, 2.5]

# Define a function that wraps the whole optimization process
function run_optimization()
    # Run the simulated annealing optimization.
    # SA parameters might need tuning for this specific problem.
    res_val, res_params_t = simulated_annealing(fun, lower_bounds_t, upper_bounds_t;
                                             max_iters=10000, T0=1.0, α=0.995, step_size_factor=0.1)
    return res_val, res_params_t
end

# --- Main Execution ---

# Run benchmark
# For a more stable benchmark, run it a few times or increase samples in @benchmark
# @benchmark can be slow if each run_optimization call is long.
# Consider reducing max_iters for benchmarking if needed, or use @btime for a single run.
println("Running benchmark (this might take a while)...")
# benchmark_result = @benchmark run_optimization() # This will run it multiple times

# Run once normally to get the results for printing
# To avoid re-running SA if it's very long, you can run it once and then analyze.
# For quick testing, reduce max_iters in simulated_annealing.
# E.g., max_iters=1000 for a quicker test run.
res_val, res_params_t = run_optimization() # Using full max_iters

println("\n--- Results ---")
println("Best value (maximum of minimized dist): ", res_val)
println("Optimal parameters (t1, t2): ", res_params_t)
println("Number of accepted SA steps: ", NumStep)

# Print benchmark results (if benchmark_result is computed)
# println("\nBenchmark Results:")
# println("Median time: ", median(benchmark_result.times) / 1_000_000, " ms")
# println("Mean time: ", mean(benchmark_result.times) / 1_000_000, " ms")
# println("Memory: ", benchmark_result.memory, " bytes")
# println("Allocations: ", benchmark_result.allocs)
# display(benchmark_result)

# Example of using @btime for a single timing run
println("\nTiming a single run with @btime:")
# @btime run_optimization() # This will print time and allocations for one run