module SIARMethod

using DynamicPolynomials
using JuMP

import SCS: SCS

import ...DEFAULT_ZERO_THRESHOLD, ...DEFAULT_ABSOLUTE_TOLERANCE
import ...rescale_data!

export SINDy

# DOCME
function SIAR(
  x_data,
  ẋ_data,
  b;
  SOCP_optimizer=optimizer_with_attributes(
    SCS.Optimizer, 
    "eps_abs" => DEFAULT_ABSOLUTE_TOLERANCE,
    "eps_rel" => 0.0,
    "max_iters" => 1e6
  ), 
  SI_constraints=(),
  zero_threshold=DEFAULT_ZERO_THRESHOLD,
  η=1e-4,
  rescale=true,
  verbose=false
)
  model = Model(SOCP_optimizer) 
  if !verbose
    set_silent(model)
  end
  ẋ_idx = eachindex(first(ẋ_data))
  @variable(model, ξ[eachindex(b), ẋ_idx])
  @variable(model, ϵ_L₂_norm[ẋ_idx])

  Θ = [bᵢ(xₖ) for xₖ ∈ x_data, bᵢ ∈ b]
  Ẋ = [ẋₖ[i] for ẋₖ ∈ ẋ_data, i ∈ ẋ_idx]
  if rescale
    scaling = rescale_data!(Θ) ./ rescale_data!(Ẋ)'
    p = [
      polynomial(ξᵢ ./ scalingᵢ, b) 
      for (ξᵢ, scalingᵢ) ∈ zip(eachcol(ξ), eachcol(scaling))
    ]
    zero_threshold = zero_threshold * scaling
  else
    p = [polynomial(ξᵢ, b) for ξᵢ ∈ eachcol(ξ)]
  end

  for SI_constraint ∈ SI_constraints
    SI_constraint(model, p)
  end

  K = length(ẋ_data)
  for (ϵ_L₂_normᵢ, ξᵢ, ẋᵢ_data) ∈ zip(ϵ_L₂_norm, eachcol(ξ), eachcol(Ẋ))
    @constraint(model,
      [ϵ_L₂_normᵢ; Θ * ξᵢ - ẋᵢ_data] ∈ MOI.SecondOrderCone(1 + K)
    )
  end

  # If the LASSO parameter is non-positive we simplify the model.
  if η > 0
    @variable(model, ξ_L₁_norm[ẋ_idx])
    b_size = length(b)
    for (ξ_L₁_normᵢ, ξᵢ) ∈ zip(ξ_L₁_norm, eachcol(ξ))
      @constraint(model, [ξ_L₁_normᵢ; ξᵢ] in MOI.NormOneCone(1 + b_size))
    end
    @objective(model, Min, sum(ϵ_L₂_norm) + η * sum(ξ_L₁_norm))
  else
    @objective(model, Min, sum(ϵ_L₂_norm))
  end

  # Optimize the model.
  optimize!(model)
  status = termination_status(model)
  if status ≠ MOI.OPTIMAL
    @warn "SIAR terminated with status $(repr(status))"
  end
  ξ_value = value.(ξ)
  ξ_value[- zero_threshold .< ξ_value .< zero_threshold] .= 0
  if rescale
    ξ_value ./= scaling
  end

  p_values = [polynomial(ξᵢ_value , b) for ξᵢ_value ∈ eachcol(ξ_value)]
  if length(p_values) == 1
    @info "SIAR found a polynomial" p=only(p_values)
  else
    @info "SIAR found multiple polynomials" p=p_values
  end

  p_values

end

function SIAR(x_data, w_data, ẋ_data, b; kwargs...)
  SIAR(
    [vcat(x, w) for (x, w) ∈ zip(x_data, w_data)],
    ẋ_data,
    b;
    kwargs...
  )
end

end # module SIARMethod