using OrdinaryDiffEq
using Plots
using Random
using Lux, Optimization, OptimizationOptimJL, ComponentArrays
using SciMLSensitivity
using LinearAlgebra
using OptimizationOptimisers
using Statistics
using Symbolics
using Convex, SCS
using Printf
using NonNegLeastSquares

rng = Random.default_rng()
Random.seed!(92)

function climate_system!(du, u, p, t)
   T, O, C = u
   α, C0_param, λ, κ, CT, CO, γ, β_param = p 
   
   F_C = α * log(C / C0_param)
   E_t = 8.0 + 0.1 * t 
   
   du[1] = (1/CT) * (F_C - λ*T - κ*(T - O))
   du[2] = (1/CO) * κ*(T - O)
   du[3] = E_t - β_param*T*C - γ*C
end

α_orig = 5.35
C0_orig = 280.0
λ_orig = 1.0
κ_orig = 0.69
CT_orig = 8.0
CO_orig = 80.0
γ_orig = 0.01
β_orig = 0.001

p_original = [α_orig, C0_orig, λ_orig, κ_orig, CT_orig, CO_orig, γ_orig, β_orig]

T0_init = 0.0
O0_init = 0.0
C0_val_init = 280.0
u0 = [T0_init, O0_init, C0_val_init]

tspan_train = (0.0, 15.0)
tspan_total = (0.0, 50.0)

prob_train_true = ODEProblem(climate_system!, u0, tspan_train, p_original)
solution_train = solve(prob_train_true, Tsit5(), abstol=1e-12, reltol=1e-12, saveat=1.0)

prob_total_true = ODEProblem(climate_system!, u0, tspan_total, p_original)
solution_total = solve(prob_total_true, Tsit5(), abstol=1e-12, reltol=1e-12, saveat=1.0)
true_data_total = Array(solution_total)
time_points_total = solution_total.t

tsdata_train = Array(solution_train)

T_train_std = std(tsdata_train[1,:])
O_train_std = std(tsdata_train[2,:])
C_train_std = std(tsdata_train[3,:])

noise_percentage = 0.01
noisy_data = copy(tsdata_train)
noisy_data[1,:] .+= Float32(noise_percentage * T_train_std) .* randn(rng, size(tsdata_train[1,:]))
noisy_data[2,:] .+= Float32(noise_percentage * O_train_std) .* randn(rng, size(tsdata_train[2,:]))
noisy_data[3,:] .+= Float32(noise_percentage * C_train_std) .* randn(rng, size(tsdata_train[3,:]))

T_train_data = tsdata_train[1,:]
O_train_data = tsdata_train[2,:] 
C_train_data = tsdata_train[3,:]

T_mean = mean(T_train_data)
T_std = std(T_train_data)
O_mean = mean(O_train_data)
O_std = std(O_train_data)
C_mean = mean(C_train_data)
C_std = std(C_train_data)

T_std = T_std < 1e-8 ? 1e-8 : T_std
O_std = O_std < 1e-8 ? 1e-8 : O_std 
C_std = C_std < 1e-8 ? 1e-8 : C_std
println("Range of β*T*C in training data: [", minimum(p_original[8] .* tsdata_train[1,:] .* tsdata_train[3,:]), ", ", maximum(p_original[8] .* tsdata_train[1,:] .* tsdata_train[3,:]), "]")

nn_beta_tc_model = Lux.Chain(
    Lux.Dense(2, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 128, softplus),
    Lux.Dense(128, 1)            # Linear output layer
)

p_nn_beta_tc_initial, st_nn_beta_tc_model = Lux.setup(rng, nn_beta_tc_model)
p_nn_beta_tc_initial = ComponentArray(p_nn_beta_tc_initial)

function ude_climate_system!(du, u, p_nn, t)
   T, O, C = u
   
   α_val = p_original[1]
   C0_param_val = p_original[2]
   λ_val = p_original[3]
   κ_val = p_original[4]
   CT_val = p_original[5]
   CO_val = p_original[6]
   γ_val = p_original[7]

   F_C = α_val * log(C / C0_param_val)
   E_t = 8.0 + 0.1 * t
   
   T_norm = (T - T_mean) / T_std
   C_norm = (C - C_mean) / C_std
   
   learned_beta_T_C_term = nn_beta_tc_model([T_norm, C_norm], p_nn, st_nn_beta_tc_model)[1][1]
   
   du[1] = (1/CT_val) * (F_C - λ_val*T - κ_val*(T - O))
   du[2] = (1/CO_val) * κ_val*(T - O)
   du[3] = E_t - learned_beta_T_C_term - γ_val*C
end

prob_ude_train = ODEProblem{true}(ude_climate_system!, u0, tspan_train, p_nn_beta_tc_initial)

function predict_ude(θ)
   sol = solve(prob_ude_train, Tsit5(), p=θ, saveat=solution_train.t,
               abstol=1e-6, reltol=1e-6,
               sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))
   
   if sol.retcode != ReturnCode.Success || typeof(Array(sol)) != typeof(noisy_data) || size(Array(sol)) != size(noisy_data)
       return fill(Inf32, size(noisy_data))
   end
   return Array(sol)
end

function loss_ude(θ)
    pred = predict_ude(θ)
    loss = sum(abs2, noisy_data .- pred)
    return loss
end

initial_loss = loss_ude(p_nn_beta_tc_initial)

losses_ude = []
iter_ude = 0
min_loss_ude = Inf

function callback_ude(θ, l)
   global iter_ude, losses_ude, min_loss_ude
   iter_ude += 1
   
   push!(losses_ude, l)
   min_loss_ude = min(min_loss_ude, l)

   if iter_ude % 50 == 0
       println("UDE Iteration: $iter_ude, Loss: $l")
   end
   return false
end

adtype_ude = Optimization.AutoZygote()
optf_ude = Optimization.OptimizationFunction((x, p_unused) -> loss_ude(x), adtype_ude)

iters_adam = 10000
optprob_ude_adam = Optimization.OptimizationProblem(optf_ude, p_nn_beta_tc_initial)
println("\nStarting UDE Adam training ($iters_adam iterations)...")
res_ude1 = Optimization.solve(optprob_ude_adam, OptimizationOptimisers.Adam(0.000001),
                            callback=callback_ude, maxiters=iters_adam)
println("UDE Adam training finished.")

iters_adamw = 10000
optprob_ude_adamw = remake(optprob_ude_adam, u0=res_ude1.u)
println("\nStarting UDE AdamW training ($iters_adamw iterations)...")
res_ude2 = Optimization.solve(optprob_ude_adamw, OptimizationOptimisers.AdamW(0.000001, (0.9999, 0.99999), 1e-7), 
                            callback=callback_ude, maxiters=iters_adamw)
println("UDE AdamW training finished.")

p_nn_trained = res_ude2.u

final_loss = loss_ude(p_nn_trained)
loss_reduction_percentage = ((initial_loss - final_loss) / initial_loss) * 100

println("\n--- UDE Training Summary ---")
println("Total UDE training iterations: ", iter_ude)
println("Initial UDE Training Loss: ", initial_loss)
println("Final UDE Training Loss: ", final_loss)
println("Minimum UDE Training Loss obtained: ", min_loss_ude)
println("Loss reduction: ", loss_reduction_percentage, "%")
println("--------------------------")

total_iterations_ude = length(losses_ude)
p_loss_curve_ude = plot(xlabel="Iterations (log scale)", ylabel="Loss (log scale)",
                        legend=:topright, yaxis=:log, xaxis=:log, size=(598, 369))

if total_iterations_ude > 0
    first_phase_end = min(iters_adam, total_iterations_ude)
    plot!(p_loss_curve_ude, 1:first_phase_end, losses_ude[1:first_phase_end],
          color=:blue, lw=2, label="Adam Optimizer")
    
    if total_iterations_ude > first_phase_end
        plot!(p_loss_curve_ude, (first_phase_end+1):total_iterations_ude, 
              losses_ude[(first_phase_end+1):end],
              color=:red, lw=2, label="AdamW Optimizer")
    end
end
display(p_loss_curve_ude)
savefig(p_loss_curve_ude, "random_seed_92_1p_noise_climate_ude_training_loss_curve.png")

prob_ude_extrapolate = ODEProblem{true}(ude_climate_system!, u0, tspan_total, p_nn_trained)
sol_ude_extrap = solve(prob_ude_extrapolate, Tsit5(), saveat=solution_total.t,
                      abstol=1e-12, reltol=1e-12)
ude_pred_total = Array(sol_ude_extrap)

T_values_from_true_total = true_data_total[1, :]
C_values_from_true_total = true_data_total[3, :]

true_beta_T_C_term_vals = p_original[8] .* T_values_from_true_total .* C_values_from_true_total

learned_beta_T_C_term_vals = zeros(length(time_points_total))
for i in eachindex(time_points_total)
   T_norm_i = (T_values_from_true_total[i] - T_mean) / T_std
   C_norm_i = (C_values_from_true_total[i] - C_mean) / C_std
   learned_beta_T_C_term_vals[i] = nn_beta_tc_model([T_norm_i, C_norm_i], p_nn_trained, st_nn_beta_tc_model)[1][1]
end

println("\n--- Starting Symbolic Extraction (SINDy) ---")

T_train = tsdata_train[1, :]  # UNNORMALIZED values for SINDy
C_train = tsdata_train[3, :]  # UNNORMALIZED values for SINDy

train_indices = time_points_total .<= tspan_train[2]
learned_beta_T_C_term_train = learned_beta_T_C_term_vals[train_indices]

@variables T_var C_var

basis_funcs = [
    u -> u[1] * u[2],           # T*C
    u -> u[1]^2 * u[2],         # T²*C
    u -> u[1] * u[2]^2,         # T*C²
    u -> u[1]^2 * u[2]^2        # T²*C²
]

# Correct symbols - these are UNNORMALIZED
basis_symbols = [
    "T*C",
    "T²*C",
    "T*C²",
    "T²*C²"
]

# Build the regression matrix using UNNORMALIZED T and C
inputs = [[T_train[i], C_train[i]] for i in 1:length(T_train)]
Φ = [Float64(f(u)) for u in inputs, f in basis_funcs]
y = reshape(Float64.(learned_beta_T_C_term_train), :, 1)

# Using Non-Negative Least Squares to ensure positive coefficients
β = nonneg_lsq(Φ, vec(y); alg=:fnnls)

println("\n🔍 Verification: All coefficients >= 0? ", all(β .>= 0))
println("Coefficient values: ", β)

threshold = 0.0000001  # Lower threshold to ensure dominant term is captured

function format_expr(β, symbols, threshold)
   terms = []
   for (coeff, term) in zip(β, symbols)
       if abs(coeff) > threshold
           coeff_str = @sprintf("%.6f", coeff)
           push!(terms, coeff_str * "*" * term)
       end
   end
   return join(terms, " + ")
end

expr_str = format_expr(β, basis_symbols, threshold)

println("\n📘 Learned symbolic expression for β*T*C term (UNNORMALIZED variables):")
if isempty(expr_str)
    println("β*T*C ≈ 0 (all coefficients below threshold)")
else
    println("β*T*C ≈ ", expr_str)
end

# Detailed coefficient analysis
println("\n📊 Detailed Coefficient Analysis:")
for (i, (coeff, term)) in enumerate(zip(β, basis_symbols))
    println("  $term: ", @sprintf("%.8f", coeff))
end

symbolic_predictions = Φ * β

mse = mean((learned_beta_T_C_term_train .- vec(symbolic_predictions)).^2)
rmse = sqrt(mse)
r2 = 1 - sum((learned_beta_T_C_term_train .- vec(symbolic_predictions)).^2) / sum((learned_beta_T_C_term_train .- mean(learned_beta_T_C_term_train)).^2)

println("\nSymbolic Regression Performance:")
println("RMSE: ", rmse)
println("R²: ", r2)

sindy_predictions_train = vec(symbolic_predictions)

T_extrap = true_data_total[1, :]
C_extrap = true_data_total[3, :]

inputs_extrap = [[T_extrap[i], C_extrap[i]] for i in 1:length(T_extrap)]
Φ_extrap = [Float64(f(u)) for u in inputs_extrap, f in basis_funcs]
sindy_predictions_total = vec(Φ_extrap * β)

extrap_indices = time_points_total .> tspan_train[2]

p_symbolic_train = plot(title="", 
                       xlabel="Time (years)", 
                       ylabel="β*T*C term",
                       legend=:topleft)
plot!(p_symbolic_train, solution_train.t, learned_beta_T_C_term_train, lw=2, label="NN")
plot!(p_symbolic_train, solution_train.t, sindy_predictions_train, lw=2, linestyle=:dash, label="SINDy Expression")
plot!(p_symbolic_train, solution_train.t, true_beta_T_C_term_vals[train_indices], lw=2, label="True β*T*C")
display(p_symbolic_train)
savefig(p_symbolic_train, "random_seed_92_1p_climate_symbolic_vs_nn_train.png")

p_symbolic_extrap = plot(title="", 
                        xlabel="Time (years)", 
                        ylabel="β*T*C term",
                        legend=:topleft)
plot!(p_symbolic_extrap, time_points_total[extrap_indices], sindy_predictions_total[extrap_indices], lw=2, linestyle=:dash, label="SINDy Expression")
plot!(p_symbolic_extrap, time_points_total[extrap_indices], true_beta_T_C_term_vals[extrap_indices], lw=2, label="True β*T*C")
plot!(p_symbolic_extrap, time_points_total[extrap_indices], learned_beta_T_C_term_vals[extrap_indices], lw=2, linestyle=:dot, label="UDE")
display(p_symbolic_extrap)
savefig(p_symbolic_extrap, "random_seed_92_1p_climate_symbolic_extrap.png")

p_symbolic_full = plot(title="", 
                      xlabel="Time (years)", 
                      ylabel="β*T*C term",
                      legend=:topleft)
plot!(p_symbolic_full, time_points_total, sindy_predictions_total, lw=2, linestyle=:dash, label="SINDy Expression")
plot!(p_symbolic_full, time_points_total, true_beta_T_C_term_vals, lw=2, label="True β*T*C")
plot!(p_symbolic_full, time_points_total, learned_beta_T_C_term_vals, lw=2, linestyle=:dot, label="UDE")
vline!(p_symbolic_full, [tspan_train[2]], color=:black, lw=1.5, linestyle=:dash, label="Training End")
display(p_symbolic_full)
savefig(p_symbolic_full, "random_seed_92_1p_climate_symbolic_full.png")

# Extract the dominant T*C term coefficient
tc_term_idx = findfirst(x -> x == "T*C", basis_symbols)
if tc_term_idx !== nothing
   β_learned = β[tc_term_idx]
   β_true = p_original[8]
   
   println("\n" * "="^60)
   println("📈 DOMINANT T*C TERM ANALYSIS")
   println("="^60)
   println("Learned coefficient β (from T*C term): ", @sprintf("%.8f", β_learned))
   println("True coefficient β:                     ", @sprintf("%.8f", β_true))
   println("\nAbsolute Error: ", @sprintf("%.8f", abs(β_learned - β_true)))
   println("Relative Error: ", @sprintf("%.2f%%", abs(β_learned - β_true) / β_true * 100))
   
   if abs(β_learned - β_true) / β_true < 0.1
       println("✅ Excellent recovery (< 10% error)")
   elseif abs(β_learned - β_true) / β_true < 0.25
       println("⚠️  Moderate recovery (10-25% error)")
   else
       println("❌ Poor recovery (> 25% error)")
   end
   println("="^60)
else
   println("\n⚠️ Warning: T*C term not found or below threshold")
end

# Analyze prediction errors for β*T*C term
println("\n" * "="^60)
println("📉 β*T*C TERM PREDICTION ERRORS")
println("="^60)

# Training period errors
train_error_abs = abs.(true_beta_T_C_term_vals[train_indices] .- learned_beta_T_C_term_vals[train_indices])
train_error_mean = mean(train_error_abs)
train_error_max = maximum(train_error_abs)

println("\nTraining Period (0-$(tspan_train[2]) years):")
println("  Mean Absolute Error:    ", @sprintf("%.6f", train_error_mean))
println("  Max Absolute Error:     ", @sprintf("%.6f", train_error_max))
println("  Mean Relative Error:    ", @sprintf("%.2f%%", 
    mean(train_error_abs ./ abs.(true_beta_T_C_term_vals[train_indices])) * 100))

# Extrapolation period errors
extrap_error_abs = abs.(true_beta_T_C_term_vals[.!train_indices] .- learned_beta_T_C_term_vals[.!train_indices])
extrap_error_mean = mean(extrap_error_abs)
extrap_error_max = maximum(extrap_error_abs)

println("\nExtrapolation Period ($(tspan_train[2])-$(tspan_total[2]) years):")
println("  Mean Absolute Error:    ", @sprintf("%.6f", extrap_error_mean))
println("  Max Absolute Error:     ", @sprintf("%.6f", extrap_error_max))
println("  Mean Relative Error:    ", @sprintf("%.2f%%", 
    mean(extrap_error_abs ./ abs.(true_beta_T_C_term_vals[.!train_indices])) * 100))

println("="^60)

error_beta_T_C_term = abs.(true_beta_T_C_term_vals .- learned_beta_T_C_term_vals)
avg_error_term = mean(error_beta_T_C_term)

final_T_true = true_data_total[1, end]
final_O_true = true_data_total[2, end]
final_C_true = true_data_total[3, end]

final_T_pred = ude_pred_total[1, end]
final_O_pred = ude_pred_total[2, end]
final_C_pred = ude_pred_total[3, end]

abs_error_T = abs(final_T_true - final_T_pred)
abs_error_O = abs(final_O_true - final_O_pred)
abs_error_C = abs(final_C_true - final_C_pred)

pct_error_T = (abs_error_T / abs(final_T_true)) * 100
pct_error_O = (abs_error_O / abs(final_O_true)) * 100
pct_error_C = (abs_error_C / abs(final_C_true)) * 100

println("\n--- Final Values Comparison ---")
println("Surface Temperature (T):")
println("  True: ", final_T_true, " °C")
println("  Predicted: ", final_T_pred, " °C")
println("  Absolute Error: ", abs_error_T, " °C")
println("  Percentage Error: ", pct_error_T, "%")

println("\nOcean Temperature (O):")
println("  True: ", final_O_true, " °C")
println("  Predicted: ", final_O_pred, " °C")
println("  Absolute Error: ", abs_error_O, " °C")
println("  Percentage Error: ", pct_error_O, "%")

println("\nCO₂ Concentration (C):")
println("  True: ", final_C_true, " ppm")
println("  Predicted: ", final_C_pred, " ppm")
println("  Absolute Error: ", abs_error_C, " ppm")
println("  Percentage Error: ", pct_error_C, "%")
println("-------------------------------")

variable_names = ["Surface Temperature Anomaly (T)", "Deep Ocean Temperature Anomaly (O)", "CO₂ Concentration (C)"]
variable_units = ["°C", "°C", "ppm"]

plots_vars_ude = []
for i in 1:3
   p = plot(title=variable_names[i], xlabel="Years", ylabel=variable_units[i], legend=:topleft,
            tickfontsize=7, legendfontsize=6, guidefontsize=8, titlefontsize=7,
            xticks=0:10:50)
   scatter!(p, time_points_total, true_data_total[i,:], label="True Data", markersize=2, markeralpha=0.7)
   plot!(p, time_points_total, ude_pred_total[i,:], lw=2, label="UDE Prediction")
   vline!(p, [tspan_train[2]], color=:black, lw=1.5, linestyle=:dash, label="Training End")
   push!(plots_vars_ude, p)
end
p_combined_vars_ude = plot(plots_vars_ude..., layout=(1,3), size=(593,371), margin=3Plots.mm)
display(p_combined_vars_ude)
savefig(p_combined_vars_ude, "random_seed_92_1p_climate_ude_pred_variables.png")

p_term_beta_tc = plot(title="", xlabel="Years", ylabel="Term Value", legend=:topleft)
plot!(p_term_beta_tc, time_points_total, true_beta_T_C_term_vals, lw=2, label="True β*T*C")
plot!(p_term_beta_tc, time_points_total, learned_beta_T_C_term_vals, lw=2, linestyle=:dash, label="Learned β*T*C (NN)")
vline!(p_term_beta_tc, [tspan_train[2]], color=:black, lw=1.5, linestyle=:dash, label="Training End")
display(p_term_beta_tc)
savefig(p_term_beta_tc, "random_seed_92_1p_climate_ude_beta_tc_term_comparison.png")

p_error_term_beta_tc = plot(title="", xlabel="Years", ylabel="Absolute Error", legend=:topright)
plot!(p_error_term_beta_tc, time_points_total, error_beta_T_C_term, lw=2, label="Error |True - Learned Term|")
hline!(p_error_term_beta_tc, [avg_error_term], lw=2, linestyle=:dash, label="Average Error")
vline!(p_error_term_beta_tc, [tspan_train[2]], color=:black, lw=1.5, linestyle=:dash, label="Training End")
display(p_error_term_beta_tc)
savefig(p_error_term_beta_tc, "random_seed_92_1p_climate_ude_beta_tc_term_error.png")

l2_errors_over_time_ude = [norm(true_data_total[:, j] - ude_pred_total[:, j]) for j in 1:size(true_data_total, 2)]
max_l2_err = maximum(l2_errors_over_time_ude)
p_l2_overall_error_ude = plot(title="",
                             xlabel="Years", ylabel="L2 Error ||True - UDE Pred||",
                             legend=:topleft, ylims=(0, max_l2_err > 0 ? max_l2_err * 1.1 : 0.1))
plot!(p_l2_overall_error_ude, time_points_total, l2_errors_over_time_ude, lw=2, label="L2 Error")
vline!(p_l2_overall_error_ude, [tspan_train[2]], color=:black, lw=1.5, linestyle=:dash, label="Training End")
display(p_l2_overall_error_ude)
savefig(p_l2_overall_error_ude, "random_seed_92_1p_climate_ude_l2_overall_error.png")

println("\n--- Climate System UDE Analysis Complete ---")
println("Training period: 0-$(tspan_train[2]) years")
println("Prediction period: $(tspan_train[2])-$(tspan_total[2]) years")
println("Final UDE-predicted surface temperature anomaly at $(tspan_total[2]) years: ", ude_pred_total[1,end], " °C")
println("Final UDE-predicted CO2 concentration at $(tspan_total[2]) years: ", ude_pred_total[3,end], " ppm")