module SetInvariance

using DynamicPolynomials
using SemialgebraicSets
using SumOfSquares

import ...Domains: simplex

export basic_semialgebraic_set_invariance, simplex_invariance

# DOCME
function subs_with_early_termination(S, args...)

  h = equalities(S)
  h′ = empty(h)
  sizehint!(h′, length(h))
  for hᵢ ∈ h
    hᵢ = subs(hᵢ, args...)
    if !isconstant(hᵢ) 
      push!(h′, hᵢ)
    elseif leading_coefficient(hᵢ) ≠ 0
      return nothing
    end
  end

  g = inequalities(S)
  g′ = empty(g)
  sizehint!(g′, length(g))
  for gᵢ ∈ g
    gᵢ = subs(gᵢ, args...)
    if !isconstant(gᵢ) 
      push!(g′, gᵢ)
    elseif leading_coefficient(gᵢ) < 0
      return nothing
    end
  end

  subspace = !isempty(h′) ? algebraic_set(h′) : FullSpace()
  !isempty(g′) ? BasicSemialgebraicSet(subspace, g′) : subspace

end

# DOCME
function basic_semialgebraic_set_invariance(model, S, ġ, ḣ; domain=S, maxdegree=:auto) 
  refs = ConstraintRef[]
  h′ = equalities(domain)
  for (gᵢ, ġᵢ) ∈ zip(inequalities(S), ġ)
    subspaceᵢ = algebraic_set(!isempty(h′) ? [h′; gᵢ] : [gᵢ])
    g′ = filter(≠(gᵢ), inequalities(domain))
    domainᵢ = !isempty(g′) ? BasicSemialgebraicSet(subspaceᵢ, g′) : subspaceᵢ
    maxdegreeᵢ = maxdegree == :auto ? 
      SumOfSquares.default_maxdegree(ġᵢ, domainᵢ) : 
      maxdegree
    push!(refs, @constraint(model, ġᵢ ∈ SOSCone(), domain=domainᵢ, maxdegree=maxdegreeᵢ))
  end
  for (hᵢ, ḣᵢ) ∈ zip(equalities(S), ḣ)
    domainᵢ = algebraic_set(hᵢ ∉ h′ ? [h′; hᵢ] : h′)
    push!(refs, @constraint(model, ḣᵢ == 0, domain=domainᵢ))
  end
  refs
end

# DOCME
function simplex_invariance(model, x, ẋ; domain=simplex(x), maxdegree=:auto)
  refs = ConstraintRef[]
  for (xᵢ, ẋᵢ) ∈ zip(x, ẋ)
    domainᵢ = subs_with_early_termination(domain, xᵢ => 0)
    if isnothing(domainᵢ)
      continue
    end
    ġᵢ = subs(ẋᵢ, xᵢ => 0)
    maxdegreeᵢ = maxdegree == :auto ? 
      SumOfSquares.default_maxdegree(ġᵢ, domainᵢ) : 
      maxdegree
    push!(refs, @constraint(model, ġᵢ ∈ SOSCone(), domain=domainᵢ, maxdegree=maxdegreeᵢ))
  end
  if length(x) == 1
    x₁ = only(x)
    ẋ₁ = only(ẋ)
    domain₂ = subs_with_early_termination(domain, x₁ => 1)
    if isnothing(domain₂)
      return refs
    end
    ġ₂ = subs(- ẋ₁, x₁ => 1)
    maxdegree₂ = maxdegree == :auto ? 
      SumOfSquares.default_maxdegree(ġ₂, domain₂) : 
      maxdegree
    push!(refs, @constraint(model, ġ₂ ∈ SOSCone(), domain=domain₂, maxdegree=maxdegree₂))
  else
    append!(refs, basic_semialgebraic_set_invariance(
      model, 
      BasicSemialgebraicSet(FullSpace(), [1 - sum(x)]), 
      [- sum(ẋ)], 
      []; 
      domain=domain, 
      maxdegree=maxdegree
    ))
  end
  refs
end

end # module SetInvariance