using LinearAlgebra
using Random
using DynamicPolynomials
using JuMP
using GDIAC
using Logging
using MAT

import MosekTools: Mosek

const NFG = NormalFormGames
const SI = SideInformation
const SH = GameDynamics.Presets.SH

Random.seed!(0)

global_logger(ConsoleLogger(Warn))

function extract_data_from_polynomial(p)
    coeffs = coefficients(p)
    terms = string.(monomials(p))  # Convert terms to strings for easier handling in MATLAB
    return coeffs, terms
end

function rand_distribution(n) 
  x = Float64[]
  sizehint!(x, n - 1)
  zᵢ = 0
  for zᵢ₊₁ ∈ sort!(rand(n - 1))
    push!(x, zᵢ₊₁ - zᵢ)
    zᵢ = zᵢ₊₁
  end
  x
end

function maxbasis_with_linear_control_inputs(x, w, max)
  basis = convert(Vector, monomials(x, 0:max))
  n = binomial(length(x) + max - 1, max)
  u = @view basis[begin:end - n]
  sizehint!(basis, length(basis) + length(u) * length(w))
  for wᵢ ∈ w
    append!(basis, u * wᵢ)
  end
  monomial_vector(basis)
end

@polyvar x₁ x₂ w[1:3]
x = [[x₁], [x₂]]

A₁ = [4 1;3 3] + [
  w[1] w[2];
  w[3] 0  
]
A₂ = A₁'

v = NFG.marginal_payoffs([A₁, A₂], x)

f            = replicator_as_polynomial_system(x, v, vec(w))

function g(u)
  y = u[1:2]
  x = y .^ 2 ./ (y .^ 2 .+ (1 .- y) .^ 2)
  w = u[3:5]
  ẋ = f([x; w])
  ẋ .* (y .^ 2 .+ (1 .- y) .^ 2)
end

x₀_train     = [rand_distribution(2); rand_distribution(2)]
t_train_data = collect(LinRange(0.0, 0.3, 4))
x₀_train     = [0.4; 0.3]

## Using a normal distribution with mean 1 and a small standard deviation
w_train_data = Dict(t => clamp.(abs.(0.3*randn(length(w))), 0, 2) for t ∈ t_train_data)


@show x_train_data, w_inferred_data, ẋ_train_data = sample_with_velocities(
  f, 
  # g, #uncomment this and comment the previous line for log-barrier dynamics
  x₀_train, 
  t_train_data,
  w_train_data
)

t_all_data = collect(LinRange(0.0, 0.3, 31))
x_all_data, w_all_data= sample(
  f,
  # g, #uncomment this and comment the previous line for log-barrier dynamics
  x₀_train, 
  t_all_data,
  w_train_data
)

basis = maxbasis_with_linear_control_inputs(reduce(vcat, x), vec(w), 4)

zero_threshold = 1e-3
p = SINDySLS(
  x_train_data, 
  w_inferred_data, 
  ẋ_train_data, 
  basis; 
  zero_threshold=zero_threshold,
  max_iters=1
)

all_sindy_coeffs = []
all_sindy_terms = []

for i = 1:length(p)
  coeffs, terms = extract_data_from_polynomial(p[i])
  push!(all_sindy_coeffs, coeffs)
  push!(all_sindy_terms, terms)
end

significant_digits = Int(ceil(- log10(zero_threshold)))
println("SINDySLS Results:")
println("p₁₁ ≈ $(repr(round(p[1]; digits=significant_digits)))")
println("p₂₂ ≈ $(repr(round(p[2]; digits=significant_digits)))")

domain = Domains.product_of_simplices(x) 
SOCP_optimizer=optimizer_with_attributes(
  Mosek.Optimizer, 
  "MSK_DPAR_INTPNT_CO_TOL_PFEAS"   => 1e-12,
  "MSK_DPAR_INTPNT_CO_TOL_DFEAS"   => 1e-12,
  "MSK_DPAR_INTPNT_CO_TOL_REL_GAP" => 1e-12
)
p = separated(
  x_train_data, 
  w_inferred_data, 
  ẋ_train_data;
  sizes=length.(x)
) do i, x_dataᵢ, w_dataᵢ, ẋ_dataᵢ
  SIAR(
    x_dataᵢ, 
    w_dataᵢ,
    ẋ_dataᵢ, 
    basis;
    SOCP_optimizer=SOCP_optimizer,
    SI_constraints = [
      (model, p) -> SI.simplex_invariance(model, x[i], p; domain=domain),
      (model, p) -> SI.positive_correlation(model, p, v[i]; domain=domain, maxdegree=7)
    ],
    η=0.0,
    zero_threshold=zero_threshold,
    verbose=true
  )
end

all_siar_coeffs = []
all_siar_terms = []

for i = 1:length(p)
  coeffs, terms = extract_data_from_polynomial(p[i])
  push!(all_siar_coeffs, coeffs)
  push!(all_siar_terms, terms)
end

println("SIAR Results:")
println("p₁₁ ≈ $(repr(round(p[1]; digits=significant_digits)))")
println("p₂₂ ≈ $(repr(round(p[2]; digits=significant_digits)))")

# Save to .mat file
folder_path = "YOUR_FOLDER_PATH"
file_name = "data_JULIA_sh.mat"
full_path = joinpath(folder_path, file_name)
matfile = matopen(full_path, "w")
write(matfile, "x_train_data", x_train_data)
write(matfile, "w_inferred_data", w_inferred_data)
write(matfile, "xdot_train_data", ẋ_train_data)
write(matfile, "x_all_data", x_all_data)
write(matfile, "w_all_data", w_all_data)
write(matfile, "all_sindy_coeffs", all_sindy_coeffs)
write(matfile, "all_sindy_terms", all_sindy_terms)
write(matfile, "all_siar_coeffs", all_siar_coeffs)
write(matfile, "all_siar_terms", all_siar_terms)
close(matfile)