{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4-element Vector{Int64}:\n",
       " 2\n",
       " 3\n",
       " 4\n",
       " 5"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "using Distributed\n",
    "addprocs(4)  # add workers to run parallel restarts if needed\n",
    "#Remove @everywhere if no parallelization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@everywhere using Random\n",
    "@everywhere using LinearAlgebra\n",
    "@everywhere using BARON, JuMP #To model and solve inner problem\n",
    "@everywhere using Ipopt #Suboptimal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    # Inner optimization (Evaluation of G) (using BARON)\n",
    "    @everywhere function inner_max_baron(X, f_list)\n",
    "        m, n = size(X) #Size of the matrix of centers\n",
    "        model = Model(BARON.Optimizer)\n",
    "        set_silent(model)\n",
    "        @variable(model, t >= 0)\n",
    "        @variable(model, y[1:m])\n",
    "        for f in f_list #Set constraints\n",
    "            @constraint(model, f(y) <= 0)\n",
    "        end\n",
    "        for i in 1:n\n",
    "            @constraint(model, t <= sum((y[j] - X[j, i])^2 for j in 1:m)) #Radius constraints\n",
    "        end\n",
    "        @objective(model, Max, t)\n",
    "        JuMP.optimize!(model)\n",
    "        return value(t)\n",
    "    end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@everywhere function ball_sample(m, n, r, center,rng) #Function to sample n points from an m dimensional ball\n",
    "    X = randn(rng,m, n)                      \n",
    "    X ./= sqrt.(sum(X.^2, dims=1))        \n",
    "    scales = (rand(rng,n).^(1/m)) * r         \n",
    "    X .*= scales'                         \n",
    "    X .+= center                          \n",
    "    return X\n",
    "end\n",
    "@everywhere function ellipse(y) #ellipse example\n",
    "    return ((y[1]*y[1])/9)+((y[2]*y[2])/4)-1\n",
    "end\n",
    "    @everywhere function p1(y) #non convex assymetric example\n",
    "        return 2*y[1]*y[1]-y[2]\n",
    "    end\n",
    "    @everywhere function p2(y)\n",
    "        return y[2]-2*(y[1]-1)*(y[1]-1)\n",
    "    end\n",
    "    @everywhere function p3(y)\n",
    "        return y[2]-5*(y[1]+0.1)*(y[1]+0.1)\n",
    "    end\n",
    "    @everywhere function p4(y)\n",
    "        return -0.1-y[1]\n",
    "    end\n",
    "    # @everywhere function p5(y)\n",
    "    #     return y[1]-0.6\n",
    "    # end\n",
    "@everywhere function h1(y) #3D example\n",
    "    return y[3] .- (1 .- y[1].^2 .- 0.3 .* y[2].^2)\n",
    "end\n",
    "\n",
    "@everywhere function h2(y)\n",
    "    return .-(y[3] .+ 1)\n",
    "end\n",
    "\n",
    "@everywhere function h3(y)\n",
    "    return y[2] .- (0.5 .* (y[1] .- 1).^2 .+ 0.2 .* y[3].^2)\n",
    "end\n",
    "#@everywhere f_list = [p1,p2,p3,p4,p5]\n",
    "@everywhere f_list = [ellipse]\n",
    "#@everywhere f_list = [h1,h2,h3,h4]\n",
    "@everywhere function f_inner(X)\n",
    "    return inner_max_baron(X, f_list)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "using BenchmarkTools #For benchmarking if required"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@everywhere begin           #Implementation of algorithm 1\n",
    "    function outermin(\n",
    "        G,                  # objective: G(x::AbstractMatrix) → Real\n",
    "        x0::AbstractMatrix, # initial point N×n\n",
    "        γ::Real,            # stepsize\n",
    "        δ::Real,            # smoothing parameter\n",
    "        T::Int,             # total iter count\n",
    "        b1::Int,            # batch size on refresh steps\n",
    "        b2::Int,            # batch size on variance‐reduction steps\n",
    "        q::Int,              # refresh period\n",
    "        rng\n",
    "    )\n",
    "        N, n = size(x0)\n",
    "        x_curr = copy(x0)\n",
    "        x_prev = copy(x0)\n",
    "        v_prev = zeros(eltype(x0), N, n)\n",
    "        best_x   = copy(x_curr)\n",
    "        best_val = G(best_x)\n",
    "        objective_evol = []\n",
    "        push!(objective_evol,best_val)\n",
    "        for t in 1:T-1\n",
    "            if (t-1) % q == 0\n",
    "                Ws = [randn(rng,N,n) for _ in 1:b1] \n",
    "                for W in Ws; W ./= norm(W); end\n",
    "                #Ws   = basis\n",
    "                gs = [ ((N*n)/(2*δ))*(G(x_curr .+ δ*W) - G(x_curr .- δ*W)) .* W\n",
    "                       for W in Ws ]\n",
    "                v = sum(gs) ./ b1\n",
    "                println((t-1),\" iterations reached, val=\",best_val)\n",
    "            else\n",
    "                # variance reduction\n",
    "                Ws = [randn(rng,N,n) for _ in 1:b2] \n",
    "                for W in Ws; W ./= norm(W); end\n",
    "                diffs = Matrix{Float64}[]\n",
    "                for W in Ws\n",
    "                    g1 = ((N*n)/(2*δ))*( G(x_curr .+ δ*W) - G(x_curr .- δ*W) )\n",
    "                    g0 = ((N*n)/(2*δ))*( G(x_prev .+ δ*W) - G(x_prev .- δ*W) )\n",
    "                    push!(diffs, (g1 .- g0) .* W)\n",
    "                end\n",
    "                v = sum(diffs) ./ b2 .+ v_prev\n",
    "            end\n",
    "            x_prev .= x_curr\n",
    "            x_curr .-= γ .* v\n",
    "            v_prev = v    #gradient step\n",
    "            val = G(x_curr)\n",
    "            push!(objective_evol, val)\n",
    "            if val < best_val\n",
    "                best_val = val #track best value and optimizer\n",
    "                best_x   .= x_curr\n",
    "            end\n",
    "        end\n",
    "\n",
    "        return best_val, best_x, objective_evol #return best value, optimizer and evolution of objective\n",
    "    end\n",
    "    @everywhere function run_restart(idx::Int, G, N::Int, n::Int,  #To run restarts\n",
    "                                      r0, c0, algparms)\n",
    "        #Random.seed!(idx)\n",
    "        rng = MersenneTwister(1234+idx)\n",
    "        # initialize via ball_sample:\n",
    "        x0 = ball_sample(N, n, r0, c0,rng )\n",
    "\n",
    "        @info \"Restart $idx ▶ G(x₀) = $(G(x0))\"\n",
    "        best_val, best_x, objective_evol = outermin(G, x0,\n",
    "                                    algparms.γ, algparms.δ, algparms.T,\n",
    "                                    algparms.b1, algparms.b2, algparms.q,rng)\n",
    "        @info \"Restart $idx ✔ best G = $best_val\"\n",
    "        return (best_val, best_x, objective_evol)\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = x -> f_inner(x)\n",
    "num_restarts = 10\n",
    "N, n         = 2,4           # problem dims\n",
    "r0, c0       = 3, [0.0,0.0]    # ball_sample params\n",
    "algparms = (γ = 5e-2,\n",
    "            δ = 1e-2,\n",
    "            T = 150,\n",
    "            b1 = 24,\n",
    "            b2 = 4,\n",
    "            q  = 1)\n",
    "\n",
    "# Build the argument‐list for each restart\n",
    "restart_args = [(i, G, N, n, r0, c0, algparms) for i in 1:num_restarts]\n",
    "\n",
    "# Launch them in parallel\n",
    "results = pmap(tup -> run_restart(tup...), restart_args)\n",
    "\n",
    "# Pick the global best\n",
    "best_val, best_x = reduce((a,b) -> a[1]<b[1] ? a : b, results)\n",
    "T = algparms.T\n",
    "objective_matrix = [res[3] for res in results]  # vector of vectors\n",
    "objective_matrix = reshape(vcat(objective_matrix...), num_restarts, T)\n",
    "println(\"\\n🏆 Overall best G = $best_val\")\n",
    "println(\"best_radius = \",sqrt(best_val))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.9.0",
   "language": "julia",
   "name": "julia-1.9"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
