import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib
from scipy.integrate import solve_ivp

# ========== Objective Functions ==========
def quadratic(x):
    return 10.0 * x[0]**2 + 1.0 * x[1]**2

def quadratic_grad(x):
    return np.array([20.0 * x[0], 2.0 * x[1]])

def f2(x):
    return 4*(x[1] - np.sin(x[0]))**2 + 0.5*(x[0] - 2)**2

def f2_grad(x):
    df_dx1 = -8*(x[1] - np.sin(x[0]))*np.cos(x[0]) + 1.0*(x[0]-2)
    df_dx2 = 8*(x[1] - np.sin(x[0]))
    return np.array([df_dx1, df_dx2])

# ========== Continuous-Time Dynamics ==========
def heavy_ball_ode(t, y, grad_func, beta=0.9, kappa_m=1):
    x, v = y[:2], y[2:]
    dxdt = v
    dvdt = kappa_m * (-grad_func(x) - beta * v)
    return np.concatenate([dxdt, dvdt])

def normalized_gd_ode(t, y, grad_func):
    x = y[:2]; g = grad_func(x)
    norm = np.linalg.norm(g)
    dxdt = -g/norm if norm>0 else np.zeros_like(g)
    return np.concatenate([dxdt, [0.0, 0.0]])

def homm_ode(t, y, grad_func, kappa=15.0, alpha=-0.5, beta=0.5, gamma=0.9):
    x, v = y[:2], y[2:]
    n = len(x); g = grad_func(x)
    dxdt = np.zeros_like(x); dvdt = np.zeros_like(v)
    for i in range(n):
        norm_z_i = np.sqrt(g[i]**2 + v[i]**2)
        dvdt[i] = kappa * norm_z_i**alpha * (-gamma*g[i] - (1-gamma)*v[i])
        dxdt[i] = norm_z_i**alpha * (-(1-beta)*g[i] + beta*v[i])
    return np.concatenate([dxdt, dvdt])

# ========== Solver ==========
def run_solver(ode_func, grad_func, y0, t_eval, rtol=1e-4):
    return solve_ivp(lambda t, y: ode_func(t, y, grad_func),
                     [t_eval[0], t_eval[-1]], y0,
                     t_eval=t_eval, method='RK45', rtol=rtol)

# ========== Compute derivatives ==========
def compute_derivatives(ode_func, sol, grad_func):
    return np.array([ode_func(sol.t[i], sol.y[:, i], grad_func)[:2] for i in range(len(sol.t))])

# ========== Main ==========
x_init = np.array([-3, 3])
v_init = np.array([0.0, 0.0])
y0 = np.concatenate([x_init, v_init])
t_eval = np.linspace(0, 25, 1000)
FONTSIZE = 20

objectives = {
    "f1": (quadratic, quadratic_grad),
    "f2": (f2, f2_grad)
}

# --- Figure setup with GridSpec ---
colors = ['#32b897', '#2878b5', '#ffbe7a', '#8983bf', '#ff8884', '#9e9e9e']

fig = plt.figure(figsize=(19, 6))
gs = GridSpec(2, 3, figure=fig, width_ratios=[1, 1, 1.2])  # last col a bit wider

# Left column (loss)
ax_loss_f1 = fig.add_subplot(gs[0, 0])
ax_loss_f2 = fig.add_subplot(gs[1, 0])

fig.subplots_adjust(left=0.035, right=0.99, bottom=0.1, top=0.93, 
                    wspace=0.12, hspace=0.15)


# Middle column (||dθ||)
ax_dxdt_f1 = fig.add_subplot(gs[0, 1])
ax_dxdt_f2 = fig.add_subplot(gs[1, 1])

# Right column (rtol effect spans both rows!)
ax_rtol = fig.add_subplot(gs[:, 2])




FONTSIZE = 20
t_eval = np.linspace(0, 25, 1000)
x_init, v_init = np.array([-3, 3]), np.array([0.0, 0.0])
y0 = np.concatenate([x_init, v_init])

# ========== Plot Loss ==========
for ax, (name, (f, grad_f)) in zip([ax_loss_f1, ax_loss_f2], [("f1", (quadratic, quadratic_grad)), ("f2", (f2, f2_grad))]):
    sol_hb = run_solver(heavy_ball_ode, grad_f, y0, t_eval)
    sol_ngd = run_solver(normalized_gd_ode, grad_f, y0, t_eval)
    sol_homm = run_solver(homm_ode, grad_f, y0, t_eval)
    ax.plot(t_eval, [f(sol_hb.y[:2,i]) for i in range(len(t_eval))], label="HB",linewidth=3, color=colors[1])
    ax.plot(t_eval, [f(sol_ngd.y[:2,i]) for i in range(len(t_eval))], label="NGD",linewidth=3, color=colors[0])
    ax.plot(t_eval, [f(sol_homm.y[:2,i]) for i in range(len(t_eval))], label="HomM",linewidth=3, color=colors[2])
    ax.set_yscale('log'); ax.grid(True)
    # ax.set_ylabel(rf'$\log(f_{{{name[-1]}}})$',fontsize=FONTSIZE)
ax_loss_f2.set_xlabel('Time',fontsize=FONTSIZE)
ax_loss_f1.set_title(r"(a). Loss Comparison: $\log(f_i)$", fontsize=FONTSIZE,fontweight='bold')
ax_loss_f1.legend(fontsize=16)
ax_loss_f1.set_xlim([0, 25])
ax_loss_f2.set_xlim([0, 25])
ax_loss_f1.tick_params(axis='both', which='major', labelsize=14) 
ax_loss_f2.tick_params(axis='both', which='major', labelsize=14) 



# ========== Plot ||dθ|| ==========
for ax, (name, (f, grad_f)) in zip([ax_dxdt_f1, ax_dxdt_f2], [("f1", (quadratic, quadratic_grad)), ("f2", (f2, f2_grad))]):
    sol_hb = run_solver(heavy_ball_ode, grad_f, y0, t_eval)
    sol_ngd = run_solver(normalized_gd_ode, grad_f, y0, t_eval)
    sol_homm = run_solver(homm_ode, grad_f, y0, t_eval)
    dx_hb = compute_derivatives(heavy_ball_ode, sol_hb, grad_f)
    dx_ngd = compute_derivatives(normalized_gd_ode, sol_ngd, grad_f)
    dx_homm = compute_derivatives(homm_ode, sol_homm, grad_f)
    ax.plot(t_eval, np.linalg.norm(dx_hb, axis=1), label="HB", linewidth=3,color=colors[1])
    ax.plot(t_eval, np.linalg.norm(dx_ngd, axis=1), label="NGD", linewidth=3,color=colors[0])
    ax.plot(t_eval, np.linalg.norm(dx_homm, axis=1), label="HomM", linewidth=3,color=colors[2])
    ax.grid(True); 
    # ax.set_ylabel(r'$\|\dot{\theta}\|$', fontsize=FONTSIZE)
ax_dxdt_f2.set_xlabel("Time",fontsize=FONTSIZE)
ax_dxdt_f1.set_title(r"(b). Parameter Updates $\|\dot{\theta}\|$", fontsize=FONTSIZE, fontweight='bold')
ax_dxdt_f1.legend(fontsize=16)
ax_dxdt_f1.set_xlim([0, 25])
ax_dxdt_f2.set_xlim([0, 25])
ax_dxdt_f1.tick_params(axis='both', which='major', labelsize=14) 
ax_dxdt_f2.tick_params(axis='both', which='major', labelsize=14) 

# ========== rtol effect ==========
# colors = matplotlib.colormaps["tab10"]

rtol_list = [1e-2, 1e-4, 1e-6]
optimizers = {"NGD": normalized_gd_ode, "HomM": homm_ode}

for j, (opt_name, ode_func) in enumerate(optimizers.items()):
    for i, rtol in enumerate(rtol_list):
        sol = run_solver(ode_func, f2_grad, y0, t_eval, rtol=rtol)
        f_values = [f2(sol.y[:2, k]) for k in range(len(t_eval))]
        
        color = colors[(j * len(rtol_list) + i) % len(colors)]
  # unique color per curve
        ax_rtol.plot(
            t_eval, f_values, 
            label=f"{opt_name}, rtol={rtol}", 
            color=color, linewidth=3
        )
ax_rtol.set_yscale("log"); ax_rtol.grid(True)
ax_rtol.set_xlabel("Time", fontsize=FONTSIZE)
# ax_rtol.set_ylabel(r"$\log(f_2)$ ", fontsize=FONTSIZE)
ax_rtol.set_title(r"(c). Effect of rtol:$\log(f_2)$", fontsize=FONTSIZE,fontweight='bold')
ax_rtol.legend(fontsize=16, loc='lower left')
ax_rtol.set_xlim([0, 25])
ax_rtol.tick_params(axis='both', which='major', labelsize=14) 

# plt.tight_layout()

plt.savefig('figures/full_comparison.pdf')
plt.show()