# %% [markdown]
# # Hessian Playground
# Checking that optimization procedure works

# %%

from typing import Literal
from itertools import cycle, islice
import numpy as np
from functools import partial
import operator
from dataclasses import dataclass, asdict
import matplotlib.pyplot as plt
import optax
import jax.random as jrnd
import jax.numpy as jnp
import jax.scipy as jscipy
import jax
import flax.nnx as nnx

from symo.group import I, B, S

from symo.optim import geom_matrix_mean
from symo.factory import FactorGrid
from symo.notebooks.plot_utils import default_rcparams
from symo.experiments.models import Activation, MLP
from symo.experiments.mlp_groups import group_config
from symo.utils import nnx_path_to_string
from symo.data import mlp_teacher_data
from symo.metrics import compute_metrics

import matplotlib.pyplot as plt

plt.rcParams |= default_rcparams(dpi=500)
Data = tuple[jax.Array, jax.Array]


# %%


depth: int = 3


@dataclass(frozen=True)
class ExperimentConfig:
    # Common
    device: str = "cpu"
    seed: int = 2025
    # seed: int = 2026

    # Data
    num_train_points = 5000
    num_test_points = 5000

    # num_train_points = 512
    # num_test_points = 512


@dataclass(frozen=True)
class ModelConfig:
    input_dim: int = 100
    hidden_dims: tuple[int, ...] = (70,) * depth
    output_dim: int = 40

    skip_every: int | None = None
    use_bias: bool = True
    use_bias_last: bool = True
    # activation: Activation = "relu"
    activation: Activation = "tanh"


# %%

cfg = ExperimentConfig()
model_cfg = ModelConfig()

kernel_init = nnx.initializers.orthogonal(scale=1.5)
mlp_teacher_config = asdict(model_cfg) | dict(rngs=nnx.Rngs(cfg.seed))
mlp_teacher = MLP(kernel_init=kernel_init, **mlp_teacher_config)

# %%

data_key, other_key = jrnd.split(jrnd.PRNGKey(cfg.seed), 2)

# %%

other_key, (train_data, val_data) = mlp_teacher_data(
    data_key,
    mlp_teacher,
    num_train_points=cfg.num_train_points,
    num_test_points=cfg.num_test_points,
)

# %%

mlp_config = mlp_teacher_config | dict(rngs=nnx.Rngs(cfg.seed + 1))
mlp = MLP(**mlp_config)

mlp_treedef, _ = nnx.split(mlp, nnx.Param)

# %%

group_spec = group_config(mlp, hid_group=S, same=False)
group_spec_tuple = tuple([g for _, g in group_spec])

# %%

factory = FactorGrid(group_spec_tuple)

# %%


def add_noise(param, key, noise_scale):
    def _noise(param, key):
        noise = jax.random.normal(key, shape=param.shape) * noise_scale
        return param + noise

    flat, treedef = jax.tree.flatten(param)
    keys = jax.random.split(key, len(flat))
    keys_tree = jax.tree.unflatten(treedef, keys)

    noisy_state = jax.tree.map(lambda param, k: _noise(param, k), param, keys_tree)

    return noisy_state


# %%


@jax.jit
def loss_fn(model):
    x, y = train_data
    y_pred = model(x)
    loss = jnp.mean(0.5 * (y - y_pred) ** 2)
    return loss


# @jax.jit
# def loss_fn(model: MLP):
#     x, y = train_data
#     pred = model(x)
#     loss = optax.losses.squared_error(pred, y).mean()
#     return loss


@jax.jit
def grad_fn(model):
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    return loss, grads


def cosine_similarity(state1, state2, jitter: float = 0):
    flat1, _ = jax.flatten_util.ravel_pytree(state1)
    flat2, _ = jax.flatten_util.ravel_pytree(state2)
    dot_product = jnp.dot(flat1, flat2) + jitter
    norm1 = jnp.linalg.norm(flat1)
    norm2 = jnp.linalg.norm(flat2)
    return dot_product / ((norm1 * norm2) + jitter)


# %%


def hvp_fn(graphdef, params, v):
    def loss_only_params(params):
        model_merged = nnx.merge(graphdef, params)
        return loss_fn(model_merged)

    return jax.jvp(jax.grad(loss_only_params), (params,), (v,))[1]


# %%


def svd(m, hermitian: bool = False):
    u, s, vh = jnp.linalg.svd(m, hermitian=hermitian)
    if hermitian:
        return u, s, u.T
    else:
        return u, s, vh


def sqrt_svd(u, s, vh, minimum: float = 0.0):
    sm = jnp.where(s >= minimum, jnp.sqrt(s), 0.0)
    return (u * sm[None]) @ vh


def inv_svd(u, s, vh, minimum: float = 0.0):
    sm = jnp.where(s > minimum, 1 / s, 0.0)
    return (u * sm[None]) @ vh


def inv_sqrt_svd(u, s, vh, minimum: float = 0.0):
    s_sqrt = jnp.where(s >= minimum, jnp.sqrt(s), 0.0)
    s_inv_sqrt = jnp.where(s > minimum, 1 / s_sqrt, 0.0)
    out_sqrt = (u * s_sqrt[None]) @ vh
    out_inv_sqrt = (u * s_inv_sqrt[None]) @ vh
    return out_sqrt, out_inv_sqrt


# %%


def hess_star_vec(graphdef, params, vec):
    hess_vec = hvp_fn(graphdef, params, vec)
    return hess_vec


# %%


def hess_hat_vec(factory, g, theta, vec):
    g_flat, treedef = jax.tree.flatten(g)
    theta_flat, _ = jax.tree.flatten(theta)
    vec_flat, _ = jax.tree.flatten(vec)

    g_theta_factors = factory.cov_factors_from_vectors(g_flat, theta_flat)
    theta_factors = factory.cov_factors_from_vectors(theta_flat)

    theta_surr = factory.cov(theta_factors, surrogate=True)

    u, s, v = svd(theta_surr)
    theta_surr_inv = inv_svd(u, s, v)

    theta_inv_factors = factory.factor_from_surrogate(theta_surr_inv)

    theta_inv_vec = factory.matvec(theta_inv_factors, vec_flat)
    hess_f_v_flat = factory.matvec(g_theta_factors, theta_inv_vec)

    hess_f_v = jax.tree.unflatten(treedef, hess_f_v_flat)

    return hess_f_v


# %%


def hess_sigma_sqrt_vec(factory, g, vec):
    g_flat, treedef = jax.tree.flatten(g)
    grad_factor = factory.cov_factors_from_vectors(g_flat)
    cov = factory.cov(grad_factor, surrogate=True)
    u, s, v = svd(cov)
    cov_sqrt = sqrt_svd(u, s, v)
    sqrt_factor = factory.factor_from_surrogate(cov_sqrt)

    vec_flat = jax.tree.leaves(vec)
    out_flat = factory.matvec(sqrt_factor, vec_flat)
    out = jax.tree.unflatten(treedef, out_flat)
    return out


# %%


def hess_tilde_vec(factory, g, theta, vec):
    g_flat = jax.tree.leaves(g)
    theta_flat = jax.tree.leaves(theta)
    vec_flat, treedef = jax.tree.flatten(vec)

    grad_factor = factory.cov_factors_from_vectors(g_flat, theta_flat)
    theta_factor = factory.cov_factors_from_vectors(theta_flat)

    # P = √A⁻¹ √(√A B √A) √A⁻¹
    #   = √A⁻¹ √C √A⁻¹

    sigma_g = factory.cov(grad_factor, surrogate=True)
    sigma_t = factory.cov(theta_factor, surrogate=True)

    # A = V Λ Vᵀ
    # √A = V (√Λ) Vᵀ
    # √A⁻¹ = V (√Λ')⁻¹ Vᵀ
    ut, st, vt = svd(sigma_t)
    sigma_t_sqrt, sigma_t_inv_sqrt = inv_sqrt_svd(ut, st, vt)

    # C = √A B √A
    # M, W = eigh(C)
    c_mat = sigma_t_sqrt @ sigma_g @ sigma_t_sqrt

    uc, sc, vc = svd(c_mat)
    c_sqrt = sqrt_svd(uc, sc, vc)
    prec = sigma_t_inv_sqrt @ c_sqrt @ sigma_t_inv_sqrt

    prec_factors = factory.factor_from_surrogate(prec)
    hess_tilde_v_flat = factory.matvec(prec_factors, vec_flat)
    hess_tilde_v = jax.tree.unflatten(treedef, hess_tilde_v_flat)
    return hess_tilde_v


def hess_tilde2_vec(factory, g, theta, vec):
    g_flat = jax.tree.leaves(g)
    theta_flat = jax.tree.leaves(theta)
    vec_flat, treedef = jax.tree.flatten(vec)

    grad_factor = factory.cov_factors_from_vectors(g_flat, theta_flat)
    theta_factor = factory.cov_factors_from_vectors(theta_flat)

    # P = √A⁻¹ √(√A B √A) √A⁻¹
    #   = √A⁻¹ √C √A⁻¹

    sigma_g = factory.cov(grad_factor, surrogate=True)
    sigma_t = factory.cov(theta_factor, surrogate=True)

    ug, sg, vg = svd(sigma_g)
    sigma_g_sqrt = sqrt_svd(ug, sg, vg)

    c_mat = sigma_g_sqrt @ sigma_t @ sigma_g_sqrt
    uc, sc, vs = svd(c_mat)
    _, c_inv_sqrt = inv_sqrt_svd(uc, sc, vs)

    prec = sigma_g_sqrt @ c_inv_sqrt @ sigma_g_sqrt

    prec_factors = factory.factor_from_surrogate(prec)
    hess_tilde_v_flat = factory.matvec(prec_factors, vec_flat)
    hess_tilde_v = jax.tree.unflatten(treedef, hess_tilde_v_flat)
    return hess_tilde_v


# %%


def star_vectors(graphdef, factory, theta):
    model = nnx.merge(graphdef, theta)
    _, grad = grad_fn(model)

    theta_flat, treedef = jax.tree.flatten(theta)
    grad_flat = jax.tree.leaves(grad)

    theta_factors = factory.mean_factors_from_vectors(theta_flat)
    grad_factors = factory.mean_factors_from_vectors(grad_flat)

    theta_star_flat = factory.mean(theta_factors, theta_flat)
    grad_star_flat = factory.mean(grad_factors, grad_flat)

    theta_star = jax.tree.unflatten(treedef, theta_star_flat)
    grad_star = jax.tree.unflatten(treedef, grad_star_flat)

    return grad, theta_star, grad_star


def model_vectors(graphdef, factory, theta):
    grad, theta_star, grad_star = star_vectors(graphdef, factory, theta)

    sub = operator.sub
    theta_diff = jax.tree.map(sub, theta, theta_star)
    grad_diff = jax.tree.map(sub, grad, grad_star)

    return grad, (theta_star, grad_star), (theta_diff, grad_diff)


def compute_cossims(
    graphdef,
    factory: FactorGrid,
    theta,
    grad,
    theta_star,
    grad_star,
    theta_vec,
    grad_vec,
):
    sub = operator.sub
    theta_diff = jax.tree.map(sub, theta, theta_star)
    grad_diff = jax.tree.map(sub, grad, grad_star)

    # $H$
    hess_v = hess_star_vec(graphdef, theta, theta_vec)

    # $H^*$
    hess_star_v = hess_star_vec(graphdef, theta_star, theta_vec)

    # $\hat{H}$
    hess_hat_v = hess_hat_vec(factory, grad_diff, theta_diff, vec=theta_vec)

    # $\tilde{H}$
    hess_tilde_v = hess_tilde_vec(factory, grad_diff, theta_diff, vec=theta_vec)

    # $\tilde{H}'$
    hess_tilde2_v = hess_tilde2_vec(factory, grad_diff, theta_diff, vec=theta_vec)

    # $\Sigma_g^{\frac{1}{2}}
    hess_grad_sqrt_v = hess_sigma_sqrt_vec(factory, grad, vec=theta_vec)

    # $\Sigma_{(g - g^*)}^{\frac{1}{2}}
    hess_grad_diff_sqrt_v = hess_sigma_sqrt_vec(factory, grad_diff, vec=theta_vec)

    ##

    out = {
        "hess": cosine_similarity(grad_vec, hess_v),
        "hess_star": cosine_similarity(grad_vec, hess_star_v),
        "hess_hat": cosine_similarity(grad_vec, hess_hat_v),
        "hess_tilde": cosine_similarity(grad_vec, hess_tilde_v),
        "hess_tilde2": cosine_similarity(grad_vec, hess_tilde2_v),
        "hess_grad_sqrt": cosine_similarity(grad_vec, hess_grad_sqrt_v),
        "hess_grad_diff_sqrt": cosine_similarity(grad_vec, hess_grad_diff_sqrt_v),
    }

    return out


def linear_change(a, b, alpha: float):
    def upd(x, y):
        return alpha * x + (1 - alpha) * y

    return jax.tree.map(upd, a, b)


# %%

alphas = jnp.linspace(1e-7, 1.0 - 1e-7, 500)

# %%


n = 10
lin_hess_stars = []
lin_hess_hats = []
lin_hess_tildas = []
lin_hess_sqrts = []

mlp_config = mlp_teacher_config | dict(rngs=nnx.Rngs(cfg.seed))
mlp = MLP(**mlp_config)

use_star: Literal["star", "dagger"] | None = "dagger"

theta = nnx.state(mlp, nnx.Param)
theta_flat, treedef = jax.tree.flatten(theta)
mlp_graphdef, _ = nnx.split(mlp, nnx.Param)

# %%


adam_opt = optax.adam(learning_rate=0.02)
adam_optimizer = nnx.Optimizer(mlp, adam_opt, wrt=nnx.Param)

# num_opt_steps = 0
num_opt_steps = 20

for i in range(num_opt_steps):
    loss, grads = jax.value_and_grad(loss_fn)(mlp)
    adam_optimizer.update(mlp, grads)
    print(f"iteration {i}:", loss)

# %%

grad, theta_star, grad_star = star_vectors(mlp_graphdef, factory, theta)


# %%

lcs_fn = partial(
    compute_cossims,
    graphdef=mlp_graphdef,
    factory=factory,
    theta=theta,
    grad=grad,
    theta_star=theta_star,
    grad_star=grad_star,
)


@jax.jit
def lcs(grad_vec, theta_vec):
    return lcs_fn(theta_vec=theta_vec, grad_vec=grad_vec)


# %%

labels = {
    "hess": r"$\mathrm{H} = \nabla^2 L(\theta)$",
    "hess_star": r"$\mathrm{H}^{*} = \nabla^2 L(\theta^{*})$",
    "hess_hat": r"$\hat{\mathrm{H}} = \Sigma_{\hat{g}\hat{\theta}} \Sigma_{\hat{\theta}}^{-1}$, $\hat{g} = g - g^{*}$, $\hat{\theta} = \theta - \theta^{*}$",
    "hess_tilde": r"$\tilde{\mathrm{H}} = \Sigma_{\hat{\theta}}^{-\frac{1}{2}}\left(\Sigma_{\hat{\theta}}^{\frac{1}{2}} \Sigma_{\hat{g}} \Sigma_{\hat{\theta}}^{\frac{1}{2}}\right)^{\frac{1}{2}} \Sigma_{\hat{\theta}}^{-\frac{1}{2}}$",
    "hess_tilde2": r"$\tilde{\mathrm{H}} = \Sigma_{\hat{g}}^{\frac{1}{2}}\left(\Sigma_{\hat{g}}^{-\frac{1}{2}} \Sigma_{\hat{\theta}}^{-1} \Sigma_{\hat{g}}^{-\frac{1}{2}}\right)^{\frac{1}{2}} \Sigma_{\hat{g}}^{\frac{1}{2}}$",
    "hess_grad_sqrt": r"$\Sigma_g^{\frac{1}{2}}$",
    "hess_grad_diff_sqrt": r"$\Sigma_{(g - g^*)}^{\frac{1}{2}}$",
}

# %%

out = {k: [] for k, _ in labels.items()}

for alpha in alphas:
    theta_prime = linear_change(theta_star, theta, alpha)

    model_prime = nnx.merge(mlp_graphdef, theta_prime)
    _, grad_prime = grad_fn(model_prime)

    if use_star == "star":
        theta_vec = jax.tree.map(operator.sub, theta_prime, theta_star)
        grad_vec = jax.tree.map(operator.sub, grad_prime, grad_star)
    elif use_star == "dagger":
        model = nnx.merge(mlp_graphdef, theta_star)
        _, grad_dagger = grad_fn(model)

        theta_vec = jax.tree.map(operator.sub, theta_prime, theta_star)
        grad_vec = jax.tree.map(operator.sub, grad_prime, grad_dagger)
    else:
        model = nnx.merge(mlp_graphdef, theta)
        _, grad = grad_fn(model)

        theta_vec = jax.tree.map(operator.sub, theta_prime, theta)
        grad_vec = jax.tree.map(operator.sub, grad_prime, grad)

    out_values = lcs(grad_vec, theta_vec)

    for k, v in out_values.items():
        out[k].append(v)

# %%

# %%

for k, v in out.items():
    out[k] = np.array(v)

# %%

ls = ["-", "--", "-."]
zs = [1, 3, 5]
linestyles = list(islice(cycle(ls), len(labels)))
zorders = list(islice(cycle(zs), len(labels)))

# %%

# args = dict(s=5, alpha=0.5)
args = dict(alpha=0.6, linewidth=1)

fig, ax = plt.subplots(nrows=1, ncols=1)
for linestyle, zorder, (k, label) in zip(linestyles, zorders, labels.items()):
    ax.plot(alphas, out[k], label=label, linestyle=linestyle, **args)

ax.legend()

if use_star == "star":
    cossim_label = r"Cosine similarity $g' - g^{*}$ vs $H (\theta' - \theta^{*})$"
elif use_star == "dagger":
    cossim_label = r"Cosine similarity $g' - g^{\dagger}$ vs $H (\theta' - \theta^{*})$, s.t. $g^{\dagger} = \nabla L(\theta^{*})$"
else:
    cossim_label = r"Cosine similarity $g' - g$ vs $H (\theta' - \theta)$"

ax.set_title(
    cossim_label + f", '{model_cfg.activation}' activation, train_steps={num_opt_steps}"
)

ax.set_xlabel(
    r"$\alpha$ for linearly interpolating parameters from $\theta$ to $\theta^*$"
)
fig.tight_layout()

# %%
