
if !(@isdefined run_once) || !run_once

  using Revise

  includet("../Utilities/AlgorithmTools.jl")
  includet("../Utilities/GradientAscent.jl")
  includet("../Utilities/PlottingTools.jl")


  run_once = true

end



using LinearAlgebra
using StaticArrays
using Random
using Flux
using ProgressLogging


using .GradientAscent
using .PlottingTools



const ACTIONS = 2

function computeProbabilitiesOfCrowdedBar!(
  out, state;
  depth=0,
  cache=nothing
)

  players, capacity = length(state), size(out, 1)


  if (isnothing(cache)) cache = @MVector zeros(capacity) end

  cache_size = players - size(out, 2)

  function get(i, j)
    
    j -= cache_size

    if (j < 1) cache[i] else out[i, j] end

  end

  function set(i, j, value)
    
    j -= cache_size

    if (j < 1) cache[i] = value else out[i, j] = value end

  end

  
  for (j, stateⱼ) ∈ enumerate(state), i ∈ min(j + depth, capacity):-1:1
    
    value = stateⱼ

    if (i > 1) value *= get(i - 1, j - 1) end
    if (j + depth > i) value += get(i, j - 1) * (1 - stateⱼ) end

    set(i, j, value)

  end

end


function divideAndConquer!(α, β, out, state)

  if (α ≥ β) return end


  players = length(state)
  capacity = size(out, 1)


  γ₁ = Int(floor((α + β) / 2))
  γ₂ = γ₁ + 1

  computeProbabilitiesOfCrowdedBar!(
    (@view out[:, α:γ₁]), 
    (@view state[vcat(γ₂:β, α:(γ₁ - 1))]);
    depth=players - (β - α + 1),
    cache=MVector{capacity}(out[:, α])
  )


  divideAndConquer!(α, γ₁, out, state)
  divideAndConquer!(γ₂, β, out, state)

end


function probabilitiesOfCrowdedBar(state, capacity)

  state_size = length(state)


  probabilities = @MMatrix zeros(capacity, state_size)

  computeProbabilitiesOfCrowdedBar!(
    (@view probabilities[:, (begin + 1):end]), 
    (@view state[1:(end - 1)])
  )

  divideAndConquer!(1, state_size, probabilities, state)
  

  probabilities[end, :]
  
end



struct ElFarolBarGame
  S::Real
  B::Real
  G::Real
  C::Unsigned
end

stayAtHomePayoff(game::ElFarolBarGame) = game.S
crowdedBarPayoff(game::ElFarolBarGame) = game.B
uncrowdedBarPayoff(game::ElFarolBarGame) = game.G
barCapacity(game::ElFarolBarGame) = Int(game.C)


function pseudoGradientOfExpectedPayoffs(game::ElFarolBarGame, state)

  S = stayAtHomePayoff(game)
  B = crowdedBarPayoff(game)
  G = uncrowdedBarPayoff(game)
  
  G - S .+ probabilitiesOfCrowdedBar(state, barCapacity(game)) * (B - G)

end

function expectedPayoffs(game::ElFarolBarGame, state)

  S = stayAtHomePayoff(game)

  S .+ state .* pseudoGradientOfExpectedPayoffs(game, state)

end



function gradientFunction(game, models, inputs; with_preconditioning=false, convexity_factor=0)

  players = length(models)


  @inline @inbounds function(state)

    output = []
    jacobians = MVector{inputs}[]
  
    for i ∈ axes(state, 1)
  
      outputᵢ, jacobianᵢ = Flux.withjacobian(models[i], state[i])
      jacobianᵢ = only(jacobianᵢ)
  
      if with_preconditioning
        jacobianᵢ = pinv(jacobianᵢ)'
      else
        jacobianᵢ = jacobianᵢ'
      end
  

      push!(output, only(outputᵢ))
      push!(jacobians, jacobianᵢ)
  
    end
    

    pseudo_gradient = pseudoGradientOfExpectedPayoffs(game, output)

    if(convexity_factor > 0)
      pseudo_gradient .-= convexity_factor .* (output .- barCapacity(game) / (players - 1))
    end
  

    SVector{players}(
      MVector{inputs}(jacobians[i] * pseudo_gradient[i])
      for i ∈ axes(state, 1)
    )
  
  end

end



function realizeTrajectories(
    game::ElFarolBarGame, players,
    models, inputs;
    initial_states,
    iterations=Int(1e3),
    convexity_factor=0,
    step=1e-2, 
    with_preconditioning=false,
    with_progress=true,
    kwargs...
  )

  trajectories = []


  total_trajectories = length(initial_states)

  for k ∈ axes(initial_states, 1)

    initial_state = initial_states[k]


    println("[$k / $total_trajectories] x₀ ≈ $(repr([
      round.(only(models[i](initial_state[i])); digits=3)
      for i ∈ 1:players
    ]))")


    trajectory = gradientAscent(
      initial_state,
      gradientFunction(
        game, models, inputs;
        with_preconditioning=with_preconditioning,
        convexity_factor=convexity_factor
      ); 
      iterations=iterations, 
      step=step,
      with_progress=with_progress,
      progress_id=@progressid,
      progress_partition=[k - 1, k] ./ total_trajectories,
      progress_name=get(kwargs, :process_name, "Gradient Ascent")
    )

    println("$(get(kwargs, :prompt, "GD")): xₑ ≈ $(repr([
      round.(only(models[i](trajectory[end].state[i])); digits=3)
      for i ∈ 1:players
    ]))")


    push!(trajectories, trajectory)

  end

  @info ProgressLogging.Progress(@progressid; done=true)


  trajectories

end

function testTrajectories(trajectories, models; ϵ=1e-3)

  failed_indices = Int[]

  count(
    begin

      output = false

      for i ∈ 1:players

        value = only(models[i](trajectoryₖ[end].state[i]))

        if value < ϵ || value > 1 - ϵ

          output = true; push!(failed_indices, k)

          break

        end

      end

      
      output

    end
    for (k, trajectoryₖ) ∈ enumerate(trajectories)
  )


  failed_indices

end



rng = MersenneTwister(4)
function randomInitialization(sz...; gain=1.0) 

  if isempty(sz)
    return (sz...) -> randomInitialization(sz...; gain=gain)
  end

  2 * gain * (rand(rng, sz...) .- 0.5)

end


players = 30
game = ElFarolBarGame(1, 0, 2, floor(0.7 * players))

inputs = 5
models = SVector{players}(
  f64(Chain(
    Dense(inputs => 4, celu; init=randomInitialization(; gain=0.85)),
    Dense(4 => ACTIONS - 1, sigmoid; init=randomInitialization(; gain=1.0))
  ))
  for i ∈ 1:players
)

metric = @inline @inbounds function(state)

  output = SVector{players}(only(models[i](state[i])) for i ∈ 1:players)

  
  norm(output .- barCapacity(game) / players)

end



total_trajectories = 100
drop_out_rate = 0.15
additional_trajectotires = Int(round(total_trajectories * drop_out_rate))

initial_states = sort(
  collect(
    SVector{players}(
      MVector{inputs}(randomInitialization(inputs; gain=0.5))
      for _ ∈ 1:players
    )
    for _ ∈ 1:(total_trajectories + 2 * additional_trajectotires)
  );
  by=metric, 
  rev=true
)[
  (begin + additional_trajectotires):(end - additional_trajectotires)
]



iterations = Int(1e3)
convexity_factor = 0.5

y_limits = (0.0, sqrt(players) / 2)


trajectories_GD = realizeTrajectories(
  game, players,
  models, inputs;
  initial_states=initial_states,
  iterations=iterations,
  convexity_factor=convexity_factor,
  step=t -> 1e-2
)

failed_indices = testTrajectories(trajectories_GD, models)
if !isempty(failed_indices)
  @warn "Numerical errors detected in trajectories: $(repr(failed_indices))!"
end

filename = "ElFarolBarGame_GD"
newFigure(filename) do matlab_session
  plotTrajectoryOfMean(
    matlab_session, trajectories_GD;
    transform=metric, 
    y_limits=y_limits
  )
end


trajectories_PGD = realizeTrajectories(
  game, players,
  models, inputs;
  initial_states=initial_states,
  iterations=iterations,
  convexity_factor=convexity_factor,
  step=t -> 1e-2,
  with_preconditioning=true,
  process_name="Preconditioning Gradient Ascent",
  prompt="PGD"
)

failed_indices = testTrajectories(trajectories_PGD, models)
if !isempty(failed_indices)
  @warn "Numerical errors detected in trajectories: $(repr(failed_indices))!"
end

filename = "ElFarolBarGame_PGD"
newFigure(filename) do matlab_session
  plotTrajectoryOfMean(
    matlab_session, trajectories_PGD;
    transform=metric, 
    y_limits=y_limits
  )
end


filename = "ElFarolBarGame_GDVsPGD"
newFigure(filename) do matlab_session

  plotTrajectoryOfMean(
    matlab_session, trajectories_PGD;
    transform=metric, 
    y_limits=(7e-2, 2e0),
    log_scale=true,
    grid=true,
    color="[$(97 / 255), $(142 / 255), $(46 / 255)]",
    name="'PGD'"
  )

  plotTrajectoryOfMean(
    matlab_session, trajectories_GD;
    transform=metric,
    new_figure=false,
    line_style="'--'",
    color="[$(206 / 255), $(44 / 255), $(64 / 255)]",
    name="'GD'"
  )

end


nothing;