module NormalFormGames

using LinearAlgebra
using DynamicPolynomials
using PolynomialTools

export marginal_payoffs, marginal_payoffs_as_polynomial_systems, variables

"""
    marginal_payoffs(A, x)

Compute the marginal payoffs of the normal-form game defined by `A` at state `x`.

`x` should be an arbitrary collection of the players' individual states.

See also [`marginal_payoffs_as_polynomial_systems`](@ref).

# Examples
```jldoctest
using NormalFormGames
using DynamicPolynomials

@polyvar x₁ x₂

A₁ = [
  -1  1;
   1 -1
]
A₂ = - A₁

marginal_payoffs([A₁, A₂], [[x₁], [x₂]])

# output

2-element Vector{Vector{Polynomial{DynamicPolynomials.Commutative{DynamicPolynomials.CreationOrder}, Graded{LexOrder}, Int64}}}:
 [1 - 2x₂, -1 + 2x₂]
 [-1 + 2x₁, 1 - 2x₁]
```

```jldoctest
using NormalFormGames

A₁ = [
  -1  1;
   1 -1
]
A₂ = - A₁

marginal_payoffs([A₁, A₂], [[0.5], [0.5]])

# output

2-element Vector{Vector{Float64}}:
 [0.0, 0.0]
 [0.0, 0.0]
```
"""
function marginal_payoffs(A, x)
  x = [[xᵢ; 1 - sum(xᵢ)] for xᵢ ∈ x]
  collect(
    begin
      b = vec(prod.(Iterators.product(
        x[begin:i - 1]..., 
        x[i + 1:end]...)
      ))
      [vec(Aᵢⱼ) ⋅ b for Aᵢⱼ ∈ eachslice(Aᵢ; dims=i)]
    end
    for (i, Aᵢ) ∈ pairs(A)
  )
end

"""
    marginal_payoffs_as_polynomial_systems(A; <keyword arguments>)

Return the marginal payoffs of the normal-form game defined by `A` as polynomial systems.

See also [`marginal_payoffs`](@ref).

# Arguments
- `size=size(first(A))`: the game size

# Examples
```jldoctest
using NormalFormGames

A₁ = [
  -1  1;
   1 -1
]
A₂ = - A₁

v = marginal_payoffs_as_polynomial_systems([A₁, A₂])

v[1]([0.5, 0.5]), v[2]([0.5, 0.5])

# output

([0.0, 0.0], [0.0, 0.0])
```
"""
function marginal_payoffs_as_polynomial_systems(A; size=size(first(A)))
  x = collect(
    only(@polyvar x[i, Base.OneTo(mᵢ - 1)])
    for (i, mᵢ) ∈ pairs(size)
  )
  polynomial_systems(marginal_payoffs(A, x))
end

# DOCME
function variables(v)
  x = Iterators.Stateful(DynamicPolynomials.variables(v))
  collect(
    begin
      mᵢ = length(vᵢ)
      @assert mᵢ ≥ 2 "Games where the action sets of one or more players are singletons are not supported. Please verify that each vector of marginal payoffs has at least two entries."
      first(x, mᵢ - 1)
    end
    for vᵢ ∈ v
  )
end

module Presets

  export SH, MP, RPS

  include("presets/stag_hunt_game.jl")
  import .StagHunt
  
  include("presets/matching_pennies_game.jl")
  import .MatchingPennies
  
  include("presets/rock_paper_scissors_game.jl")
  import .RockPaperScissors

  const SH  = StagHunt
  const MP  = MatchingPennies
  const RPS = RockPaperScissors

end # module Presets

import .Presets

@doc raw"""
    Presets.SH = Presets.StagHunt

Presets for Stag Hunt game given by the payoff matrices
```math
  A_1 = \begin{pmatrix}
    4 & 3 \\
    0 & 1
  \end{pmatrix}
  \ \text{and}\ 
  A_2 = \begin{pmatrix}
    4 & 0 \\
    3 & 1
  \end{pmatrix}
```

See also [`Presets.MatchingPennies`](@ref), and [`Presets.RockPaperScissors`](@ref).
"""
Presets.StagHunt

@doc raw"""
    Presets.MP = Presets.MatchingPennies
    
Presets for a Matching Pennies game given by the payoff matrices
```math
  A_1 = - A_2 = \begin{pmatrix*}[r]
     1 & -1 \\
    -1 &  1
  \end{pmatrix*}
```

See also [`Presets.StagHunt`](@ref), and [`Presets.RockPaperScissors`](@ref).
"""
Presets.MatchingPennies

@doc raw"""
    Presets.RPS = Presets.RockPaperScissors
    
Presets for a Rock-Papaer-Scissors game given by the payoff matrices
```math
  A_1 = - A_2 = \begin{pmatrix*}[r]
     0 & -1 &  1 \\
     1 &  0 & -1 \\
    -1 &  1 &  0
  \end{pmatrix*}
```

See also [`Presets.StagHunt`](@ref), and [`Presets.MatchingPennies`](@ref).
"""
Presets.RockPaperScissors

end # module NormalFormGames