if !(@isdefined run_once) || !run_once

  using Revise

  includet("../Utilities/AlgorithmTools.jl")
  includet("../Utilities/GradientAscent.jl")
  includet("../Utilities/PlottingTools.jl")


  run_once = true

end



using LinearAlgebra
using StaticArrays
using Random
using DynamicPolynomials
using Flux
using ProgressLogging

import StaticPolynomials as SP


using .GradientAscent
using .PlottingTools



const PLAYERS = 2
const ACTIONS = 2
const INPUTS = 1

function pseudoGradientOfExpectedPayoffsFunction()

  @polyvar x

  gradient_function = SP.Polynomial(2 * (2x - 1))


  @inline @inbounds function(state)

    MVector{PLAYERS}(
        gradient_function([state[2]]),
      - gradient_function([state[1]])
    )

  end

end

function expectedPayoffsFunction()

  pseudo_gradient_function = pseudoGradientOfExpectedPayoffsFunction()


  @inline @inbounds function(state)

    output = prod(pseudo_gradient_function(state)) / 4
    
    SVector{PLAYERS}(
      - output,
        output
    )

  end

end


function gradientFunction(models; with_preconditioning=false, convexity_factor=0)

  pseudo_gradient_function = pseudoGradientOfExpectedPayoffsFunction()


  @inline @inbounds function(state)

    output = []
    jacobians = MVector{INPUTS}[]
  
    for i ∈ axes(state, 1)
  
      outputᵢ, jacobianᵢ = Flux.withjacobian(models[i], state[i])
      jacobianᵢ = only(jacobianᵢ)
  
      if with_preconditioning
        jacobianᵢ = pinv(jacobianᵢ)'
      else
        jacobianᵢ = jacobianᵢ'
      end
  

      push!(output, only(outputᵢ))
      push!(jacobians, jacobianᵢ)
  
    end
    

    pseudo_gradient = pseudo_gradient_function(output)

    if(convexity_factor > 0)
      pseudo_gradient .-= convexity_factor .* (output .- 1 / ACTIONS)
    end
  

    SVector{PLAYERS}(
      MVector{INPUTS}(jacobians[i] * pseudo_gradient[i])
      for i ∈ axes(state, 1)
    )
  
  end

end



function realizeTrajectories(
    models;
    initial_states,
    iterations=Int(1e3),
    convexity_factor=0,
    step=1e-2, 
    with_preconditioning=false,
    with_progress=true,
    kwargs...
  )

  total_trajectories = length(initial_states)


  trajectories = []
  
  for k ∈ axes(initial_states, 1)

    initial_state = initial_states[k]


    println("x₀ ≈ $(repr([
      round.(only(models[i](initial_state[i])); digits=3)
      for i ∈ 1:PLAYERS
    ]))")


    trajectory = gradientAscent(
      initial_state,
      gradientFunction(
        models;
        with_preconditioning=with_preconditioning,
        convexity_factor=convexity_factor
      ); 
      iterations=iterations, 
      step=step,
      with_progress=with_progress,
      progress_id=@progressid,
      progress_partition=[k - 1, k] ./ total_trajectories,
      process_name=get(kwargs, :process_name, "Gradient Ascent")
    )

    println("$(get(kwargs, :prompt, "GD")): xₑ ≈ $(repr([
      round.(only(models[i](trajectory[end].state[i])); digits=3)
      for i ∈ 1:PLAYERS
    ]))")


    push!(trajectories, trajectory)

  end

  @info ProgressLogging.Progress(@progressid; done=true)


  trajectories

end



rng = MersenneTwister(3)
function randomInitialization(sz...; gain=1.0) 

  if isempty(sz)
    return (sz...) -> randomInitialization(sz...; gain=gain)
  end

  2 * gain * (rand(rng, sz...) .- 0.5)

end


models = SVector{PLAYERS}(
  f64(Chain(
    Dense(INPUTS => 1, celu; init=randomInitialization(; gain=1.0)),
    Dense(1 => ACTIONS - 1, sigmoid; init=randomInitialization(; gain=1.5))
  ))
  for i ∈ 1:PLAYERS
)

metric = @inline @inbounds function(state)

  norm(reduce(vcat, 
    only(models[i](state[i])) - 1 / ACTIONS
    for i ∈ 1:PLAYERS
  ))

end


initial_state = SVector{PLAYERS}(
  MVector{INPUTS, Float64}(1.25),
  MVector{INPUTS, Float64}(2.25)
)


iterations=Int(5e3)
convexity_factor=0.75

y_limits₁ = (0.0, (ACTIONS - 1) / ACTIONS)
x_limits₂ = (-5, 5)
y_limits₂ = x_limits₂


trajectory_GD = only(realizeTrajectories(
  models;
  initial_states=[initial_state],
  iterations=iterations,
  convexity_factor=convexity_factor,
  step=t -> 1e-2
))


filename = "MatchingPenniesGame_GD"
newFigure(filename) do matlab_session
  plotTrajectory(
    matlab_session, trajectory_GD;
    transform=metric, 
    y_limits=y_limits₁
  )
end

newFigure(filename; append=true) do matlab_session

  x_data = LinRange(y_limits₂..., 100)
  y_data = LinRange(x_limits₂..., 100)

  plotHeatMap(
    matlab_session,
    [
      metric([[x], [y]])
      for y ∈ y_data, x ∈ x_data
    ],
    x_limits=x_limits₂,
    y_limits=y_limits₂
  )


  transform = state -> [only(state[i]) for i ∈ 1:PLAYERS]

  plotTrajectoryInSquare(
    matlab_session, trajectory_GD;
    transform=transform,
    new_figure=false,
    plot_type=PlottingTools.PlotTypes.Plot,
    color₁="'Red'",
    line_width=0.75
  )

  plotTrajectoryInSquare(
    matlab_session, trajectory_GD;
    transform=transform,
    new_figure=false,
    plot_type=PlottingTools.PlotTypes.ScatterPlot,
    color₁="'Black'",
    line_width=0.1,
    marker_size=1.25,
  )
  
end

newFigure(filename; append=true) do matlab_session

  x_data = LinRange(y_limits₂..., 100)
  y_data = LinRange(x_limits₂..., 100)

  plotHeatMap(matlab_session, [
    norm(vcat(x, y) .- 1 / ACTIONS)
    for  y ∈ y_data, x ∈ x_data
  ])


  transform = state -> [only(models[i](state[i])) for i ∈ 1:PLAYERS]

  plotTrajectoryInSquare(
    matlab_session, trajectory_GD;
    transform=transform,
    new_figure=false,
    plot_type=PlottingTools.PlotTypes.Plot,
    color₁="'Red'",
    line_width=0.75
  )

  plotTrajectoryInSquare(
    matlab_session, trajectory_GD;
    transform=transform,
    new_figure=false,
    plot_type=PlottingTools.PlotTypes.ScatterPlot,
    color₁="'Black'",
    line_width=0.1,
    marker_size=1.25,
  )
  
end


trajectory_PGD = only(realizeTrajectories(
  models;
  initial_states=[initial_state],
  iterations=iterations,
  convexity_factor=convexity_factor,
  step=t -> 1e-2,
  with_preconditioning=true,
  process_name="Preconditioning Gradient Ascent",
  prompt="PGD"
))

filename = "MatchingPenniesGame_PGD"
newFigure(filename) do matlab_session
  plotTrajectory(
    matlab_session, trajectory_PGD;
    transform=metric, 
    y_limits=y_limits₁
  )
end

newFigure(filename; append=true) do matlab_session

  x_data = LinRange(y_limits₂..., 100)
  y_data = LinRange(x_limits₂..., 100)

  plotHeatMap(
    matlab_session,
    [
      metric([[x], [y]])
      for  y ∈ y_data, x ∈ x_data
    ],
    x_limits=x_limits₂,
    y_limits=y_limits₂
  )


  trandform = state -> [only(state[i]) for i ∈ 1:PLAYERS]

  plotTrajectoryInSquare(
    matlab_session, trajectory_PGD;
    transform=trandform,
    new_figure=false,
    plot_type=PlottingTools.PlotTypes.Plot,
    color₁="'Red'",
    line_width=0.75
  )

  plotTrajectoryInSquare(
    matlab_session, trajectory_PGD;
    transform=trandform,
    new_figure=false,
    plot_type=PlottingTools.PlotTypes.ScatterPlot,
    color₁="'Black'",
    line_width=0.1,
    marker_size=2.25,
  )
  
end

newFigure(filename; append=true) do matlab_session

  x_data = LinRange(y_limits₂..., 100)
  y_data = LinRange(x_limits₂..., 100)

  plotHeatMap(matlab_session, [
    norm(vcat(x, y) .- 1 / ACTIONS)
    for y ∈ y_data, x ∈ x_data 
  ])


  trandform = state -> [only(models[i](state[i])) for i ∈ 1:PLAYERS]

  plotTrajectoryInSquare(
    matlab_session, trajectory_PGD;
    transform=trandform,
    new_figure=false,
    plot_type=PlottingTools.PlotTypes.Plot,
    color₁="'Red'",
    line_width=0.75
  )

  plotTrajectoryInSquare(
    matlab_session, trajectory_PGD;
    transform=trandform,
    new_figure=false,
    plot_type=PlottingTools.PlotTypes.ScatterPlot,
    color₁="'Black'",
    line_width=0.1,
    marker_size=2.25,
  )

end


filename = "MatchingPenniesGame_GDVsPGD"
newFigure(filename) do matlab_session

  plotTrajectory(
    matlab_session, trajectory_PGD;
    transform=metric, 
    y_limits=(1e-16, 1e4),
    log_scale=true,
    grid=true,
    color="[$(97 / 255), $(142 / 255), $(46 / 255)]",
    name="'PGD'",
    legends=true
  )

  plotTrajectory(
    matlab_session, trajectory_GD;
    transform=metric,
    new_figure=false,
    line_style="'--'",
    color="[$(206 / 255), $(44 / 255), $(64 / 255)]",
    name="'GD'",
    legends=true
  )

end


nothing;