module GameDynamics

using LinearAlgebra
using DynamicPolynomials
using PolynomialTools

import DifferentialEquations: ODEProblem, solve
import DiffEqCallbacks: PresetTimeCallback

export sample, sample_with_velocities, infer_control_inputs
export replicator, replicator_as_polynomial_system

"""
    sample(f, x₀, t_data; <keyword arguments>)

Generate a sample from the ODE system given by `f` initialized in state `x₀`.

`t_data` should provide the sampling times in ascending order.

# Arguments
- `t₀=0.0`: the ODE system's starting time.

# Examples
```jldoctest
using DynamicPolynomials
using GameDynamics

@polyvar x₁ x₂

v₁ = [2 * x₂ - 1, 1 - 2 * x₂]
v₂ = [1 - 2 * x₁, 2 * x₁ - 1]

f = replicator_as_polynomial_system([v₁, v₂])
sample(f, [0.125, 0.625], [0.0, 1.0])

# output

[[0.125, 0.625], [0.2883885849629817, 0.8536761765271454]]
```
"""
function sample(f, x₀, t_data; t₀=0.0)
  problem = ODEProblem((x, _, _) -> f(x), x₀, (t₀, t_data[end]))
  solution = solve(problem; saveat=t_data)
  solution.u
end

# DOCME
function sample(f, x₀, t_data, w_data; t₀=0.0)
  problem = ODEProblem((x, w, _) -> f([x; w]), x₀, (t₀, t_data[end]), w_data[t₀])
  callback = PresetTimeCallback(
    collect(keys(w_data)),
    integrator -> integrator.p = w_data[integrator.t];
    save_positions=(false, false)
  )
  solution = solve(problem; saveat=t_data, callback=callback)
  solution.u, infer_control_inputs(t_data, w_data)
end

# DOCME
function sample_with_velocities(f, x₀, t_data; kwargs...)
  x_data = sample(f, x₀, t_data; kwargs...)
  ẋ_data = f.(x_data)
  x_data, ẋ_data
end

function sample_with_velocities(f, x₀, t_data, w_data; kwargs...)
  x_data, w_inferred_data = sample(f, x₀, t_data, w_data; kwargs...)
  ẋ_data = [f([x; w]) for (x, w) ∈ zip(x_data, w_inferred_data)]
  x_data, w_inferred_data, ẋ_data
end

# DOCME
function infer_control_inputs(t_data, w_data)

  w_data_remaining = Iterators.Stateful(sort!(collect(pairs(w_data))))
  wₖ_key, wₖ_value = first(w_data_remaining)
  @assert wₖ_key ≤ t_data[begin] "The control input at the earliest given time cannot be inferred. Verify that the given control inputs can be indexed for some time smaller or equal to the earliest given time."

  K = length(t_data)
  w_inferred_data = []
  sizehint!(w_inferred_data, K)

  if isempty(w_data_remaining) 
    @goto NO_CONTROL_INPUTS 
  end

  wₖ₊₁_key, wₖ₊₁_value = first(w_data_remaining)
  for t ∈ t_data
    while t > wₖ₊₁_key
      wₖ_value = wₖ₊₁_value
      if isempty(w_data_remaining)
        @goto NO_CONTROL_INPUTS
      end
      wₖ₊₁_key, wₖ₊₁_value = first(w_data_remaining)
    end
    push!(w_inferred_data, wₖ_value)
  end

  @label NO_CONTROL_INPUTS
  K_remaining = K - length(w_inferred_data)
  if K_remaining > 0
    append!(w_inferred_data, fill(wₖ_value, K_remaining))
  end

  w_inferred_data

end

# DOCME
function replicator(x, v)
  ẋ = Polynomial[]
  sizehint!(ẋ, length(x))
  for (xᵢ, vᵢ) ∈ zip(x, v)
    @assert 1 ≤ length(xᵢ) == length(vᵢ) - 1
    uᵢ = xᵢ ⋅ (@view vᵢ[begin:end - 1]) + (1 - sum(xᵢ)) * vᵢ[end]
    for (xᵢⱼ, vᵢⱼ) ∈ zip(xᵢ, vᵢ)
      push!(ẋ, xᵢⱼ * (vᵢⱼ - uᵢ))
    end
  end
  ẋ
end

"""
    replicator_as_polynomial_system([x, ]v)

Return a system replicator equations for a game defined by the marginal payoffs `v`.

`x` should be an arbitrary collection of the players' individual states.

# Examples
```jldoctest
using DynamicPolynomials
using GameDynamics

@polyvar x₁ x₂

v₁ = [2 * x₂ - 1, 1 - 2 * x₂]
v₂ = [1 - 2 * x₁, 2 * x₁ - 1]

f = replicator_as_polynomial_system([v₁, v₂])
f([0.5, 0.5])

# output

2-element Vector{Float64}:
 0.0
 0.0
```
"""
function replicator_as_polynomial_system(x, v; vars=collect(reduce(vcat, x)))
  polynomial_system(replicator(x, v), vars)
end

function replicator_as_polynomial_system(x, v, w)
  replicator_as_polynomial_system(x, v; vars=[reduce(vcat, x); w])
end

module Presets

  export SH, MP, RPS

  include("presets/stag_hunt_game.jl")
  import .StagHunt

  include("presets/matching_pennies_game.jl")
  import .MatchingPennies

  include("presets/rock_paper_scissors_game.jl")
  import .RockPaperScissors

  const SH  = StagHunt
  const MP  = MatchingPennies
  const RPS = RockPaperScissors

end # module Presets

import .Presets

# DOCME
Presets.StagHunt

# DOCME
Presets.MatchingPennies

# DOCME
Presets.RockPaperScissors

end # module GameDynamics
