

using OrdinaryDiffEq
using Plots
using Random
using Lux, Optimization, OptimizationOptimJL, ComponentArrays
using SciMLSensitivity
using LinearAlgebra
using OptimizationOptimisers
using Statistics
using Printf

rng = Random.default_rng()
Random.seed!(91)

# ============================================================================
# CLIMATE SYSTEM DEFINITION
# ============================================================================
function climate_system!(du, u, p, t)
    T, O, C = u
    α, C0, λ, κ, CT, CO, γ, β = p
    
    F_C = α * log(C / C0)
    E_t = 8.0 + 0.1 * t 
    
    du[1] = (1/CT) * (F_C - λ*T - κ*(T - O))
    du[2] = (1/CO) * κ*(T - O)
    du[3] = E_t - β*T*C - γ*C
end

α = 5.35
C0 = 280.0
λ = 1.0
κ = 0.69
CT = 8.0
CO = 80.0
γ = 0.01
β = 0.001

p = [α, C0, λ, κ, CT, CO, γ, β]

T0 = 0.0
O0 = 0.0
C0_init = 280.0
u0 = [T0, O0, C0_init]

tspan_train = (0.0, 15.0)
tspan_total = (0.0, 50.0)

# Generate true data
prob_train = ODEProblem(climate_system!, u0, tspan_train, p)
solution_train = solve(prob_train, Tsit5(), abstol=1e-12, reltol=1e-12, saveat=1.0)

prob_total = ODEProblem(climate_system!, u0, tspan_total, p)
solution_total = solve(prob_total, Tsit5(), abstol=1e-12, reltol=1e-12, saveat=1.0)

tsdata_train = Array(solution_train)

# Add noise
T_train_std = std(tsdata_train[1,:])
O_train_std = std(tsdata_train[2,:])
C_train_std = std(tsdata_train[3,:])

noise_percentage = 0.01
noisy_data = copy(tsdata_train)
noisy_data[1,:] .+= Float32(noise_percentage * T_train_std) .* randn(rng, size(tsdata_train[1,:]))
noisy_data[2,:] .+= Float32(noise_percentage * O_train_std) .* randn(rng, size(tsdata_train[2,:]))
noisy_data[3,:] .+= Float32(noise_percentage * C_train_std) .* randn(rng, size(tsdata_train[3,:]))

# Normalization parameters
T_mean, T_std = mean(tsdata_train[1,:]), std(tsdata_train[1,:])
O_mean, O_std = mean(tsdata_train[2,:]), std(tsdata_train[2,:])
C_mean, C_std = mean(tsdata_train[3,:]), std(tsdata_train[3,:])

T_std = T_std < 1e-8 ? 1e-8 : T_std
O_std = O_std < 1e-8 ? 1e-8 : O_std 
C_std = C_std < 1e-8 ? 1e-8 : C_std

function normalize_variables(T, O, C)
    T_norm = (T - T_mean) / T_std
    O_norm = (O - O_mean) / O_std
    C_norm = (C - C_mean) / C_std
    return T_norm, O_norm, C_norm
end

function denormalize_derivatives(dT_norm, dO_norm, dC_norm)
    dT = dT_norm * T_std
    dO = dO_norm * O_std
    dC = dC_norm * C_std
    return dT, dO, dC
end

# ============================================================================
# NEURAL ODE MODEL
# ============================================================================
println("\n" * "="^80)
println("TRAINING NEURAL ODE MODEL")
println("="^80)

ann_node = Lux.Chain(
    Lux.Dense(3, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 3)
)

rng = Random.default_rng()
p1, st1 = Lux.setup(rng, ann_node)
p_ann = ComponentArray(p1)

# Count parameters for Neural ODE
n_params_node = length(p_ann)

function dudt_node(du, u, p, t)
    T, O, C = u
    T_norm, O_norm, C_norm = normalize_variables(T, O, C)
    nn_output = ann_node([T_norm, O_norm, C_norm], p, st1)[1]
    dT, dO, dC = denormalize_derivatives(nn_output[1], nn_output[2], nn_output[3])
    du[1] = dT
    du[2] = dO
    du[3] = dC
end

prob_node = ODEProblem{true}(dudt_node, u0, tspan_train, p_ann)

function predict_node(θ)
    Array(solve(prob_node, Tsit5(), p=θ, saveat=1.0, 
                abstol=1e-6, reltol=1e-6,
                sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end

function loss_node(θ)
    pred = predict_node(θ)
    loss = sum(abs2, (noisy_data .- pred))
    return loss
end

iter_node = 0
losses_node = []
function callback_node(θ, l)
    global iter_node, losses_node
    iter_node += 1
    push!(losses_node, l)
    if iter_node % 50 == 0
        println("Neural ODE Iteration: $iter_node, Loss: $l")
    end
    return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p_inner) -> loss_node(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p_ann)

iters_adam = 20000
println("\nStarting Neural ODE Adam training ($iters_adam iterations)...")
time_node_start = time()
res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(0.00001), 
                         callback=callback_node, maxiters=iters_adam)

iters_adamw = 20000
optprob2 = remake(optprob, u0=res1.u)
println("Starting Neural ODE AdamW training ($iters_adamw iterations)...")
res2 = Optimization.solve(optprob2, OptimizationOptimisers.AdamW(0.00001, (0.9, 0.999), 1e-4),
                         callback=callback_node, maxiters=iters_adamw)
time_node_end = time()
time_node_total = time_node_end - time_node_start

println("Neural ODE training complete!")

# Neural ODE predictions
prob_node_extrapolate = ODEProblem{true}(dudt_node, u0, tspan_total, res2.u)
sol_node = Array(solve(prob_node_extrapolate, Tsit5(), saveat=solution_total.t,
                       abstol=1e-12, reltol=1e-12,
                       sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))

# ============================================================================
# UDE MODEL
# ============================================================================
println("\n" * "="^80)
println("TRAINING UDE MODEL")
println("="^80)

nn_beta_tc_model = Lux.Chain(
    Lux.Dense(2, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 1)
)

p_nn_beta_tc_initial, st_nn_beta_tc_model = Lux.setup(rng, nn_beta_tc_model)
p_nn_beta_tc_initial = ComponentArray(p_nn_beta_tc_initial)

# Count parameters for UDE
n_params_ude = length(p_nn_beta_tc_initial)

function ude_climate_system!(du, u, p_nn, t)
    T, O, C = u
    
    F_C = α * log(C / C0)
    E_t = 8.0 + 0.1 * t
    
    T_norm = (T - T_mean) / T_std
    C_norm = (C - C_mean) / C_std
    
    learned_beta_T_C_term = nn_beta_tc_model([T_norm, C_norm], p_nn, st_nn_beta_tc_model)[1][1]
    
    du[1] = (1/CT) * (F_C - λ*T - κ*(T - O))
    du[2] = (1/CO) * κ*(T - O)
    du[3] = E_t - learned_beta_T_C_term - γ*C
end

prob_ude_train = ODEProblem{true}(ude_climate_system!, u0, tspan_train, p_nn_beta_tc_initial)

function predict_ude(θ)
    sol = solve(prob_ude_train, Tsit5(), p=θ, saveat=solution_train.t,
                abstol=1e-6, reltol=1e-6,
                sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))
    
    if sol.retcode != ReturnCode.Success || typeof(Array(sol)) != typeof(noisy_data) || size(Array(sol)) != size(noisy_data)
        return fill(Inf32, size(noisy_data))
    end
    return Array(sol)
end

function loss_ude(θ)
    pred = predict_ude(θ)
    loss = sum(abs2, noisy_data .- pred)
    return loss
end

iter_ude = 0
losses_ude = []
function callback_ude(θ, l)
    global iter_ude, losses_ude
    iter_ude += 1
    push!(losses_ude, l)
    if iter_ude % 50 == 0
        println("UDE Iteration: $iter_ude, Loss: $l")
    end
    return false
end

adtype_ude = Optimization.AutoZygote()
optf_ude = Optimization.OptimizationFunction((x, p_unused) -> loss_ude(x), adtype_ude)

iters_adam_ude = 5500
optprob_ude_adam = Optimization.OptimizationProblem(optf_ude, p_nn_beta_tc_initial)
println("\nStarting UDE Adam training ($iters_adam_ude iterations)...")
time_ude_start = time()
res_ude1 = Optimization.solve(optprob_ude_adam, OptimizationOptimisers.Adam(0.001),
                            callback=callback_ude, maxiters=iters_adam_ude)

iters_adamw_ude = 5500
optprob_ude_adamw = remake(optprob_ude_adam, u0=res_ude1.u)
println("Starting UDE AdamW training ($iters_adamw_ude iterations)...")
res_ude2 = Optimization.solve(optprob_ude_adamw, OptimizationOptimisers.AdamW(0.0001, (0.9, 0.999), 1e-4), 
                            callback=callback_ude, maxiters=iters_adamw_ude)
time_ude_end = time()
time_ude_total = time_ude_end - time_ude_start

println("UDE training complete!")

# UDE predictions
prob_ude_extrapolate = ODEProblem{true}(ude_climate_system!, u0, tspan_total, res_ude2.u)
sol_ude = Array(solve(prob_ude_extrapolate, Tsit5(), saveat=solution_total.t,
                      abstol=1e-12, reltol=1e-12))

# ============================================================================
# VAR MODEL WITH MANUAL IMPLEMENTATION
# ============================================================================
println("\n" * "="^80)
println("TRAINING VAR MODELS WITH LAG SELECTION")
println("="^80)

# Manual VAR implementation functions
function create_var_matrix(Y, p)
    """
    Create design matrix for VAR(p) model
    Y: T×K matrix of variables
    p: lag order
    Returns X (design matrix) and Y_trim (trimmed Y for alignment)
    """
    T, K = size(Y)
    T_eff = T - p
    
    # Create lagged matrix
    X = ones(T_eff, K*p + 1)  # +1 for constant
    for lag in 1:p
        for var in 1:K
            col_idx = (lag-1)*K + var + 1
            X[:, col_idx] = Y[p+1-lag:end-lag, var]
        end
    end
    
    Y_trim = Y[p+1:end, :]
    
    return X, Y_trim
end

function estimate_var(Y, p)
    """
    Estimate VAR(p) model using equation-by-equation OLS
    Y: T×K matrix of variables
    p: lag order
    """
    X, Y_trim = create_var_matrix(Y, p)
    T_eff, K = size(Y_trim)
    
    # Check for sufficient degrees of freedom
    n_params = K*p + 1
    if T_eff <= n_params + K
        error("Insufficient degrees of freedom: T_eff=$T_eff, parameters=$n_params")
    end
    
    # Store coefficients and residuals
    B = zeros(n_params, K)
    residuals = zeros(T_eff, K)
    
    # Estimate each equation separately using OLS with regularization if needed
    for k in 1:K
        # Add small ridge regularization to avoid singularity
        XtX = X'*X
        λ = 1e-8  # Small regularization parameter
        XtX_reg = XtX + λ*I(size(XtX, 1))
        
        # OLS with regularization: β = (X'X + λI)^(-1)X'y
        B[:, k] = XtX_reg \ (X'*Y_trim[:, k])
        residuals[:, k] = Y_trim[:, k] - X * B[:, k]
    end
    
    # Estimate covariance matrix with degrees of freedom correction
    Sigma = (residuals' * residuals) / (T_eff - n_params)
    
    # Add small regularization to ensure positive definiteness
    Sigma = Sigma + 1e-6 * I(K)
    
    return B, residuals, Sigma, X, Y_trim
end

function safe_logdet(Sigma)
    """
    Safely compute log determinant with fallback for non-positive definite matrices
    """
    try
        # Try direct computation
        ld = logdet(Sigma)
        if isfinite(ld)
            return ld
        end
    catch
        # If that fails, use eigenvalue decomposition
    end
    
    # Fallback: use eigenvalues
    eigenvals = eigvals(Symmetric(Sigma))
    
    # Filter out very small or negative eigenvalues
    positive_eigenvals = filter(x -> x > 1e-10, eigenvals)
    
    if length(positive_eigenvals) == 0
        return NaN  # Matrix is essentially zero
    end
    
    # Return sum of log of positive eigenvalues
    return sum(log.(positive_eigenvals))
end

function calculate_information_criteria(Y, p_max)
    """
    Calculate AIC, BIC, and HQC for different lag orders
    """
    T, K = size(Y)
    
    # Store criteria values
    AICs = Float64[]
    BICs = Float64[]
    HQCs = Float64[]
    FPEs = Float64[]
    valid_lags = Int[]
    
    for p in 1:p_max
        # Check if we have enough observations
        T_eff = T - p
        n_params_per_eq = K * p + 1
        
        # Need at least 2*n_params observations for reliable estimation
        if T_eff < 2 * n_params_per_eq
            println("  Skipping lag $p: insufficient observations (T_eff=$T_eff < $(2*n_params_per_eq))")
            continue
        end
        
        try
            _, residuals, Sigma, _, _ = estimate_var(Y, p)
            
            # Log determinant of covariance matrix with safety check
            log_det_Sigma = safe_logdet(Sigma)
            
            if !isfinite(log_det_Sigma)
                println("  Skipping lag $p: numerical issues with covariance matrix")
                continue
            end
            
            # Number of parameters
            total_params = K^2 * p
            
            # Information criteria (Lütkepohl formulations)
            AIC = log_det_Sigma + (2 * total_params) / T_eff
            BIC = log_det_Sigma + (log(T_eff) * total_params) / T_eff
            HQC = log_det_Sigma + (2 * log(log(T_eff)) * total_params) / T_eff
            
            # FPE with safety check
            fpe_factor = (T_eff + total_params) / (T_eff - total_params)
            if fpe_factor > 0 && T_eff > total_params
                FPE = fpe_factor^K * exp(log_det_Sigma)
            else
                FPE = Inf
            end
            
            push!(AICs, AIC)
            push!(BICs, BIC)
            push!(HQCs, HQC)
            push!(FPEs, FPE)
            push!(valid_lags, p)
            
        catch e
            println("  Skipping lag $p due to error: $e")
            continue
        end
    end
    
    return AICs, BICs, HQCs, FPEs, valid_lags
end

function forecast_var(B, Y_last, n_ahead, p)
    """
    Forecast VAR model n_ahead periods
    B: coefficient matrix
    Y_last: last p observations (p×K matrix)
    n_ahead: forecast horizon
    p: lag order
    """
    K = size(B, 2)
    forecasts = zeros(n_ahead, K)
    
    # Create augmented history matrix
    Y_history = copy(Y_last)
    
    for h in 1:n_ahead
        # Create regressor vector
        x_h = ones(1, K*p + 1)
        for lag in 1:p
            for var in 1:K
                col_idx = (lag-1)*K + var + 1
                row_idx = size(Y_history, 1) - lag + 1
                x_h[1, col_idx] = Y_history[row_idx, var]
            end
        end
        
        # Forecast all variables
        y_hat = x_h * B
        forecasts[h, :] = y_hat
        
        # Update history
        Y_history = vcat(Y_history, y_hat)
    end
    
    return forecasts
end

# Prepare data for VAR
var_train_data = noisy_data'  # Transpose to T×K format
var_true_data = Array(solution_total)'

# Grid search for optimal lag length
println("\n--- VAR Lag Selection ---")
p_max = min(6, Int(floor(length(solution_train.t) / 4)))  # Ensure reasonable lag relative to sample size

println("Maximum lag to test: $p_max")
println("Sample size: $(size(var_train_data, 1))")

AICs, BICs, HQCs, FPEs, valid_lags = calculate_information_criteria(var_train_data, p_max)

if isempty(valid_lags)
    error("No valid VAR models could be estimated. Try reducing p_max or checking your data.")
end

# Find optimal lags among valid models
p_aic = valid_lags[argmin(AICs)]
p_bic = valid_lags[argmin(BICs)]
p_hqc = valid_lags[argmin(HQCs)]

println("\nInformation Criteria Results:")
println("─"^40)
for (i, p) in enumerate(valid_lags)
    @printf("Lag %d: AIC=%.4f  BIC=%.4f  HQC=%.4f\n", 
            p, AICs[i], BICs[i], HQCs[i])
end
println("\nOptimal lag selection:")
println("  AIC selects p = $p_aic")
println("  BIC selects p = $p_bic")
println("  HQC selects p = $p_hqc")

# Use BIC selection (tends to be more parsimonious)
p_optimal = p_bic
println("\n✓ Selected VAR($p_optimal) based on BIC")

# Estimate VAR with optimal lag
time_var_start = time()
B_var, residuals_var, Sigma_var, _, _ = estimate_var(var_train_data, p_optimal)

println("\nVAR($p_optimal) Estimation Results:")
println("─"^40)
println("Coefficient matrix dimensions: $(size(B_var))")
println("Residual covariance matrix:")
for i in 1:3
    @printf("  Σ[%d,:] = [%.6f, %.6f, %.6f]\n", 
            i, Sigma_var[i,1], Sigma_var[i,2], Sigma_var[i,3])
end

# Check for stationarity (companion form eigenvalues)
function check_var_stability(B, p, K)
    """
    Check VAR stability via companion form eigenvalues
    """
    if p == 0
        return 0.0, []
    end
    
    # Create companion matrix
    comp_size = K * p
    companion = zeros(comp_size, comp_size)
    
    # First K rows from B (excluding constant)
    for i in 1:K
        for j in 1:min(comp_size, size(B, 1)-1)
            companion[i, j] = B[j+1, i]  # Skip constant term
        end
    end
    
    # Identity blocks
    if p > 1
        for i in 1:(p-1)
            start_row = K * i + 1
            start_col = K * (i-1) + 1
            for j in 0:(K-1)
                if start_row + j <= comp_size && start_col + j <= comp_size
                    companion[start_row + j, start_col + j] = 1.0
                end
            end
        end
    end
    
    eigenvals = eigvals(companion)
    max_eigenval = maximum(abs.(eigenvals))
    
    return max_eigenval, eigenvals
end

K = 3  # Number of variables
max_eig, all_eigs = check_var_stability(B_var, p_optimal, K)
println("\n✓ VAR Stability Check:")
println("  Maximum eigenvalue modulus: $(round(max_eig, digits=4))")
if max_eig < 1
    println("  → Model is stable (all eigenvalues inside unit circle)")
else
    println("  ⚠ Warning: Model may be unstable")
end

# Generate VAR forecasts
n_forecast = length(solution_total.t) - length(solution_train.t)

# Get last p observations for initialization
Y_init = var_train_data[end-p_optimal+1:end, :]

# Forecast
var_forecasts = forecast_var(B_var, Y_init, n_forecast, p_optimal)

time_var_end = time()
time_var_total = time_var_end - time_var_start

# Combine training and forecast data
var_full_T = vcat(var_train_data[:, 1], var_forecasts[:, 1])
var_full_O = vcat(var_train_data[:, 2], var_forecasts[:, 2])
var_full_C = vcat(var_train_data[:, 3], var_forecasts[:, 3])

# Count parameters for VAR
# VAR(p) has K*(K*p + 1) parameters total (K equations, each with K*p lags + 1 constant)
n_params_var_total = K * (K * p_optimal + 1)

# Total iterations for VAR is the number of models evaluated in grid search
total_iters_var = length(valid_lags)

# ============================================================================
# COMPUTATIONAL RESOURCE USAGE COMPARISON
# ============================================================================
println("\n" * "="^80)
println("COMPUTATIONAL RESOURCE USAGE COMPARISON")
println("="^80)

total_iters_node = iters_adam + iters_adamw
total_iters_ude = iters_adam_ude + iters_adamw_ude

println("\n┌" * "─"^78 * "┐")
println("│" * " "^25 * "COMPUTATIONAL RESOURCES COMPARISON" * " "^19 * "│")
println("├" * "─"^78 * "┤")
@printf("│ %-30s │ %13s │ %13s │ %13s │\n", "Metric", "Neural ODE", "UDE", "VAR")
println("├" * "─"^78 * "┤")
@printf("│ %-30s │ %13d │ %13d │ %13d │\n", "Number of Parameters", n_params_node, n_params_ude, n_params_var_total)
@printf("│ %-30s │ %13d │ %13d │ %13d │\n", "Training Iterations", total_iters_node, total_iters_ude, total_iters_var)
@printf("│ %-30s │ %10.2f s │ %10.2f s │ %10.2f s │\n", "Training Time", time_node_total, time_ude_total, time_var_total)
@printf("│ %-30s │ %10.2f ms │ %10.2f ms │ %10.2f ms │\n", "Time per Iteration", 
        time_node_total/total_iters_node*1000, time_ude_total/total_iters_ude*1000, time_var_total/total_iters_var*1000)
println("├" * "─"^78 * "┤")
@printf("│ %-30s │ %13s │ %13s │ %13s │\n", "Model Complexity", "High", "Medium", "Low")
@printf("│ %-30s │ %13s │ %13s │ %13s │\n", "Physical Knowledge", "None", "Partial", "None")
@printf("│ %-30s │ %13s │ %13s │ %13s │\n", "Interpretability", "Low", "High", "Medium")
println("└" * "─"^78 * "┘")

# ============================================================================
# CALCULATE COMPARISON METRICS
# ============================================================================
println("\n" * "="^80)
println("FORECASTING COMPARISON METRICS")
println("="^80)

true_data_total = Array(solution_total)
time_points_total = solution_total.t

# Split into training and testing periods
train_idx = 1:length(solution_train.t)
test_idx = (length(solution_train.t)+1):length(time_points_total)

function calculate_metrics(true_vals, pred_vals, var_name, model_name)
    mae = mean(abs.(true_vals .- pred_vals))
    rmse = sqrt(mean((true_vals .- pred_vals).^2))
    mape = mean(abs.((true_vals .- pred_vals) ./ (true_vals .+ 1e-10))) * 100
    max_error = maximum(abs.(true_vals .- pred_vals))
    
    ss_res = sum((true_vals .- pred_vals).^2)
    ss_tot = sum((true_vals .- mean(true_vals)).^2)
    r2 = 1 - ss_res / ss_tot
    
    return (mae=mae, rmse=rmse, mape=mape, max_error=max_error, r2=r2)
end

# Calculate metrics for test period only (extrapolation)
println("\n" * "─"^80)
println("EXTRAPOLATION PERIOD METRICS (t = 15 to 50 years)")
println("─"^80)

# Store metrics for plotting
metrics_dict = Dict()

for (i, var_name, unit) in [(1, "Surface Temperature", "°C"), 
                             (2, "Ocean Temperature", "°C"), 
                             (3, "CO₂ Concentration", "ppm")]
    println("\n$(var_name) ($(unit)):")
    println("="^60)
    
    true_test = true_data_total[i, test_idx]
    
    # Neural ODE
    node_test = sol_node[i, test_idx]
    metrics_node = calculate_metrics(true_test, node_test, var_name, "Neural ODE")
    
    # UDE
    ude_test = sol_ude[i, test_idx]
    metrics_ude = calculate_metrics(true_test, ude_test, var_name, "UDE")
    
    # VAR
    if i == 1
        var_test = var_full_T[test_idx]
    elseif i == 2
        var_test = var_full_O[test_idx]
    else
        var_test = var_full_C[test_idx]
    end
    metrics_var = calculate_metrics(true_test, var_test, var_name, "VAR")
    
    # Store for plotting
    metrics_dict[i] = (node=metrics_node, ude=metrics_ude, var=metrics_var)
    
    # Print comparison table
    println("\n                   Neural ODE       UDE           VAR")
    println("─"^60)
    @printf("MAE:              %.6f      %.6f      %.6f\n", metrics_node.mae, metrics_ude.mae, metrics_var.mae)
    @printf("RMSE:             %.6f      %.6f      %.6f\n", metrics_node.rmse, metrics_ude.rmse, metrics_var.rmse)
    @printf("MAPE (%%):         %.4f       %.4f       %.4f\n", metrics_node.mape, metrics_ude.mape, metrics_var.mape)
    @printf("Max Error:        %.6f      %.6f      %.6f\n", metrics_node.max_error, metrics_ude.max_error, metrics_var.max_error)
    @printf("R²:               %.6f      %.6f      %.6f\n", metrics_node.r2, metrics_ude.r2, metrics_var.r2)
    
    # Determine best model
    best_model = ""
    if metrics_node.rmse < metrics_ude.rmse && metrics_node.rmse < metrics_var.rmse
        best_model = "Neural ODE"
    elseif metrics_ude.rmse < metrics_var.rmse
        best_model = "UDE"
    else
        best_model = "VAR"
    end
    println("\n✓ Best Model (by RMSE): $(best_model)")
end

# Final values comparison at t=50
println("\n" * "─"^80)
println("FINAL VALUES AT t = 50 years")
println("─"^80)

for (i, var_name, unit) in [(1, "Surface Temperature", "°C"), 
                             (2, "Ocean Temperature", "°C"), 
                             (3, "CO₂ Concentration", "ppm")]
    println("\n$(var_name):")
    
    true_final = true_data_total[i, end]
    node_final = sol_node[i, end]
    ude_final = sol_ude[i, end]
    
    if i == 1
        var_final = var_full_T[end]
    elseif i == 2
        var_final = var_full_O[end]
    else
        var_final = var_full_C[end]
    end
    
    @printf("  True Value:       %.6f %s\n", true_final, unit)
    @printf("  Neural ODE:       %.6f %s (error: %.6f, %.2f%%)\n", 
            node_final, unit, abs(true_final - node_final), 
            abs(true_final - node_final)/abs(true_final)*100)
    @printf("  UDE:              %.6f %s (error: %.6f, %.2f%%)\n", 
            ude_final, unit, abs(true_final - ude_final), 
            abs(true_final - ude_final)/abs(true_final)*100)
    @printf("  VAR:              %.6f %s (error: %.6f, %.2f%%)\n", 
            var_final, unit, abs(true_final - var_final), 
            abs(true_final - var_final)/abs(true_final)*100)
end

# ============================================================================
# CREATE COMPARISON PLOTS
# ============================================================================
println("\n" * "="^80)
println("GENERATING COMPARISON PLOTS")
println("="^80)

# Plot 1: Combined Climate Variables Comparison (3 subplots)
p1 = plot(title="Surface Temperature Anomaly",
          xlabel="Time (years)",
          ylabel="Temperature (°C)",
          legend=:topleft,
          titlefontsize=8,
          legendfontsize=4,
          guidefontsize=7,
          tickfontsize=6)

scatter!(p1, time_points_total, true_data_total[1,:], 
         label="True", markersize=2, color=:black, alpha=0.6)
plot!(p1, time_points_total, sol_node[1,:], 
      label="Neural ODE", lw=1.5, color=:blue)
plot!(p1, time_points_total, sol_ude[1,:], 
      label="UDE", lw=1.5, color=:red, linestyle=:dash)
plot!(p1, time_points_total, var_full_T, 
      label="VAR", lw=1.5, color=:green, linestyle=:dot)
vline!(p1, [15.0], label="Train End", color=:black, lw=1.5, linestyle=:dashdot)

p2 = plot(title="Deep Ocean Temperature Anomaly",
          xlabel="Time (years)",
          ylabel="Temperature (°C)",
          legend=:topleft,
          titlefontsize=8,
          legendfontsize=4,
          guidefontsize=7,
          tickfontsize=6)

scatter!(p2, time_points_total, true_data_total[2,:], 
         label="True", markersize=2, color=:black, alpha=0.6)
plot!(p2, time_points_total, sol_node[2,:], 
      label="Neural ODE", lw=1.5, color=:blue)
plot!(p2, time_points_total, sol_ude[2,:], 
      label="UDE", lw=1.5, color=:red, linestyle=:dash)
plot!(p2, time_points_total, var_full_O, 
      label="VAR", lw=1.5, color=:green, linestyle=:dot)
vline!(p2, [15.0], label="Train End", color=:black, lw=1.5, linestyle=:dashdot)

p3 = plot(title="CO₂ Concentration",
          xlabel="Time (years)",
          ylabel="Concentration (ppm)",
          legend=:topleft,
          titlefontsize=8,
          legendfontsize=4,
          guidefontsize=7,
          tickfontsize=6)

scatter!(p3, time_points_total, true_data_total[3,:], 
         label="True", markersize=2, color=:black, alpha=0.6)
plot!(p3, time_points_total, sol_node[3,:], 
      label="Neural ODE", lw=1.5, color=:blue)
plot!(p3, time_points_total, sol_ude[3,:], 
      label="UDE", lw=1.5, color=:red, linestyle=:dash)
plot!(p3, time_points_total, var_full_C, 
      label="VAR", lw=1.5, color=:green, linestyle=:dot)
vline!(p3, [15.0], label="Train End", color=:black, lw=1.5, linestyle=:dashdot)

p_combined = plot(p1, p2, p3, layout=(3,1), size=(593, 371))
display(p_combined)
savefig(p_combined, "random_seed_91_1p_noise_VAR_vs_all_methods_comparison_combined_climate_variables.png")

# Plot 2: Absolute Error Comparison with Proper Scaling
mae_T = [metrics_dict[1].node.mae, metrics_dict[1].ude.mae, metrics_dict[1].var.mae]
mae_O = [metrics_dict[2].node.mae, metrics_dict[2].ude.mae, metrics_dict[2].var.mae]
mae_C = [metrics_dict[3].node.mae, metrics_dict[3].ude.mae, metrics_dict[3].var.mae]

# Find max values for scaling
max_mae_T = maximum(mae_T)
max_mae_O = maximum(mae_O)
max_mae_C = maximum(mae_C)

# Create separate plots for each variable with appropriate scaling
p_abs1 = groupedbar([mae_T[1]; mae_T[2]; mae_T[3]]',
                bar_position = :dodge,
                bar_width=0.7,
                xticks=([1], ["Surface\nTemp"]),
                ylabel="MAE (°C)",
                title="Surface Temperature",
                label=["Neural ODE" "UDE" "VAR"],
                legend=false,
                titlefontsize=8,
                guidefontsize=7,
                tickfontsize=6,
                color=[:blue :red :green],
                ylims=(0, max_mae_T * 1.15))

p_abs2 = groupedbar([mae_O[1]; mae_O[2]; mae_O[3]]',
                bar_position = :dodge,
                bar_width=0.7,
                xticks=([1], ["Ocean\nTemp"]),
                ylabel="MAE (°C)",
                title="Ocean Temperature",
                label=["Neural ODE" "UDE" "VAR"],
                legend=false,
                titlefontsize=8,
                guidefontsize=7,
                tickfontsize=6,
                color=[:blue :red :green],
                ylims=(0, max_mae_O * 1.15))

p_abs3 = groupedbar([mae_C[1]; mae_C[2]; mae_C[3]]',
                bar_position = :dodge,
                bar_width=0.7,
                xticks=([1], ["CO₂\nConc"]),
                ylabel="MAE (ppm)",
                title="CO₂ Concentration",
                label=["Neural ODE" "UDE" "VAR"],
                legend=:topright,
                titlefontsize=8,
                legendfontsize=6,
                guidefontsize=7,
                tickfontsize=6,
                color=[:blue :red :green],
                ylims=(0, max_mae_C * 1.15))

p_absolute = plot(p_abs1, p_abs2, p_abs3, 
                 layout=(1,3), 
                 size=(800, 400),
                 plot_title="Absolute Error Comparison (MAE) - Extrapolation Period",
                 plot_titlefontsize=10)

display(p_absolute)
savefig(p_absolute, "random_seed_91_1p_noise_VAR_vs_all_methods_comparison_absolute_errors.png")

# Plot 3: Percentage Error Comparison with Proper Scaling
mape_T = [metrics_dict[1].node.mape, metrics_dict[1].ude.mape, metrics_dict[1].var.mape]
mape_O = [metrics_dict[2].node.mape, metrics_dict[2].ude.mape, metrics_dict[2].var.mape]
mape_C = [metrics_dict[3].node.mape, metrics_dict[3].ude.mape, metrics_dict[3].var.mape]

# Find max values for scaling
max_mape_T = maximum(mape_T)
max_mape_O = maximum(mape_O)
max_mape_C = maximum(mape_C)

# Create separate plots for each variable with appropriate scaling
p_pct1 = groupedbar([mape_T[1]; mape_T[2]; mape_T[3]]',
                bar_position = :dodge,
                bar_width=0.7,
                xticks=([1], ["Surface\nTemp"]),
                ylabel="MAPE (%)",
                title="Surface Temperature",
                label=["Neural ODE" "UDE" "VAR"],
                legend=false,
                titlefontsize=8,
                guidefontsize=7,
                tickfontsize=6,
                color=[:blue :red :green],
                ylims=(0, max_mape_T * 1.15))

p_pct2 = groupedbar([mape_O[1]; mape_O[2]; mape_O[3]]',
                bar_position = :dodge,
                bar_width=0.7,
                xticks=([1], ["Ocean\nTemp"]),
                ylabel="MAPE (%)",
                title="Ocean Temperature",
                label=["Neural ODE" "UDE" "VAR"],
                legend=false,
                titlefontsize=8,
                guidefontsize=7,
                tickfontsize=6,
                color=[:blue :red :green],
                ylims=(0, max_mape_O * 1.15))

p_pct3 = groupedbar([mape_C[1]; mape_C[2]; mape_C[3]]',
                bar_position = :dodge,
                bar_width=0.7,
                xticks=([1], ["CO₂\nConc"]),
                ylabel="MAPE (%)",
                title="CO₂ Concentration",
                label=["Neural ODE" "UDE" "VAR"],
                legend=:topright,
                titlefontsize=8,
                legendfontsize=6,
                guidefontsize=7,
                tickfontsize=6,
                color=[:blue :red :green],
                ylims=(0, max_mape_C * 1.15))

p_percentage = plot(p_pct1, p_pct2, p_pct3, 
                   layout=(1,3), 
                   size=(800, 400),
                   plot_title="Percentage Error Comparison (MAPE) - Extrapolation Period",
                   plot_titlefontsize=10)

display(p_percentage)
savefig(p_percentage, "random_seed_91_1p_noise_VAR_vs_all_methods_comparison_percentage_errors.png")

# Plot 4: Overall RMSE Comparison
rmse_T_node = sqrt(mean((true_data_total[1, test_idx] .- sol_node[1, test_idx]).^2))
rmse_T_ude = sqrt(mean((true_data_total[1, test_idx] .- sol_ude[1, test_idx]).^2))
rmse_T_var = sqrt(mean((true_data_total[1, test_idx] .- var_full_T[test_idx]).^2))

rmse_O_node = sqrt(mean((true_data_total[2, test_idx] .- sol_node[2, test_idx]).^2))
rmse_O_ude = sqrt(mean((true_data_total[2, test_idx] .- sol_ude[2, test_idx]).^2))
rmse_O_var = sqrt(mean((true_data_total[2, test_idx] .- var_full_O[test_idx]).^2))

rmse_C_node = sqrt(mean((true_data_total[3, test_idx] .- sol_node[3, test_idx]).^2))
rmse_C_ude = sqrt(mean((true_data_total[3, test_idx] .- sol_ude[3, test_idx]).^2))
rmse_C_var = sqrt(mean((true_data_total[3, test_idx] .- var_full_C[test_idx]).^2))

variables = ["Surface\nTemp", "Ocean\nTemp", "CO₂\nConc"]
x_pos = 1:3

p4 = groupedbar([rmse_T_node rmse_O_node rmse_C_node;
                 rmse_T_ude rmse_O_ude rmse_C_ude;
                 rmse_T_var rmse_O_var rmse_C_var]',
                bar_position = :dodge,
                bar_width=0.25,
                xticks=(x_pos, variables),
                xlabel="Variables",
                ylabel="RMSE (Extrapolation Period)",
                title="RMSE Comparison Across All Models",
                label=["Neural ODE" "UDE" "VAR"],
                legend=:topright,
                size=(800, 500),
                color=[:blue :red :green])

display(p4)
savefig(p4, "random_seed_91_1p_noise_VAR_vs_all_methods_comparison_rmse_all.png")

println("\n" * "="^80)
println("FORECASTING COMPARISON COMPLETE")
println("="^80)
println("\nPlots saved:")
println("  1. random_seed_91_1p_noise_VAR_vs_all_methods_comparison_combined_climate_variables.png")
println("  2. random_seed_91_1p_noise_VAR_vs_all_methods_comparison_absolute_errors.png")
println("  3. random_seed_91_1p_noise_VAR_vs_all_methods_comparison_percentage_errors.png")
println("  4. random_seed_91_1p_noise_VAR_vs_all_methods_comparison_rmse_all.png")
println("\nVAR Model Parameters:")
println("  Optimal lag order: VAR($p_optimal)")
println("  Selected by: BIC criterion")