import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

# ========== Objective Functions ==========
def quadratic(x):
    return 10.0 * x[0]**2 + 1.0 * x[1]**2

def quadratic_grad(x):
    return np.array([20.0 * x[0], 2.0 * x[1]])

def f2(x):
    return 4*(x[1] -   np.sin(x[0]))**2 + 0.5 * (x[0] - 2)**2

def f2_grad(x):
    df_dx1 = -8*(x[1] -  np.sin(x[0]))*np.cos(x[0]) + 1.0*(x[0]-2)
    df_dx2 = 8*(x[1] - np.sin(x[0]))
    return np.array([df_dx1, df_dx2])

# ========== Continuous-Time Dynamics ==========
def heavy_ball_ode(t, y, grad_func, beta=0.99, kappa_m = 1):
    x = y[:2]
    v = y[2:]
    dxdt = v
    dvdt = kappa_m*(-grad_func(x) - beta * v)
    return np.concatenate([dxdt, dvdt])

def normalized_gd_ode(t, y, grad_func):
    x = y[:2]
    g = grad_func(x)
    norm = np.linalg.norm(g)
    if norm == 0:
        return np.zeros_like(y)
    dxdt = -1.0*g / norm
    return np.concatenate([dxdt, [0.0, 0.0]])  # no velocity dynamics

def homm_ode(t, y, grad_func, kappa =15.0,  alpha=-0.5, beta=0.5, gamma=0.9):
    x = y[:2]
    v = y[2:]
    n = len(x)

    g = grad_func(x)
    
    dvdt = np.zeros_like(v)
    dxdt = np.zeros_like(x)

    for i in range(n):
        norm_z_i = np.sqrt(g[i]**2 + v[i]**2)
        dvdt[i] = kappa * norm_z_i **alpha * (-gamma * g[i] - (1 - gamma) * v[i])
        dxdt[i] = norm_z_i**alpha * (-(1 - beta) * g[i] +  beta * v[i])
    
    return np.concatenate([dxdt, dvdt])

# ========== Run Solver ==========
def run_solver(ode_func, grad_func, y0, t_span=(0, 25), t_eval=None):
    return solve_ivp(lambda t, y: ode_func(t, y, grad_func), t_span, y0,
                     t_eval=t_eval, method='RK45', rtol=1e-4)

# Compute derivatives along trajectory
def compute_derivatives(ode_func, sol, grad_func):
    dtheta_list, dv_list = [], []
    for i in range(len(sol.t)):
        dydt = ode_func(sol.t[i], sol.y[:, i], grad_func)
        dxdt, dvdt = dydt[:2], dydt[2:]
        dtheta_list.append(dxdt)
        dv_list.append(dvdt)
    return np.array(dtheta_list), np.array(dv_list)

# ========== Main ==========
if __name__ == "__main__":
    x_init = np.array([-3, 3])
    v_init = np.array([0.0, 0.0])
    y0 = np.concatenate([x_init, v_init])
    t_eval = np.linspace(0, 25, 1000)

    objectives = {
        "f_1": (quadratic, quadratic_grad),
        "f_2": (f2, f2_grad),
    }

    # ========== Main plotting loop ==========
    for name, (f, grad_f) in objectives.items():
        results = {}

        # Solve ODEs
        sol_hb = run_solver(heavy_ball_ode, grad_f, y0, t_eval=t_eval)
        sol_ngd = run_solver(normalized_gd_ode, grad_f, y0, t_eval=t_eval)
        sol_homm = run_solver(homm_ode, grad_f, y0, t_eval=t_eval)

        results["Heavy Ball"] = (sol_hb, [f(sol_hb.y[:2, i]) for i in range(len(t_eval))])
        results["Normalized GD"] = (sol_ngd, [f(sol_ngd.y[:2, i]) for i in range(len(t_eval))])
        results["HomM"] = (sol_homm, [f(sol_homm.y[:2, i]) for i in range(len(t_eval))])

    # --- Plot Loss as Subplots ---
    plt.figure(figsize=(10, 8))  # taller figure for 2 subplots

    for idx, (name, (f, grad_f)) in enumerate(objectives.items(), 1):
        plt.subplot(len(objectives), 1, idx)
        for label, (sol, losses) in results.items():
            plt.plot(t_eval, losses, label=label, linewidth=3.5)
        plt.yscale('log')
        if idx == 2:
            plt.xlabel('Time', fontsize=18, fontweight='bold')
        plt.ylabel(f'{name} Value (log scale)', fontsize=18, fontweight='bold')
        
        plt.grid(True)
        plt.legend(fontsize=18)  # show legend only once
        plt.xlim(0, 25)  
        plt.tick_params(axis='both', which='major', labelsize=14)

    plt.tight_layout()
    plt.savefig('figures/ode_comparison_all_loss.pdf')
    plt.show()



        # Compute derivatives for all optimizers
    plt.figure(figsize=(10, 8))  # taller figure for two subplots

    for idx, (name, (f, grad_f)) in enumerate(objectives.items(), 1):
        # Solve ODEs
        results = {}
        sol_hb = run_solver(heavy_ball_ode, grad_f, y0, t_eval=t_eval)
        sol_ngd = run_solver(normalized_gd_ode, grad_f, y0, t_eval=t_eval)
        sol_homm = run_solver(homm_ode, grad_f, y0, t_eval=t_eval)

        results["Heavy Ball"] = sol_hb
        results["Normalized GD"] = sol_ngd
        results["HomM"] = sol_homm

        # Compute dxdt for all optimizers
        dxdt_all = {}
        for label, sol in results.items():
            dxdt, _ = compute_derivatives(
                heavy_ball_ode if label == "Heavy Ball" else
                normalized_gd_ode if label == "Normalized GD" else
                homm_ode,
                sol, grad_f
            )
            dxdt_all[label] = dxdt

        # --- Subplot ---
        plt.subplot(2, 1, idx)
        for label, dtheta in dxdt_all.items():
            norm_dtheta = np.linalg.norm(dtheta, axis=1)
            plt.plot(t_eval, norm_dtheta, label=label, linewidth=3.5)
        if idx == 2:
            plt.xlabel("Time", fontsize=18, fontweight='bold')
        plt.ylabel(r"$\|\dot{\theta}\|$", fontsize=18, fontweight='bold')
        plt.tick_params(axis='both', which='major', labelsize=14)
        # plt.title(f"Parameters Update Comparison on {name}", fontsize=18, fontweight='bold')
        plt.grid(True)
        plt.xlim(0, 25)         
        plt.legend(fontsize=18)  # show legend only on top subplot

    plt.tight_layout()
    plt.savefig("figures/dxdt_comparison_all.pdf")
    plt.show()
