

using OrdinaryDiffEq
using Plots
using Random
using Lux, Optimization, OptimizationOptimJL, ComponentArrays
using SciMLSensitivity
using LinearAlgebra
using OptimizationOptimisers
using Statistics
using Printf
using StateSpaceModels
using StatsPlots

rng = Random.default_rng()
Random.seed!(91)

# ============================================================================
# CLIMATE SYSTEM DEFINITION
# ============================================================================
function climate_system!(du, u, p, t)
    T, O, C = u
    α, C0, λ, κ, CT, CO, γ, β = p
    
    F_C = α * log(C / C0)
    E_t = 8.0 + 0.1 * t 
    
    du[1] = (1/CT) * (F_C - λ*T - κ*(T - O))
    du[2] = (1/CO) * κ*(T - O)
    du[3] = E_t - β*T*C - γ*C
end

α = 5.35
C0 = 280.0
λ = 1.0
κ = 0.69
CT = 8.0
CO = 80.0
γ = 0.01
β = 0.001

p = [α, C0, λ, κ, CT, CO, γ, β]

T0 = 0.0
O0 = 0.0
C0_init = 280.0
u0 = [T0, O0, C0_init]

tspan_train = (0.0, 15.0)
tspan_total = (0.0, 50.0)

# Generate true data
prob_train = ODEProblem(climate_system!, u0, tspan_train, p)
solution_train = solve(prob_train, Tsit5(), abstol=1e-12, reltol=1e-12, saveat=1.0)

prob_total = ODEProblem(climate_system!, u0, tspan_total, p)
solution_total = solve(prob_total, Tsit5(), abstol=1e-12, reltol=1e-12, saveat=1.0)

tsdata_train = Array(solution_train)

# Add noise
T_train_std = std(tsdata_train[1,:])
O_train_std = std(tsdata_train[2,:])
C_train_std = std(tsdata_train[3,:])

noise_percentage = 0.01
noisy_data = copy(tsdata_train)
noisy_data[1,:] .+= Float32(noise_percentage * T_train_std) .* randn(rng, size(tsdata_train[1,:]))
noisy_data[2,:] .+= Float32(noise_percentage * O_train_std) .* randn(rng, size(tsdata_train[2,:]))
noisy_data[3,:] .+= Float32(noise_percentage * C_train_std) .* randn(rng, size(tsdata_train[3,:]))

# Normalization parameters
T_mean, T_std = mean(tsdata_train[1,:]), std(tsdata_train[1,:])
O_mean, O_std = mean(tsdata_train[2,:]), std(tsdata_train[2,:])
C_mean, C_std = mean(tsdata_train[3,:]), std(tsdata_train[3,:])

T_std = T_std < 1e-8 ? 1e-8 : T_std
O_std = O_std < 1e-8 ? 1e-8 : O_std 
C_std = C_std < 1e-8 ? 1e-8 : C_std

function normalize_variables(T, O, C)
    T_norm = (T - T_mean) / T_std
    O_norm = (O - O_mean) / O_std
    C_norm = (C - C_mean) / C_std
    return T_norm, O_norm, C_norm
end

function denormalize_derivatives(dT_norm, dO_norm, dC_norm)
    dT = dT_norm * T_std
    dO = dO_norm * O_std
    dC = dC_norm * C_std
    return dT, dO, dC
end

# ============================================================================
# NEURAL ODE MODEL
# ============================================================================
println("\n" * "="^80)
println("TRAINING NEURAL ODE MODEL")
println("="^80)

ann_node = Lux.Chain(
    Lux.Dense(3, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 3)
)

rng = Random.default_rng()
p1, st1 = Lux.setup(rng, ann_node)
p_ann = ComponentArray(p1)

# Count parameters for Neural ODE
n_params_node = length(p_ann)

function dudt_node(du, u, p, t)
    T, O, C = u
    T_norm, O_norm, C_norm = normalize_variables(T, O, C)
    nn_output = ann_node([T_norm, O_norm, C_norm], p, st1)[1]
    dT, dO, dC = denormalize_derivatives(nn_output[1], nn_output[2], nn_output[3])
    du[1] = dT
    du[2] = dO
    du[3] = dC
end

prob_node = ODEProblem{true}(dudt_node, u0, tspan_train, p_ann)

function predict_node(θ)
    Array(solve(prob_node, Tsit5(), p=θ, saveat=1.0, 
                abstol=1e-6, reltol=1e-6,
                sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end

function loss_node(θ)
    pred = predict_node(θ)
    loss = sum(abs2, (noisy_data .- pred))
    return loss
end

iter_node = 0
losses_node = []
function callback_node(θ, l)
    global iter_node, losses_node
    iter_node += 1
    push!(losses_node, l)
    if iter_node % 50 == 0
        println("Neural ODE Iteration: $iter_node, Loss: $l")
    end
    return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p_inner) -> loss_node(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p_ann)

iters_adam = 20000
println("\nStarting Neural ODE Adam training ($iters_adam iterations)...")
time_node_start = time()
res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(0.00001), 
                         callback=callback_node, maxiters=iters_adam)

iters_adamw = 20000
optprob2 = remake(optprob, u0=res1.u)
println("Starting Neural ODE AdamW training ($iters_adamw iterations)...")
res2 = Optimization.solve(optprob2, OptimizationOptimisers.AdamW(0.00001, (0.9, 0.999), 1e-4),
                         callback=callback_node, maxiters=iters_adamw)
time_node_end = time()
time_node_total = time_node_end - time_node_start

println("Neural ODE training complete!")

# Neural ODE predictions
prob_node_extrapolate = ODEProblem{true}(dudt_node, u0, tspan_total, res2.u)
sol_node = Array(solve(prob_node_extrapolate, Tsit5(), saveat=solution_total.t,
                       abstol=1e-12, reltol=1e-12,
                       sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))

# ============================================================================
# UDE MODEL
# ============================================================================
println("\n" * "="^80)
println("TRAINING UDE MODEL")
println("="^80)

nn_beta_tc_model = Lux.Chain(
    Lux.Dense(2, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 1)
)

p_nn_beta_tc_initial, st_nn_beta_tc_model = Lux.setup(rng, nn_beta_tc_model)
p_nn_beta_tc_initial = ComponentArray(p_nn_beta_tc_initial)

# Count parameters for UDE
n_params_ude = length(p_nn_beta_tc_initial)

function ude_climate_system!(du, u, p_nn, t)
    T, O, C = u
    
    F_C = α * log(C / C0)
    E_t = 8.0 + 0.1 * t
    
    T_norm = (T - T_mean) / T_std
    C_norm = (C - C_mean) / C_std
    
    learned_beta_T_C_term = nn_beta_tc_model([T_norm, C_norm], p_nn, st_nn_beta_tc_model)[1][1]
    
    du[1] = (1/CT) * (F_C - λ*T - κ*(T - O))
    du[2] = (1/CO) * κ*(T - O)
    du[3] = E_t - learned_beta_T_C_term - γ*C
end

prob_ude_train = ODEProblem{true}(ude_climate_system!, u0, tspan_train, p_nn_beta_tc_initial)

function predict_ude(θ)
    sol = solve(prob_ude_train, Tsit5(), p=θ, saveat=solution_train.t,
                abstol=1e-6, reltol=1e-6,
                sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))
    
    if sol.retcode != ReturnCode.Success || typeof(Array(sol)) != typeof(noisy_data) || size(Array(sol)) != size(noisy_data)
        return fill(Inf32, size(noisy_data))
    end
    return Array(sol)
end

function loss_ude(θ)
    pred = predict_ude(θ)
    loss = sum(abs2, noisy_data .- pred)
    return loss
end

iter_ude = 0
losses_ude = []
function callback_ude(θ, l)
    global iter_ude, losses_ude
    iter_ude += 1
    push!(losses_ude, l)
    if iter_ude % 50 == 0
        println("UDE Iteration: $iter_ude, Loss: $l")
    end
    return false
end

adtype_ude = Optimization.AutoZygote()
optf_ude = Optimization.OptimizationFunction((x, p_unused) -> loss_ude(x), adtype_ude)

iters_adam_ude = 5500
optprob_ude_adam = Optimization.OptimizationProblem(optf_ude, p_nn_beta_tc_initial)
println("\nStarting UDE Adam training ($iters_adam_ude iterations)...")
time_ude_start = time()
res_ude1 = Optimization.solve(optprob_ude_adam, OptimizationOptimisers.Adam(0.001),
                            callback=callback_ude, maxiters=iters_adam_ude)

iters_adamw_ude = 5500
optprob_ude_adamw = remake(optprob_ude_adam, u0=res_ude1.u)
println("Starting UDE AdamW training ($iters_adamw_ude iterations)...")
res_ude2 = Optimization.solve(optprob_ude_adamw, OptimizationOptimisers.AdamW(0.0001, (0.9, 0.999), 1e-4), 
                            callback=callback_ude, maxiters=iters_adamw_ude)
time_ude_end = time()
time_ude_total = time_ude_end - time_ude_start

println("UDE training complete!")

# UDE predictions
prob_ude_extrapolate = ODEProblem{true}(ude_climate_system!, u0, tspan_total, res_ude2.u)
sol_ude = Array(solve(prob_ude_extrapolate, Tsit5(), saveat=solution_total.t,
                      abstol=1e-12, reltol=1e-12))

# ============================================================================
# ARIMA MODEL WITH GRID SEARCH
# ============================================================================
println("\n" * "="^80)
println("TRAINING ARIMA MODELS WITH GRID SEARCH")
println("="^80)

# Grid search parameters
p_range = 0:3
d_range = 0:2
q_range = 0:3

# Function to fit ARIMA and calculate AIC
function fit_arima_model(y_train, p, d, q)
    try
        model = SARIMA(y_train; order=(p, d, q), include_mean=true)
        fit!(model)
        return model, StateSpaceModels.aicc(model)
    catch e
        return nothing, Inf
    end
end

# Function to forecast with ARIMA
function forecast_arima(model, n_ahead)
    try
        forecast_result = forecast(model, n_ahead)
        return forecast_expected_value(forecast_result)
    catch e
        return fill(NaN, n_ahead)
    end
end

# Grid search for each variable
time_arima_start = time()

println("\n--- ARIMA Grid Search for Surface Temperature (T) ---")
best_aic_T = Inf
best_params_T = (0, 0, 0)
best_model_T = nothing

for p_val in p_range, d_val in d_range, q_val in q_range
    model, aic = fit_arima_model(noisy_data[1,:], p_val, d_val, q_val)
    if aic < best_aic_T
        best_aic_T = aic
        best_params_T = (p_val, d_val, q_val)
        best_model_T = model
    end
end

println("Best ARIMA parameters for T: $(best_params_T)")
println("Best AICc for T: $(best_aic_T)")

println("\n--- ARIMA Grid Search for Ocean Temperature (O) ---")
best_aic_O = Inf
best_params_O = (0, 0, 0)
best_model_O = nothing

for p_val in p_range, d_val in d_range, q_val in q_range
    model, aic = fit_arima_model(noisy_data[2,:], p_val, d_val, q_val)
    if aic < best_aic_O
        best_aic_O = aic
        best_params_O = (p_val, d_val, q_val)
        best_model_O = model
    end
end

println("Best ARIMA parameters for O: $(best_params_O)")
println("Best AICc for O: $(best_aic_O)")

println("\n--- ARIMA Grid Search for CO2 Concentration (C) ---")
best_aic_C = Inf
best_params_C = (0, 0, 0)
best_model_C = nothing

for p_val in p_range, d_val in d_range, q_val in q_range
    model, aic = fit_arima_model(noisy_data[3,:], p_val, d_val, q_val)
    if aic < best_aic_C
        best_aic_C = aic
        best_params_C = (p_val, d_val, q_val)
        best_model_C = model
    end
end

println("Best ARIMA parameters for C: $(best_params_C)")
println("Best AICc for C: $(best_aic_C)")

# Generate ARIMA forecasts
n_forecast = length(solution_total.t) - length(solution_train.t)

arima_forecast_T = forecast_arima(best_model_T, n_forecast)
arima_forecast_O = forecast_arima(best_model_O, n_forecast)
arima_forecast_C = forecast_arima(best_model_C, n_forecast)

time_arima_end = time()
time_arima_total = time_arima_end - time_arima_start

# Combine training and forecast
arima_full_T = vcat(noisy_data[1,:], arima_forecast_T)
arima_full_O = vcat(noisy_data[2,:], arima_forecast_O)
arima_full_C = vcat(noisy_data[3,:], arima_forecast_C)

# Count parameters for ARIMA (total for all three models)
n_params_arima_T = sum(best_params_T[1:1]) + sum(best_params_T[3:3]) + 1  # p + q + mean
n_params_arima_O = sum(best_params_O[1:1]) + sum(best_params_O[3:3]) + 1
n_params_arima_C = sum(best_params_C[1:1]) + sum(best_params_C[3:3]) + 1
n_params_arima_total = n_params_arima_T + n_params_arima_O + n_params_arima_C

# ============================================================================
# COMPUTATIONAL RESOURCE USAGE COMPARISON
# ============================================================================
println("\n" * "="^80)
println("COMPUTATIONAL RESOURCE USAGE COMPARISON")
println("="^80)

total_iters_node = iters_adam + iters_adamw
total_iters_ude = iters_adam_ude + iters_adamw_ude
total_iters_arima = length(p_range) * length(d_range) * length(q_range) * 3  # 3 variables

println("\n┌" * "─"^78 * "┐")
println("│" * " "^25 * "COMPUTATIONAL RESOURCES COMPARISON" * " "^19 * "│")
println("├" * "─"^78 * "┤")
@printf("│ %-30s │ %13s │ %13s │ %13s │\n", "Metric", "Neural ODE", "UDE", "ARIMA")
println("├" * "─"^78 * "┤")
@printf("│ %-30s │ %13d │ %13d │ %13d │\n", "Number of Parameters", n_params_node, n_params_ude, n_params_arima_total)
@printf("│ %-30s │ %13d │ %13d │ %13d │\n", "Training Iterations", total_iters_node, total_iters_ude, total_iters_arima)
@printf("│ %-30s │ %10.2f s │ %10.2f s │ %10.2f s │\n", "Training Time", time_node_total, time_ude_total, time_arima_total)
@printf("│ %-30s │ %10.2f ms │ %10.2f ms │ %10.2f ms │\n", "Time per Iteration", 
        time_node_total/total_iters_node*1000, time_ude_total/total_iters_ude*1000, time_arima_total/total_iters_arima*1000)
println("├" * "─"^78 * "┤")
@printf("│ %-30s │ %13s │ %13s │ %13s │\n", "Model Complexity", "High", "Medium", "Low")
@printf("│ %-30s │ %13s │ %13s │ %13s │\n", "Physical Knowledge", "None", "Partial", "None")
@printf("│ %-30s │ %13s │ %13s │ %13s │\n", "Interpretability", "Low", "High", "Medium")
println("└" * "─"^78 * "┘")

# ============================================================================
# CALCULATE COMPARISON METRICS
# ============================================================================
println("\n" * "="^80)
println("FORECASTING COMPARISON METRICS")
println("="^80)

true_data_total = Array(solution_total)
time_points_total = solution_total.t

# Split into training and testing periods
train_idx = 1:length(solution_train.t)
test_idx = (length(solution_train.t)+1):length(time_points_total)

function calculate_metrics(true_vals, pred_vals, var_name, model_name)
    mae = mean(abs.(true_vals .- pred_vals))
    rmse = sqrt(mean((true_vals .- pred_vals).^2))
    mape = mean(abs.((true_vals .- pred_vals) ./ (true_vals .+ 1e-10))) * 100
    max_error = maximum(abs.(true_vals .- pred_vals))
    
    ss_res = sum((true_vals .- pred_vals).^2)
    ss_tot = sum((true_vals .- mean(true_vals)).^2)
    r2 = 1 - ss_res / ss_tot
    
    return (mae=mae, rmse=rmse, mape=mape, max_error=max_error, r2=r2)
end

# Calculate metrics for test period only (extrapolation)
println("\n" * "─"^80)
println("EXTRAPOLATION PERIOD METRICS (t = 15 to 50 years)")
println("─"^80)

# Store metrics for plotting
metrics_dict = Dict()

for (i, var_name, unit) in [(1, "Surface Temperature", "°C"), 
                             (2, "Ocean Temperature", "°C"), 
                             (3, "CO₂ Concentration", "ppm")]
    println("\n$(var_name) ($(unit)):")
    println("="^60)
    
    true_test = true_data_total[i, test_idx]
    
    # Neural ODE
    node_test = sol_node[i, test_idx]
    metrics_node = calculate_metrics(true_test, node_test, var_name, "Neural ODE")
    
    # UDE
    ude_test = sol_ude[i, test_idx]
    metrics_ude = calculate_metrics(true_test, ude_test, var_name, "UDE")
    
    # ARIMA
    if i == 1
        arima_test = arima_full_T[test_idx]
    elseif i == 2
        arima_test = arima_full_O[test_idx]
    else
        arima_test = arima_full_C[test_idx]
    end
    metrics_arima = calculate_metrics(true_test, arima_test, var_name, "ARIMA")
    
    # Store for plotting
    metrics_dict[i] = (node=metrics_node, ude=metrics_ude, arima=metrics_arima)
    
    # Print comparison table
    println("\n                   Neural ODE       UDE           ARIMA")
    println("─"^60)
    @printf("MAE:              %.6f      %.6f      %.6f\n", metrics_node.mae, metrics_ude.mae, metrics_arima.mae)
    @printf("RMSE:             %.6f      %.6f      %.6f\n", metrics_node.rmse, metrics_ude.rmse, metrics_arima.rmse)
    @printf("MAPE (%%):         %.4f       %.4f       %.4f\n", metrics_node.mape, metrics_ude.mape, metrics_arima.mape)
    @printf("Max Error:        %.6f      %.6f      %.6f\n", metrics_node.max_error, metrics_ude.max_error, metrics_arima.max_error)
    @printf("R²:               %.6f      %.6f      %.6f\n", metrics_node.r2, metrics_ude.r2, metrics_arima.r2)
    
    # Determine best model
    best_model = ""
    if metrics_node.rmse < metrics_ude.rmse && metrics_node.rmse < metrics_arima.rmse
        best_model = "Neural ODE"
    elseif metrics_ude.rmse < metrics_arima.rmse
        best_model = "UDE"
    else
        best_model = "ARIMA"
    end
    println("\n✓ Best Model (by RMSE): $(best_model)")
end

# Final values comparison at t=50
println("\n" * "─"^80)
println("FINAL VALUES AT t = 50 years")
println("─"^80)

for (i, var_name, unit) in [(1, "Surface Temperature", "°C"), 
                             (2, "Ocean Temperature", "°C"), 
                             (3, "CO₂ Concentration", "ppm")]
    println("\n$(var_name):")
    
    true_final = true_data_total[i, end]
    node_final = sol_node[i, end]
    ude_final = sol_ude[i, end]
    
    if i == 1
        arima_final = arima_full_T[end]
    elseif i == 2
        arima_final = arima_full_O[end]
    else
        arima_final = arima_full_C[end]
    end
    
    @printf("  True Value:       %.6f %s\n", true_final, unit)
    @printf("  Neural ODE:       %.6f %s (error: %.6f, %.2f%%)\n", 
            node_final, unit, abs(true_final - node_final), 
            abs(true_final - node_final)/abs(true_final)*100)
    @printf("  UDE:              %.6f %s (error: %.6f, %.2f%%)\n", 
            ude_final, unit, abs(true_final - ude_final), 
            abs(true_final - ude_final)/abs(true_final)*100)
    @printf("  ARIMA:            %.6f %s (error: %.6f, %.2f%%)\n", 
            arima_final, unit, abs(true_final - arima_final), 
            abs(true_final - arima_final)/abs(true_final)*100)
end

# ============================================================================
# CREATE COMPARISON PLOTS
# ============================================================================
println("\n" * "="^80)
println("GENERATING COMPARISON PLOTS")
println("="^80)

# Plot 1: Combined Climate Variables Comparison (3 subplots)
p1 = plot(title="Surface Temperature Anomaly",
          xlabel="Time (years)",
          ylabel="Temperature (°C)",
          legend=:topleft,
          titlefontsize=8,
          legendfontsize=4,
          guidefontsize=7,
          tickfontsize=6)

scatter!(p1, time_points_total, true_data_total[1,:], 
         label="True", markersize=2, color=:black, alpha=0.6)
plot!(p1, time_points_total, sol_node[1,:], 
      label="Neural ODE", lw=1.5, color=:blue)
plot!(p1, time_points_total, sol_ude[1,:], 
      label="UDE", lw=1.5, color=:red, linestyle=:dash)
plot!(p1, time_points_total, arima_full_T, 
      label="ARIMA", lw=1.5, color=:green, linestyle=:dot)
vline!(p1, [15.0], label="Train End", color=:black, lw=1.5, linestyle=:dashdot)

p2 = plot(title="Deep Ocean Temperature Anomaly",
          xlabel="Time (years)",
          ylabel="Temperature (°C)",
          legend=:topleft,
          titlefontsize=8,
          legendfontsize=4,
          guidefontsize=7,
          tickfontsize=6)

scatter!(p2, time_points_total, true_data_total[2,:], 
         label="True", markersize=2, color=:black, alpha=0.6)
plot!(p2, time_points_total, sol_node[2,:], 
      label="Neural ODE", lw=1.5, color=:blue)
plot!(p2, time_points_total, sol_ude[2,:], 
      label="UDE", lw=1.5, color=:red, linestyle=:dash)
plot!(p2, time_points_total, arima_full_O, 
      label="ARIMA", lw=1.5, color=:green, linestyle=:dot)
vline!(p2, [15.0], label="Train End", color=:black, lw=1.5, linestyle=:dashdot)

p3 = plot(title="CO₂ Concentration",
          xlabel="Time (years)",
          ylabel="Concentration (ppm)",
          legend=:topleft,
          titlefontsize=8,
          legendfontsize=4,
          guidefontsize=7,
          tickfontsize=6)

scatter!(p3, time_points_total, true_data_total[3,:], 
         label="True", markersize=2, color=:black, alpha=0.6)
plot!(p3, time_points_total, sol_node[3,:], 
      label="Neural ODE", lw=1.5, color=:blue)
plot!(p3, time_points_total, sol_ude[3,:], 
      label="UDE", lw=1.5, color=:red, linestyle=:dash)
plot!(p3, time_points_total, arima_full_C, 
      label="ARIMA", lw=1.5, color=:green, linestyle=:dot)
vline!(p3, [15.0], label="Train End", color=:black, lw=1.5, linestyle=:dashdot)

p_combined = plot(p1, p2, p3, layout=(3,1), size=(593, 371))
display(p_combined)
savefig(p_combined, "random_seed_91_1p_noise_ARIMA_vs_all_methods_comparison_combined_climate_variables.png")

# Plot 2: Absolute Error Comparison with Proper Scaling
mae_T = [metrics_dict[1].node.mae, metrics_dict[1].ude.mae, metrics_dict[1].arima.mae]
mae_O = [metrics_dict[2].node.mae, metrics_dict[2].ude.mae, metrics_dict[2].arima.mae]
mae_C = [metrics_dict[3].node.mae, metrics_dict[3].ude.mae, metrics_dict[3].arima.mae]

# Find max values for scaling
max_mae_T = maximum(mae_T)
max_mae_O = maximum(mae_O)
max_mae_C = maximum(mae_C)

# Create separate plots for each variable with appropriate scaling
p_abs1 = groupedbar([mae_T[1]; mae_T[2]; mae_T[3]]',
                bar_position = :dodge,
                bar_width=0.7,
                xticks=([1], ["Surface\nTemp"]),
                ylabel="MAE (°C)",
                title="Surface Temperature",
                label=["Neural ODE" "UDE" "ARIMA"],
                legend=false,
                titlefontsize=8,
                guidefontsize=7,
                tickfontsize=6,
                color=[:blue :red :green],
                ylims=(0, max_mae_T * 1.15))

p_abs2 = groupedbar([mae_O[1]; mae_O[2]; mae_O[3]]',
                bar_position = :dodge,
                bar_width=0.7,
                xticks=([1], ["Ocean\nTemp"]),
                ylabel="MAE (°C)",
                title="Ocean Temperature",
                label=["Neural ODE" "UDE" "ARIMA"],
                legend=false,
                titlefontsize=8,
                guidefontsize=7,
                tickfontsize=6,
                color=[:blue :red :green],
                ylims=(0, max_mae_O * 1.15))

p_abs3 = groupedbar([mae_C[1]; mae_C[2]; mae_C[3]]',
                bar_position = :dodge,
                bar_width=0.7,
                xticks=([1], ["CO₂\nConc"]),
                ylabel="MAE (ppm)",
                title="CO₂ Concentration",
                label=["Neural ODE" "UDE" "ARIMA"],
                legend=:topright,
                titlefontsize=8,
                legendfontsize=6,
                guidefontsize=7,
                tickfontsize=6,
                color=[:blue :red :green],
                ylims=(0, max_mae_C * 1.15))

p_absolute = plot(p_abs1, p_abs2, p_abs3, 
                 layout=(1,3), 
                 size=(800, 400),
                 plot_title="Absolute Error Comparison (MAE) - Extrapolation Period",
                 plot_titlefontsize=10)

display(p_absolute)
savefig(p_absolute, "random_seed_91_1p_noise_ARIMA_vs_all_methods_comparison_absolute_errors.png")

# Plot 3: Percentage Error Comparison with Proper Scaling
mape_T = [metrics_dict[1].node.mape, metrics_dict[1].ude.mape, metrics_dict[1].arima.mape]
mape_O = [metrics_dict[2].node.mape, metrics_dict[2].ude.mape, metrics_dict[2].arima.mape]
mape_C = [metrics_dict[3].node.mape, metrics_dict[3].ude.mape, metrics_dict[3].arima.mape]

# Find max values for scaling
max_mape_T = maximum(mape_T)
max_mape_O = maximum(mape_O)
max_mape_C = maximum(mape_C)

# Create separate plots for each variable with appropriate scaling
p_pct1 = groupedbar([mape_T[1]; mape_T[2]; mape_T[3]]',
                bar_position = :dodge,
                bar_width=0.7,
                xticks=([1], ["Surface\nTemp"]),
                ylabel="MAPE (%)",
                title="Surface Temperature",
                label=["Neural ODE" "UDE" "ARIMA"],
                legend=false,
                titlefontsize=8,
                guidefontsize=7,
                tickfontsize=6,
                color=[:blue :red :green],
                ylims=(0, max_mape_T * 1.15))

p_pct2 = groupedbar([mape_O[1]; mape_O[2]; mape_O[3]]',
                bar_position = :dodge,
                bar_width=0.7,
                xticks=([1], ["Ocean\nTemp"]),
                ylabel="MAPE (%)",
                title="Ocean Temperature",
                label=["Neural ODE" "UDE" "ARIMA"],
                legend=false,
                titlefontsize=8,
                guidefontsize=7,
                tickfontsize=6,
                color=[:blue :red :green],
                ylims=(0, max_mape_O * 1.15))

p_pct3 = groupedbar([mape_C[1]; mape_C[2]; mape_C[3]]',
                bar_position = :dodge,
                bar_width=0.7,
                xticks=([1], ["CO₂\nConc"]),
                ylabel="MAPE (%)",
                title="CO₂ Concentration",
                label=["Neural ODE" "UDE" "ARIMA"],
                legend=:topright,
                titlefontsize=8,
                legendfontsize=6,
                guidefontsize=7,
                tickfontsize=6,
                color=[:blue :red :green],
                ylims=(0, max_mape_C * 1.15))

p_percentage = plot(p_pct1, p_pct2, p_pct3, 
                   layout=(1,3), 
                   size=(800, 400),
                   plot_title="Percentage Error Comparison (MAPE) - Extrapolation Period",
                   plot_titlefontsize=10)

display(p_percentage)
savefig(p_percentage, "random_seed_91_1p_noise_ARIMA_vs_all_methods_comparison_percentage_errors.png")

# Plot 4: Overall RMSE Comparison
rmse_T_node = sqrt(mean((true_data_total[1, test_idx] .- sol_node[1, test_idx]).^2))
rmse_T_ude = sqrt(mean((true_data_total[1, test_idx] .- sol_ude[1, test_idx]).^2))
rmse_T_arima = sqrt(mean((true_data_total[1, test_idx] .- arima_full_T[test_idx]).^2))

rmse_O_node = sqrt(mean((true_data_total[2, test_idx] .- sol_node[2, test_idx]).^2))
rmse_O_ude = sqrt(mean((true_data_total[2, test_idx] .- sol_ude[2, test_idx]).^2))
rmse_O_arima = sqrt(mean((true_data_total[2, test_idx] .- arima_full_O[test_idx]).^2))

rmse_C_node = sqrt(mean((true_data_total[3, test_idx] .- sol_node[3, test_idx]).^2))
rmse_C_ude = sqrt(mean((true_data_total[3, test_idx] .- sol_ude[3, test_idx]).^2))
rmse_C_arima = sqrt(mean((true_data_total[3, test_idx] .- arima_full_C[test_idx]).^2))

variables = ["Surface\nTemp", "Ocean\nTemp", "CO₂\nConc"]
x_pos = 1:3

p4 = groupedbar([rmse_T_node rmse_O_node rmse_C_node;
                 rmse_T_ude rmse_O_ude rmse_C_ude;
                 rmse_T_arima rmse_O_arima rmse_C_arima]',
                bar_position = :dodge,
                bar_width=0.25,
                xticks=(x_pos, variables),
                xlabel="Variables",
                ylabel="RMSE (Extrapolation Period)",
                title="RMSE Comparison Across All Models",
                label=["Neural ODE" "UDE" "ARIMA"],
                legend=:topright,
                size=(800, 500),
                color=[:blue :red :green])

display(p4)
savefig(p4, "random_seed_91_1p_noise_ARIMA_vs_all_methods_comparison_rmse_all.png")

println("\n" * "="^80)
println("FORECASTING COMPARISON COMPLETE")
println("="^80)
println("\nPlots saved:")
println("  1. comparison_combined_climate_variables.png (Combined 3-panel plot)")
println("  2. comparison_absolute_errors.png")
println("  3. comparison_percentage_errors.png")
println("  4. comparison_rmse_all.png")
println("\nARIMA Model Parameters:")
println("  Surface Temperature: ARIMA$(best_params_T)")
println("  Ocean Temperature:   ARIMA$(best_params_O)")
println("  CO₂ Concentration:   ARIMA$(best_params_C)")