using LinearAlgebra
using Random
using DynamicPolynomials
using JuMP
using GDIAC
using Logging
using MAT

import MosekTools: Mosek

const NFG = NormalFormGames
const RPS = GameDynamics.Presets.RPS
const SI = SideInformation

Random.seed!(0)

global_logger(ConsoleLogger(Warn))

function extract_data_from_polynomial(p)
    coeffs = coefficients(p)
    terms = string.(monomials(p))  # Convert terms to strings for easier handling in MATLAB
    return coeffs, terms
end

function rand_distribution(n) 
  x = Float64[]
  sizehint!(x, n - 1)
  zᵢ = 0
  for zᵢ₊₁ ∈ sort!(rand(n - 1))
    push!(x, zᵢ₊₁ - zᵢ)
    zᵢ = zᵢ₊₁
  end
  x
end

function maxbasis_with_linear_control_inputs(x, w, max)
  basis = convert(Vector, monomials(x, 0:max))
  n = binomial(length(x) + max - 1, max)
  u = @view basis[begin:end - n]
  sizehint!(basis, length(basis) + length(u) * length(w))
  for wᵢ ∈ w
    append!(basis, u * wᵢ)
  end
  monomial_vector(basis)
end

@polyvar x₁[1:2] x₂[1:2] w[1:4]
x = [x₁, x₂]

A₁ =[
   0.25 -1  1;
   1  0.25 -1;
  -1  1  0.25
]+ [
  0    w[1] w[2];
  w[3] 0    0;
  w[4] 0    0
]

A₂ = - A₁
v = NFG.marginal_payoffs([A₁, A₂], x)

f            = replicator_as_polynomial_system(x, v, vec(w))
x₀_train     = [0.5;0.015;0.5;0.15]
t_train_data = collect(LinRange(0.0, 1.0, 11))
w_train_data = Dict(t => clamp.(0.3*randn(length(w)), -1, 1) for t ∈ t_train_data)
x_train_data, w_inferred_data, ẋ_train_data = sample_with_velocities(
  f, 
  x₀_train, 
  t_train_data,
  w_train_data
)

t_all_data = collect(LinRange(0.0, 1.0, 101))
x_all_data, w_all_data= sample(
  f, 
  x₀_train, 
  t_all_data,
  w_train_data
)

basis = maxbasis_with_linear_control_inputs(reduce(vcat, x), vec(w), 4)

zero_threshold = 1e-3
p = SINDySLS(
  x_train_data, 
  w_inferred_data, 
  ẋ_train_data, 
  basis; 
  zero_threshold=zero_threshold,
  max_iters=1
)

all_sindy_coeffs = []
all_sindy_terms = []

for i = 1:length(p)
    coeffs, terms = extract_data_from_polynomial(p[i])
    push!(all_sindy_coeffs, coeffs)
    push!(all_sindy_terms, terms)
end

significant_digits = Int(ceil(- log10(zero_threshold)))
println("SINDySLS Results:")
println("p₁₁ ≈ $(repr(round(p[1]; digits=significant_digits)))")
println("p₂₂ ≈ $(repr(round(p[2]; digits=significant_digits)))")
println("p₂₁ ≈ $(repr(round(p[3]; digits=significant_digits)))")
println("p₂₂ ≈ $(repr(round(p[4]; digits=significant_digits)))")

domain = Domains.product_of_simplices(x) 
SOCP_optimizer=optimizer_with_attributes(
  Mosek.Optimizer, 
  "MSK_DPAR_INTPNT_CO_TOL_PFEAS"   => 1e-12,
  "MSK_DPAR_INTPNT_CO_TOL_DFEAS"   => 1e-12,
  "MSK_DPAR_INTPNT_CO_TOL_REL_GAP" => 1e-12
)
p = separated(
  x_train_data, 
  w_inferred_data, 
  ẋ_train_data;
  sizes=length.(x)
) do i, x_dataᵢ, w_dataᵢ, ẋ_dataᵢ
  SIAR(
    x_dataᵢ, 
    w_dataᵢ,
    ẋ_dataᵢ, 
    basis;
    SOCP_optimizer=SOCP_optimizer,
    SI_constraints = [
      (model, p) -> SI.simplex_invariance(model, x[i], p; domain=domain),
      (model, p) -> SI.positive_correlation(model, p, v[i]; domain=domain, maxdegree=7)
    ],
    η=0.0,
    zero_threshold=zero_threshold,
    verbose=true
  )
end

all_siar_coeffs = []
all_siar_terms = []

for i = 1:length(p)
    coeffs, terms = extract_data_from_polynomial(p[i])
    push!(all_siar_coeffs, coeffs)
    push!(all_siar_terms, terms)
end

println("SIAR Results:")
println("p₁₁ ≈ $(repr(round(p[1]; digits=significant_digits)))")
println("p₂₂ ≈ $(repr(round(p[2]; digits=significant_digits)))")
println("p₂₁ ≈ $(repr(round(p[3]; digits=significant_digits)))")
println("p₂₂ ≈ $(repr(round(p[4]; digits=significant_digits)))")

# Save to .mat file
folder_path = "YOUR_FOLDER_PATH"
file_name = "data_JULIA_rps.mat"
full_path = joinpath(folder_path, file_name)
matfile = matopen(full_path, "w")
write(matfile, "x_train_data", x_train_data)
write(matfile, "w_inferred_data", w_inferred_data)
write(matfile, "xdot_train_data", ẋ_train_data)
write(matfile, "x_all_data", x_all_data)
write(matfile, "w_all_data", w_all_data)
write(matfile, "all_sindy_coeffs", all_sindy_coeffs)
write(matfile, "all_sindy_terms", all_sindy_terms)
write(matfile, "all_siar_coeffs", all_siar_coeffs)
write(matfile, "all_siar_terms", all_siar_terms)
close(matfile)