if !(@isdefined run_once) || !run_once

  using Revise

  includet("../Utilities/AlgorithmTools.jl")
  includet("../Utilities/GradientAscent.jl")
  includet("../Utilities/PlottingTools.jl")


  run_once = true

end



using LinearAlgebra
using StaticArrays
using Random
using DynamicPolynomials
using Flux
using ProgressLogging

import StaticPolynomials as SP


using .GradientAscent
using .PlottingTools



const PLAYERS = 2
const ACTIONS = 3

struct ShapleyGame
  β::Real
end

function pseudoGradientOfExpectedPayoffsFunction(game::ShapleyGame)

  β = game.β

  A = [
    1; 0; β;;
    β; 1; 0;;
    0; β; 1
  ]

  B = [
    -β;  0;  1;;
     1; -β;  0;;
     0;  1; -β
  ]


  @polyvar x[1:ACTIONS]

  gradientFunctions = [
    SP.PolynomialSystem(A * x),
    SP.PolynomialSystem(B * x)
  ]
  
  
  @inline @inbounds function(state)

    SVector{PLAYERS, MVector{ACTIONS}}(
      gradientFunctions[1](state[2]),
      gradientFunctions[2](state[1])
    )

  end

end

function expectedPayoffsFunction(game::ShapleyGame)

  pseudo_gradient_function = pseudoGradientOfExpectedPayoffsFunction(game)


  @inline @inbounds function(state)

    output = pseudo_gradient_function(state)

    SVector{PLAYERS}(
      state[i] ⋅ output[i]
      for i ∈ 1:PLAYERS
    )

  end

end


function gradientFunction(
    game::ShapleyGame, models, inputs;
    with_preconditioning=false,
    convexity_factor=0
  )

  pseudo_gradient_function = pseudoGradientOfExpectedPayoffsFunction(game)


  @inline @inbounds function(state)

    output = SVector{ACTIONS}[]
    jacobians = MMatrix{inputs, ACTIONS}[]

    for i ∈ 1:PLAYERS
      
      outputᵢ, jacobianᵢ = Flux.withjacobian(models[i], state[i])
      jacobianᵢ = only(jacobianᵢ)

      if with_preconditioning
        jacobianᵢ = pinv(jacobianᵢ)
      else 
        jacobianᵢ = jacobianᵢ'
      end


      push!(output, outputᵢ)
      push!(jacobians, jacobianᵢ)

    end


    pseudo_gradient = pseudo_gradient_function(output)

    if(convexity_factor > 0)
      for i ∈ 1:PLAYERS
        pseudo_gradient[i] .-= convexity_factor .* (output[i] .- 1 / ACTIONS)
      end
    end

    
    SVector{PLAYERS}(
      SVector{inputs}(jacobians[i] * pseudo_gradient[i])
      for i ∈ 1:PLAYERS
    )

  end

end



function realizeTrajectories(
    game::ShapleyGame,
    models, inputs;
    initial_states,
    iterations=Int(1e3),
    convexity_factor=0,
    step=1e-2, 
    with_preconditioning=false,
    with_progress=true,
    kwargs...
  )

  total_trajectories = length(initial_states)


  trajectories = []

  for k ∈ axes(initial_states, 1)

    initial_state = initial_states[k]

    println("[$k / $total_trajectories] x₀ ≈ $(repr([
      round.(models[i](initial_state[i]); digits=3)
      for i ∈ 1:PLAYERS
    ]))")


    trajectory = gradientAscent(
      initial_state,
      gradientFunction(
        game, models, inputs;
        with_preconditioning=with_preconditioning,
        convexity_factor=convexity_factor
      ); 
      iterations=iterations, 
      step=step,
      with_progress=with_progress,
      progress_id=@progressid,
      progress_partition=[k - 1, k] ./ total_trajectories,
      progress_name=get(kwargs, :process_name, "Gradient Ascent")
    )

    println("$(get(kwargs, :prompt, "GD")): xₑ ≈ $(repr([
      round.(models[i](trajectory[end].state[i]); digits=3)
      for i ∈ 1:PLAYERS
    ]))")


    push!(trajectories, trajectory)

  end

  @info ProgressLogging.Progress(@progressid; done=true)


  trajectories

end



rng = MersenneTwister(4)
function randomInitialization(sz...; gain=1.0) 

  if isempty(sz)
    return (sz...) -> randomInitialization(sz...; gain=gain)
  end

  2 * gain * (rand(rng, sz...) .- 0.5)

end


game = ShapleyGame(0.2)

inputs = 5
models = SVector{PLAYERS}(
  f64(Chain(
    Dense(inputs => 4, celu; init=randomInitialization(; gain=1.0)),
    Dense(4 => ACTIONS; init=randomInitialization(; gain=1.0)),
    softmax
  ))
  for i ∈ 1:PLAYERS
)

metric = @inline @inbounds function(state)

  norm(reduce(vcat, 
    models[i](state[i]) .- 1 / ACTIONS
    for i ∈ 1:PLAYERS
  ))

end


total_trajectories = 100
drop_out_rate = 0.15
additional_trajectotires = Int(round(total_trajectories * drop_out_rate))

initial_states = sort(
  collect(
    SVector{PLAYERS}(
      MVector{inputs}(randomInitialization(inputs; gain=0.75))
      for _ ∈ 1:PLAYERS
    )
    for _ ∈ 1:(total_trajectories + 2 * additional_trajectotires)
  );
  by=metric, 
  rev=true
)[
  (begin + additional_trajectotires):(end - additional_trajectotires)
]



iterations = Int(1e4)
convexity_factor = 0.5

selected_index = Int(ceil(total_trajectories * 0.1))
y_limits = (0.0, (ACTIONS - 1) / ACTIONS)
plot_type = PlottingTools.PlotTypes.ScatterPlot


trajectories_GD = realizeTrajectories(
  game, models, inputs;
  initial_states=initial_states,
  iterations=iterations,
  convexity_factor=convexity_factor,
  step=t -> 1e-2
)

filename = "ShapleyGame_GD"
newFigure(filename) do matlab_session
  plotTrajectoryOfMean(
    matlab_session, trajectories_GD;
    transform=metric, 
    y_limits=y_limits
  )
end

trajectory = trajectories_GD[selected_index]
newFigure(filename; append=true) do matlab_session
  plotTrajectoryInSimplex(
    matlab_session, trajectory;
    transform=state -> (models[1](state[1]))[1:PLAYERS],
    color="'Blue'",
    plot_type=plot_type
  )
end

newFigure(filename; append=true) do matlab_session
  plotTrajectoryInSimplex(
    matlab_session, trajectory;
    transform=state -> (models[2](state[2]))[1:PLAYERS],
    color="'Red'",
    plot_type=plot_type
  )
end


trajectories_PGD = realizeTrajectories(
  game, models, inputs;
  initial_states=initial_states,
  iterations=iterations,
  convexity_factor=convexity_factor,
  step=t -> 1e-2,
  with_preconditioning=true,
  process_name="Preconditioning Gradient Ascent",
  prompt="PGD"
)

filename = "ShapleyGame_PGD"
newFigure(filename) do matlab_session
  plotTrajectoryOfMean(
    matlab_session, trajectories_PGD;
    transform=metric, 
    y_limits=y_limits
  )
end

trajectory = trajectories_PGD[selected_index]
newFigure(filename; append=true) do matlab_session
  plotTrajectoryInSimplex(
    matlab_session, trajectory;
    transform=state -> (models[1](state[1]))[1:PLAYERS],
    color="'Blue'",
    plot_type=plot_type
  )
end

newFigure(filename; append=true) do matlab_session
  plotTrajectoryInSimplex(
    matlab_session, trajectory;
    transform=state -> (models[2](state[2]))[1:PLAYERS],
    color="'Red'",
    plot_type=plot_type
  )
end


filename = "ShapleyGame_GDVsPGD"
newFigure(filename) do matlab_session

  plotTrajectoryOfMean(
    matlab_session, trajectories_PGD;
    transform=metric, 
    y_limits=(1e-8, 1e3),
    log_scale=true,
    grid=true,
    color="[$(97 / 255), $(142 / 255), $(46 / 255)]",
    name="'PGD'",
    legends=true
  )

  plotTrajectoryOfMean(
    matlab_session, trajectories_GD;
    transform=metric,
    new_figure=false,
    line_style="'--'",
    color="[$(206 / 255), $(44 / 255), $(64 / 255)]",
    name="'GD'",
    legends=true
  )

end


nothing;