using OrdinaryDiffEq
using Plots
using Random
using Lux, Optimization, OptimizationOptimJL, ComponentArrays
using SciMLSensitivity
using LinearAlgebra
using OptimizationOptimisers
using Statistics
using Printf

rng = Random.default_rng()
Random.seed!(91)

function climate_system!(du, u, p, t)
    T, O, C = u
    α, C0, λ, κ, CT, CO, γ, β = p
    
    F_C = α * log(C / C0)
    E_t = 8.0 + 0.1 * t 
    
    du[1] = (1/CT) * (F_C - λ*T - κ*(T - O))
    du[2] = (1/CO) * κ*(T - O)
    du[3] = E_t - β*T*C - γ*C
end

α = 5.35
C0 = 280.0
λ = 1.0
κ = 0.69
CT = 8.0
CO = 80.0
γ = 0.01
β = 0.001

p = [α, C0, λ, κ, CT, CO, γ, β]

T0 = 0.0
O0 = 0.0
C0_init = 280.0

u0 = [T0, O0, C0_init]

tspan_train = (0.0, 15.0)
tspan_total = (0.0, 50.0)

prob_train = ODEProblem(climate_system!, u0, tspan_train, p)
solution_train = solve(prob_train, Tsit5(), abstol=1e-12, reltol=1e-12, saveat=1.0)

prob_total = ODEProblem(climate_system!, u0, tspan_total, p)
solution_total = solve(prob_total, Tsit5(), abstol=1e-12, reltol=1e-12, saveat=1.0)

tsdata_train = Array(solution_train)

T_train_std = std(tsdata_train[1,:])
O_train_std = std(tsdata_train[2,:])
C_train_std = std(tsdata_train[3,:])

noise_percentage = 0.01
noisy_data = copy(tsdata_train)
noisy_data[1,:] .+= Float32(noise_percentage * T_train_std) .* randn(rng, size(tsdata_train[1,:]))
noisy_data[2,:] .+= Float32(noise_percentage * O_train_std) .* randn(rng, size(tsdata_train[2,:]))
noisy_data[3,:] .+= Float32(noise_percentage * C_train_std) .* randn(rng, size(tsdata_train[3,:]))

T_mean, T_std = mean(tsdata_train[1,:]), std(tsdata_train[1,:])
O_mean, O_std = mean(tsdata_train[2,:]), std(tsdata_train[2,:])
C_mean, C_std = mean(tsdata_train[3,:]), std(tsdata_train[3,:])

T_std = T_std < 1e-8 ? 1e-8 : T_std
O_std = O_std < 1e-8 ? 1e-8 : O_std 
C_std = C_std < 1e-8 ? 1e-8 : C_std

function normalize_variables(T, O, C)
    T_norm = (T - T_mean) / T_std
    O_norm = (O - O_mean) / O_std
    C_norm = (C - C_mean) / C_std
    return T_norm, O_norm, C_norm
end

function denormalize_derivatives(dT_norm, dO_norm, dC_norm)
    dT = dT_norm * T_std
    dO = dO_norm * O_std
    dC = dC_norm * C_std
    return dT, dO, dC
end

ann_node = Lux.Chain(
    Lux.Dense(3, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 64, softplus),
    Lux.Dense(64, 3)
)

rng = Random.default_rng()
p1, st1 = Lux.setup(rng, ann_node)
p_ann = ComponentArray(p1)

function dudt_node(du, u, p, t)
    T, O, C = u
    
    T_norm, O_norm, C_norm = normalize_variables(T, O, C)
    
    nn_output = ann_node([T_norm, O_norm, C_norm], p, st1)[1]
    
    dT, dO, dC = denormalize_derivatives(nn_output[1], nn_output[2], nn_output[3])
    
    du[1] = dT
    du[2] = dO
    du[3] = dC
end

prob_node = ODEProblem{true}(dudt_node, u0, tspan_train, p_ann)

function predict(θ)
    Array(solve(prob_node, Tsit5(), p=θ, saveat=1.0, 
                abstol=1e-6, reltol=1e-6,
                sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end

function loss(θ)
    pred = predict(θ)
    loss = sum(abs2, (noisy_data .- pred))
    return loss
end

initial_loss = loss(p_ann)

iter = 0
losses = []
min_loss = Inf

function callback(θ, l)
    global iter, losses, min_loss
    iter += 1
    
    push!(losses, l)
    min_loss = min(min_loss, l)

    if iter % 50 == 0
        println("Neural ODE Iteration: $iter, Loss: $l")
    end
    return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p_inner) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p_ann)

iters_adam = 20000
println("\nStarting Neural ODE Adam training ($iters_adam iterations)...")
res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(0.00001), 
                         callback=callback, maxiters=iters_adam)
println("Neural ODE Adam training finished.")

iters_adamw = 20000
optprob2 = remake(optprob, u0=res1.u)
println("\nStarting Neural ODE AdamW training ($iters_adamw iterations)...")
res2 = Optimization.solve(optprob2, OptimizationOptimisers.AdamW(0.00001, (0.9, 0.999), 1e-4),
                         callback=callback, maxiters=iters_adamw)
println("Neural ODE AdamW training finished.")

final_loss = loss(res2.u)
loss_reduction_percentage = ((initial_loss - final_loss) / initial_loss) * 100

println("\n--- Neural ODE Training Summary ---")
println("Total Neural ODE training iterations: ", iter)
println("Initial Neural ODE Training Loss: ", initial_loss)
println("Final Neural ODE Training Loss: ", final_loss)
println("Minimum Neural ODE Training Loss obtained: ", min_loss)
println("Loss reduction: ", loss_reduction_percentage, "%")
println("------------------------------------")

total_iterations = length(losses)

p_loss_curve = plot(xlabel="Iterations (log scale)", ylabel="Loss (log scale)",
                    legend=:topright, yaxis=:log, xaxis=:log, size=(598, 369))

if total_iterations > 0
    first_phase_end = min(iters_adam, total_iterations)
    plot!(p_loss_curve, 1:first_phase_end, losses[1:first_phase_end],
          color=:blue, lw=2, label="Adam Optimizer")
    
    if total_iterations > first_phase_end
        plot!(p_loss_curve, (first_phase_end+1):total_iterations, 
              losses[(first_phase_end+1):end],
              color=:red, lw=2, label="AdamW Optimizer")
    end
end

display(p_loss_curve)
savefig(p_loss_curve, "random_seed_91_1p_noise_NODE_climate_node_training_loss_curve.png")

prob_node_extrapolate = ODEProblem{true}(dudt_node, u0, tspan_total, res2.u)
sol_node = Array(solve(prob_node_extrapolate, Tsit5(), saveat=solution_total.t,
                       abstol=1e-12, reltol=1e-12,
                       sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))

true_data_total = Array(solution_total)
time_points_total = solution_total.t

@assert size(true_data_total) == size(sol_node)

l2_errors_over_time = [norm(true_data_total[:, j] - sol_node[:, j]) for j in 1:size(true_data_total, 2)]

println("\n--- Final Values Comparison (at t=50 years) ---")

T_actual_final = true_data_total[1, end]
T_predicted_final = sol_node[1, end]
T_abs_error = abs(T_actual_final - T_predicted_final)
T_percent_error = (T_abs_error / abs(T_actual_final)) * 100

@printf("Surface Temperature:\n")
@printf("  Actual: %.6f °C\n", T_actual_final)
@printf("  Predicted: %.6f °C\n", T_predicted_final)
@printf("  Absolute Error: %.6f °C\n", T_abs_error)
@printf("  Percentage Error: %.2f%%\n", T_percent_error)

O_actual_final = true_data_total[2, end]
O_predicted_final = sol_node[2, end]
O_abs_error = abs(O_actual_final - O_predicted_final)
O_percent_error = (O_abs_error / abs(O_actual_final)) * 100

@printf("\nOcean Temperature:\n")
@printf("  Actual: %.6f °C\n", O_actual_final)
@printf("  Predicted: %.6f °C\n", O_predicted_final)
@printf("  Absolute Error: %.6f °C\n", O_abs_error)
@printf("  Percentage Error: %.2f%%\n", O_percent_error)

C_actual_final = true_data_total[3, end]
C_predicted_final = sol_node[3, end]
C_abs_error = abs(C_actual_final - C_predicted_final)
C_percent_error = (C_abs_error / abs(C_actual_final)) * 100

@printf("\nCO2 Concentration:\n")
@printf("  Actual: %.6f ppm\n", C_actual_final)
@printf("  Predicted: %.6f ppm\n", C_predicted_final)
@printf("  Absolute Error: %.6f ppm\n", C_abs_error)
@printf("  Percentage Error: %.2f%%\n", C_percent_error)
println("------------------------------------------------")

variable_names = ["Surface Temperature Anomaly (T)", "Deep Ocean Temperature Anomaly (O)", "CO₂ Concentration (C)"]
variable_units = ["°C", "°C", "ppm"]
plots = []

for i in 1:3
    p = plot(title=variable_names[i],
             xlabel="Years",
             ylabel=variable_units[i],
             legend = i == 1 ? :topleft : false,
             tickfontsize=7, legendfontsize=6, guidefontsize=8, titlefontsize=7,
             xticks=0:10:50)

    scatter!(p, solution_total.t, Array(solution_total)[i,:],
             label= i == 1 ? "True Data" : "",
             markersize=2, markeralpha=0.7)

    plot!(p, solution_total.t, sol_node[i,:],
          lw=2,
          label= i == 1 ? "Neural ODE Prediction" : "")

    vline!(p, [tspan_train[2]], color=:black, lw=1.5, linestyle=:dash,
           label= i == 1 ? "Training End" : "")

    push!(plots, p)
end

p_combined = plot(plots...,
                 layout=(1,3),
                 size=(593,371),
                 margin=3Plots.mm)

display(p_combined)
savefig(p_combined, "random_seed_91_1p_noise_NODE_climate_node_pred_params.png")

max_l2_err = maximum(l2_errors_over_time)
p_error = plot(title="",
               xlabel="Years",
               ylabel="L2 Error",
               legend=:topleft,
               ylims=(0, max_l2_err > 0 ? max_l2_err * 1.1 : 0.1))

plot!(p_error, time_points_total, l2_errors_over_time,
      label="L2 Error ||True - Predicted||",
      lw=2)

vline!(p_error, [tspan_train[2]], color=:black, linestyle=:dash, lw=1.5, 
       label="Training End Time")

display(p_error)
savefig(p_error, "random_seed_91_1p_noise_NODE_climate_neural_ode_l2_error_over_time.png")

println("\n--- Climate System Neural ODE Analysis Complete ---")
println("Training period: 0-15 years")
println("Prediction period: 15-50 years")
@printf("Final temperature anomaly at 50 years: %.6f °C\n", sol_node[1,end])
@printf("Final CO2 concentration at 50 years: %.6f ppm\n", sol_node[3,end])