using JuMP
using Ipopt # Changed from GLPK to Ipopt for non-linear inner problem
using MathOptInterface
const MOI = MathOptInterface
using Random
using BenchmarkTools # Uncomment if benchmarking is desired, e.g., add `using BenchmarkTools`

# Global problem parameters (from the problem description)
const problem_n = 7    # Corresponds to 'n' in the problem statement
const problem_nx = problem_n + 1 # Corresponds to 'nx' (n+1)
# const problem_nt = 1 # Corresponds to 'nt' (number of t variables)

# Global counter for accepted SA steps, similar to original code
global NumStep = 0

# Define the inner optimization function.
# For a given t1_val (parameter from SA), this function solves:
# minimize_{x[1]...x[n]} (ReR^2 + ImR^2)
# subject to -3.1 <= x[j] <= 3.1 for j=1..n
# The minimized value is effectively x[nx].
function solve_inner(t1_val::Float64)
    cos_t1 = cos(t1_val)
    sin_t1 = sin(t1_val)

    # Denominators in ReR and ImR definitions:
    # den1 = 5 - 4*cos(t1). Min value is 5-4=1 (when cos_t1=1), Max value is 5+4=9 (when cos_t1=-1).
    # So, den1 is always in [1, 9], ensuring no division by zero.
    den1 = 5.0 - 4.0 * cos_t1
    # den2 = 4*cos(t1) - 5. This is -den1.
    den2 = 4.0 * cos_t1 - 5.0

    term_Re_const = (cos_t1 - 2.0) / den1
    term_Im_const = sin_t1 / den2

    # Create a JuMP model with Ipopt solver
    model = Model(Ipopt.Optimizer)
    set_optimizer_attribute(model, "print_level", 0) # Suppress Ipopt's own output

    # Define variables x[1]...x[n] (i.e., x[1]...x[7] for problem_n=7)
    @variable(model, -3.1 <= x_inner[j=1:problem_n] <= 3.1)

    # Define ReR and ImR as JuMP expressions
    # Sums are from j=1 to problem_n
    sum_Re_terms = @expression(model, sum(x_inner[j] * cos((j-1.0) * t1_val) for j in 1:problem_n))
    sum_Im_terms = @expression(model, sum(x_inner[j] * sin((j-1.0) * t1_val) for j in 1:problem_n))

    ReR_expr = @expression(model, term_Re_const - sum_Re_terms)
    ImR_expr = @expression(model, term_Im_const - sum_Im_terms)

    # Objective: minimize ReR^2 + ImR^2. This value represents x[nx].
    @objective(model, Min, ReR_expr^2 + ImR_expr^2)

    # Solve the inner problem
    optimize!(model)
    status = termination_status(model)

    if status == MOI.OPTIMAL || status == MOI.LOCALLY_SOLVED
        return objective_value(model) # This is the minimized x[nx] for the given t1_val
    else
        @warn "Inner optimization failed for t1_val = $t1_val with status $status. Returning Inf."
        return Inf # Return a very poor value (for minimization SA)
    end
end

# Define the outer function "fun" that SA will operate on.
# 'params' is a vector, for this problem it will contain only t1_val.
function fun(params::Vector{Float64})
    t1_val = params[1] # t[1] is the first (and only) parameter for SA
    return solve_inner(t1_val)
end

# Simulated Annealing algorithm adapted for MINIMIZATION.
function simulated_annealing_minimize(obj_fun::Function, 
                                      lower_bounds_sa::Vector{Float64}, 
                                      upper_bounds_sa::Vector{Float64};
                                      max_iters::Int = 10000, 
                                      T0::Float64 = 1.0, 
                                      alpha_cooling::Float64 = 0.995,
                                      step_size_factor::Float64 = 0.1)
    
    num_dims = length(lower_bounds_sa)
    current_sol = lower_bounds_sa .+ rand(num_dims) .* (upper_bounds_sa .- lower_bounds_sa)
    current_val = obj_fun(current_sol)

    # Retry mechanism if the very first random point results in failure (Inf)
    max_initial_attempts = 10
    attempts = 0
    while current_val == Inf && attempts < max_initial_attempts
        @warn "Initial SA point resulted in Inf objective. Retrying with new random start (attempt $(attempts+1)/$max_initial_attempts)..."
        current_sol = lower_bounds_sa .+ rand(num_dims) .* (upper_bounds_sa .- lower_bounds_sa)
        current_val = obj_fun(current_sol)
        attempts += 1
    end

    if current_val == Inf
        @error "Could not find a valid starting point for SA after $max_initial_attempts attempts. Aborting SA."
        return Inf, current_sol # Return Inf and the last tried solution
    end

    best_sol = copy(current_sol)
    best_val = current_val
    
    T = T0
    global NumStep # Make sure to use the global counter
    NumStep = 0   # Reset counter at the start of SA run

    for iter in 1:max_iters
        # Generate a candidate solution by perturbing the current solution
        perturbation = (rand(num_dims) .- 0.5) .* (upper_bounds_sa .- lower_bounds_sa) .* step_size_factor
        candidate_sol = current_sol .+ perturbation
        candidate_sol = clamp.(candidate_sol, lower_bounds_sa, upper_bounds_sa) # Ensure candidate is within bounds

        candidate_val = obj_fun(candidate_sol)

        # If solver failed for candidate_val, it will be Inf.
        # This is naturally handled as a very bad solution in minimization context.

        delta_E = candidate_val - current_val # Change in objective function value

        # Metropolis acceptance criterion for MINIMIZATION:
        # Accept if candidate is better (delta_E < 0).
        # Accept worse solution (delta_E > 0) with probability exp(-delta_E / T).
        if delta_E < 0 || (candidate_val != Inf && T > 1e-9 && exp(-delta_E / T) > rand())
            current_sol = candidate_sol
            current_val = candidate_val
            NumStep += 1 # Increment for each accepted step (as per original code's NumStep logic)

            if current_val < best_val # If this accepted step is better than the overall best
                best_sol = copy(current_sol) # copy() is important for arrays
                best_val = current_val
            end
        end

        T *= alpha_cooling # Cool down the temperature
    end
    return best_val, best_sol
end

# This function is used to retrieve the optimal x[1]...x[n] values
# after the best t1_val is found by SA. It's similar to solve_inner
# but also returns the x_inner variable values.
function get_optimal_xs_and_obj(t1_val::Float64)
    cos_t1 = cos(t1_val)
    sin_t1 = sin(t1_val)
    den1 = 5.0 - 4.0 * cos_t1
    den2 = 4.0 * cos_t1 - 5.0
    term_Re_const = (cos_t1 - 2.0) / den1
    term_Im_const = sin_t1 / den2

    model = Model(Ipopt.Optimizer)
    set_optimizer_attribute(model, "print_level", 0)
    @variable(model, -3.1 <= x_vars[j=1:problem_n] <= 3.1)

    sum_Re_terms = @expression(model, sum(x_vars[j] * cos((j-1.0) * t1_val) for j in 1:problem_n))
    sum_Im_terms = @expression(model, sum(x_vars[j] * sin((j-1.0) * t1_val) for j in 1:problem_n))
    ReR_expr = @expression(model, term_Re_const - sum_Re_terms)
    ImR_expr = @expression(model, term_Im_const - sum_Im_terms)
    @objective(model, Min, ReR_expr^2 + ImR_expr^2)

    optimize!(model)
    status = termination_status(model)

    if status == MOI.OPTIMAL || status == MOI.LOCALLY_SOLVED
        return objective_value(model), value.(x_vars) # Return objective (x[nx]) and x[1]...x[n]
    else
        @warn "Final optimization to get x_vars failed for t1_val = $t1_val with status $status."
        return Inf, fill(NaN, problem_n) # Return failure indicators
    end
end

# Main function to orchestrate the entire optimization process
function run_full_optimization()
    # SA searches for the optimal t[1].
    # Bounds for t[1] are [0, 2*pi] from the problem statement.
    lower_t_bounds = [0.0]
    upper_t_bounds = [2.0 * pi]

    println("Starting Simulated Annealing for t[1]...")
    # SA parameters (max_iters, T0, alpha_cooling, step_size_factor) might need tuning for best performance.
    # The values from the original code are used as a starting point.
    minimized_x_nx_sa, optimal_t1_params_sa = simulated_annealing_minimize(
        fun, 
        lower_t_bounds, 
        upper_t_bounds,
        max_iters=10000,       # As in original
        T0=1.0,                # As in original
        alpha_cooling=0.995,   # As in original
        step_size_factor=0.1   # A common SA parameter, adjust as needed
    )

    # Check if SA failed to find any valid point
    if minimized_x_nx_sa == Inf
        println("\nSimulated Annealing did not find a valid solution where x[nx] is finite.")
        println("Minimized x[nx] (x[$(problem_nx)]): Inf")
        println("Optimal t[1]: N/A (SA failed)")
        println("Optimal x[1]...x[$(problem_n)]: N/A")
        println("Number of accepted SA steps: ", NumStep)
        return # Exit if SA failed
    end
    
    optimal_t1_val = optimal_t1_params_sa[1] # Extract the t1 value from the result vector

    println("\nSimulated Annealing finished.")
    println("Best t[1] found by SA: ", optimal_t1_val)
    println("Corresponding minimized x[nx] (x[$(problem_nx)]) from SA: ", minimized_x_nx_sa)

    println("\nRunning final optimization with best t[1] to get optimal x[1]...x[$(problem_n)] values...")
    # Get the final objective value and corresponding x[1]...x[n] values using the best t1_val found by SA.
    # This also serves as a verification of the SA's objective value.
    final_obj_val, optimal_x_j_values = get_optimal_xs_and_obj(optimal_t1_val)

    # --- Print Results ---
    println("\n--- Optimization Results ---")
    println("Minimized x[nx] (x[$(problem_nx)]): ", final_obj_val)
    # (minimized_x_nx_sa and final_obj_val should be very close if SA converged well and no numerical issues)
    println("Optimal t[1]: ", optimal_t1_val)
    println("Optimal x[1]...x[$(problem_n)] (x[j] for j=1 to n): ")
    for j in 1:problem_n
        println("  x[$(j)]: ", optimal_x_j_values[j])
    end
    println("Number of accepted SA steps during optimization: ", NumStep)
end

# --- Main Execution ---
# To use BenchmarkTools (ensure `using BenchmarkTools` is at the top):
println("Starting benchmark...")
benchmark_result_new = @benchmark run_full_optimization()
println("Benchmark finished.")
run_full_optimization() # Run once more to get results for printing clearly after benchmark output
println("\n--- Benchmark Results ---")
println("Median time: ", median(benchmark_result_new.times) / 1_000_000, " ms") # Convert ns to ms
println("Mean time: ", mean(benchmark_result_new.times) / 1_000_000, " ms")
println("Memory: ", BenchmarkTools.summary(benchmark_result_new).memory, " bytes") # Correct way to get memory
println("Allocations: ", BenchmarkTools.summary(benchmark_result_new).allocs)
display(benchmark_result_new) # Rich display in environments like Pluto/Jupyter

# Run the optimization once for typical output without benchmarking:
run_full_optimization()