# %%

import json

import flax.nnx as nnx
import jax
import jax.numpy as jnp
import jax.random as jrnd
import matplotlib.pyplot as plt
import numpy as np
import optax

from symo.experiments.models import MLP
from symo.experiments.utils import sync_order_values
from symo.factory import FactorGrid
from symo.group import B, I, S
from symo.notebooks.plot_utils import default_rcparams, diag_scale, plot_matrix
from symo.notebooks.utils import align_eigval_groups, analyze_repeated_eigvals
from symo.utils import flatten_with_string_path, nnx_path_to_string

plt.rcParams |= default_rcparams(dpi=500)

# %% [markdown]
# # Factory Playground
# Checking covariance of small MLP

# %%

device = "cpu"
seed = 2025
rnd_key = jax.random.PRNGKey(seed)
model_key, data_key, other_key = jax.random.split(rnd_key, 3)

# %%

num_data = 100
input_dim = 5
hidden_dim = 11
output_dim = 4
depth = 2  # 2 or 3
hidden_dims = (hidden_dim,) * depth
skip_every = None
# activation = nnx.tanh
activation = nnx.relu

# plot = False
plot = True

# %%

x_data_key, y_data_key = jrnd.split(data_key, 2)
x = jrnd.normal(x_data_key, (num_data, input_dim))
y = jrnd.normal(y_data_key, (num_data, output_dim))
data = (x, y)

# %%

mlp_config = dict(
    input_dim=input_dim,
    output_dim=output_dim,
    hidden_dims=hidden_dims,
    use_bias=True,
    skip_every=skip_every,
    activation=activation,
    rngs=nnx.Rngs(seed),
)

# %%

mlp = MLP(**mlp_config)
y_pred = mlp(x)

# %%

mlp_params = nnx.state(mlp, nnx.Param)
mlp_state_flatten = jax.tree.flatten_with_path(mlp_params)
mlp_state_str = [(nnx_path_to_string(k), v) for k, v in mlp_state_flatten[0]]

# %%

print(json.dumps([(k, str(v.shape)) for k, v in mlp_state_str], indent=2))

# %%

if activation == nnx.relu:
    group = S
elif activation == nnx.tanh:
    group = B

if depth == 2:
    if skip_every is None:
        In = I["input", input_dim]
        G1 = group["L1", hidden_dim]
        G2 = group["L2", hidden_dim]
        Ou = I["output", output_dim]
    else:
        In = I["input", input_dim]
        G1 = group["L1", hidden_dim]
        G2 = group["L1", hidden_dim]
        Ou = I["output", output_dim]

    groups_spec = (
        ("layers/#0/kernel/.value", (In, G1)),
        ("layers/#0/bias/.value", G1),
        ("layers/#1/kernel/.value", (G1, G2)),
        ("layers/#1/bias/.value", G2),
        ("layers/#2/kernel/.value", (G2, Ou)),
    )
elif depth == 3:
    if skip_every is None:
        In = I["input", input_dim]
        G1 = group["L1", hidden_dim]
        G2 = group["L2", hidden_dim]
        G3 = group["L3", hidden_dim]
        Ou = I["output", output_dim]
    else:
        In = I["input", input_dim]
        G1 = group["L1", hidden_dim]
        G2 = group["L1", hidden_dim]
        G3 = group["L1", hidden_dim]
        Ou = I["output", output_dim]

    groups_spec = (
        ("layers/#0/kernel/.value", (In, G1)),
        ("layers/#0/bias/.value", G1),
        ("layers/#1/kernel/.value", (G1, G2)),
        ("layers/#1/bias/.value", G2),
        ("layers/#2/kernel/.value", (G2, G3)),
        ("layers/#2/bias/.value", G3),
        ("layers/#3/kernel/.value", (G3, Ou)),
    )

# %%

groups_spec = sync_order_values(dict(groups_spec), mlp_params)
groups = tuple([v for _, v in groups_spec])

# %%

factor_grid = FactorGrid(groups)

# %% [markdown]
# Initialize factors randomly

# %%

values, treedef = jax.tree.flatten(mlp_params)
other_key, *keys = jrnd.split(other_key, len(values) + 1)
tree_keys = jax.tree.unflatten(treedef, keys)

new_params = jax.tree.map(lambda x, y: jrnd.normal(y, x.shape), mlp_params, tree_keys)
new_params_flat, _ = jax.tree.flatten(new_params)

# %%

params_flat = jax.tree.leaves(mlp_params)
mean_factors = factor_grid.mean_factors_from_vectors(params_flat)

# %%

means = factor_grid.mean(mean_factors)

# %%

rnd_factors = factor_grid.factor_from_normal(other_key)
rnd_cov = factor_grid.cov(rnd_factors)
rnd_surr = factor_grid.cov(rnd_factors, surrogate=True)
rnd_surr_factors = factor_grid.factor_from_surrogate(rnd_surr)
rnd_cov_factors = factor_grid.factor_from_surrogate(rnd_cov)
rnd_recovered_cov = factor_grid.cov(rnd_surr_factors)

# %% [markdown]
# Frobenious norm of covariance matrix with random factors and recovered covariance matrix of estimated factors from surrogate matrix of random factors.

# %%

jnp.sum((rnd_cov - rnd_recovered_cov) ** 2)

# %%

weights = [x for x in jax.tree.leaves(rnd_factors) if isinstance(x, jax.Array)]
surr_weights = [
    x for x in jax.tree.leaves(rnd_surr_factors) if isinstance(x, jax.Array)
]

# %%

norms = [
    jnp.linalg.norm(w.flatten() - s.flatten()) for w, s in zip(weights, surr_weights)
]

norms

# %%

if plot:
    fig, (ax1, ax2, ax3) = plt.subplots(ncols=3)

    plot_matrix(fig, ax1, diag_scale(rnd_cov), "Covariance from random factors")
    plot_matrix(fig, ax2, diag_scale(rnd_recovered_cov), "Recovered covariance")
    plot_matrix(fig, ax3, rnd_cov - rnd_recovered_cov, "Diff covariances")

# %% [markdown]
# Using gradients for initializing factors

# %%


def loss_fn(model: MLP):
    pred = model(x)
    mse = optax.losses.squared_error(pred, y).mean()
    return mse


loss, grad = nnx.value_and_grad(loss_fn)(mlp)

# %%

grad_leaves = jax.tree.leaves(grad)
param_leaves = jax.tree.leaves(mlp_params)
factor_grad = factor_grid.cov_factors_from_vectors(grad_leaves)
factor_param = factor_grid.cov_factors_from_vectors(param_leaves)

# %% [markdown]
# ## Checking matrix-matrix products

# %% [markdown]
# Self matrix product A^2

# %%

cov_grad = factor_grid.cov(factor_grad)
cov_surr_grad = factor_grid.cov(factor_grad, surrogate=True)

cov_prod = cov_grad @ cov_grad
cov_surr_prod = cov_surr_grad @ cov_surr_grad
factor_prod = factor_grid.factor_from_surrogate(cov_surr_prod)
cov_prod_2 = factor_grid.cov(factor_prod)

np.testing.assert_array_almost_equal(cov_prod, cov_prod_2)

# %% [markdown]
# A B + B A

# %%

cov_grad = factor_grid.cov(factor_grad)
cov_param = factor_grid.cov(factor_param)

cov_surr_grad = factor_grid.cov(factor_grad, surrogate=True)
cov_surr_param = factor_grid.cov(factor_param, surrogate=True)

cov_prod = cov_grad @ cov_param + cov_param @ cov_grad
cov_surr_prod = cov_surr_grad @ cov_surr_param + cov_surr_param @ cov_surr_grad

factor_prod = factor_grid.factor_from_surrogate(cov_surr_prod)
cov_prod_2 = factor_grid.cov(factor_prod)

np.testing.assert_array_almost_equal(cov_prod, cov_prod_2)

# %% [markdown]
# A B A

# %%

cov_grad = factor_grid.cov(factor_grad)
cov_param = factor_grid.cov(factor_param)

cov_surr_grad = factor_grid.cov(factor_grad, surrogate=True)
cov_surr_param = factor_grid.cov(factor_param, surrogate=True)

cov_prod = cov_grad @ cov_param @ cov_grad
cov_surr_prod = cov_surr_grad @ cov_surr_param @ cov_surr_grad

factor_prod = factor_grid.factor_from_surrogate(cov_surr_prod)
cov_prod_2 = factor_grid.cov(factor_prod)

np.testing.assert_array_almost_equal(cov_prod, cov_prod_2)

# %% [markdown]
# # Checking eigendecomposition

# %%

cov_grad = factor_grid.cov(factor_grad)
cov_surr_grad = factor_grid.cov(factor_grad, surrogate=True)

evals_full = jnp.linalg.eigvalsh(cov_grad)
evals_surr = jnp.linalg.eigvalsh(cov_surr_grad)

analysis_full = analyze_repeated_eigvals(evals_full, tol=1e-6)
analysis_surr = analyze_repeated_eigvals(evals_surr, tol=1e-6)

groups = align_eigval_groups(analysis_full, analysis_surr, tol=1e-4)

# %%

ng = len(groups)
ncols = int(np.ceil(np.sqrt(ng)))
nrows = int(np.ceil(ng / ncols))

ax_args = dict(s=5, alpha=0.5)
# ax_args = dict(linewidth=0.5, s=5, alpha=.5)

fig, axes = plt.subplots(ncols=ncols, nrows=nrows)

axes = axes.flatten()
for i, group in enumerate(groups):
    ax = axes[i]

    full = group.g1
    surr = group.g2
    ii, jj = group.idx
    full_y = full.eigvals
    full_x = np.arange(full_y.shape[0])
    ax.scatter(full_x, full_y, label="Full", **ax_args)

    surr_y = surr.eigvals
    surr_x = np.arange(surr_y.shape[0])
    ax.scatter(surr_x, surr_y, label="Surr", **ax_args)

    ax.set_ylabel(rf"$({ii}, {jj})$")
    ax.xaxis.get_major_locator().set_params(integer=True)

axes[-1].legend()
fig.tight_layout()
fig.show()

# %%

ng = len(groups)
ncols = int(np.ceil(np.sqrt(ng)))
nrows = int(np.ceil(ng / ncols))

ax_args = dict(s=10, alpha=0.9, marker="_")
# ax_args = dict(linewidth=0.5, s=5, alpha=.5)

fig, axes = plt.subplots(ncols=ncols, nrows=nrows)

axes = axes.flatten()
for i, group in enumerate(groups):
    ax = axes[i]

    full = group.g1
    surr = group.g2
    ii, jj = group.idx
    full_y = full.eigvals
    surr_y = surr.eigvals

    full_x = np.zeros(full_y.shape)
    surr_x = np.ones(surr_y.shape)

    ax.scatter(full_x, full_y, label="Full", **ax_args)

    ax.scatter(surr_x, surr_y, label="Surr", **ax_args)

    pos = (0, 1)
    lw = dict(linewidth=0.3, color="gray")
    bp = ax.boxplot(
        [full_y, surr_y],
        positions=pos,
        patch_artist=True,
        widths=0.3,
        tick_labels=["Full", "Surr"],
        boxprops=lw | dict(alpha=0.5, zorder=1),
        whiskerprops=lw,
        capprops=lw,
        flierprops=dict(marker="", markersize=0, alpha=0),
        medianprops=lw | dict(color="darkred"),
    )

    colors = ["lightblue", "lightcoral"]
    for patch, color in zip(bp["boxes"], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.3)
        patch.set_zorder(1)
        # patch.set_visible(False)

    ax.set_ylabel(rf"$({ii}, {jj})$")
    # ax.xaxis.get_major_locator().set_params(integer=True)
    ax.set_xlim(-0.5, 1.5)

axes[-1].legend()
fig.tight_layout()
fig.show()


# %%

x = np.arange(evals_cov.shape[0])
x_surr = np.arange(evals_surr_cov.shape[0])


ax_kwargs = dict(s=5, alpha=0.2)
fig, ax = plt.subplots()
ax.scatter(x, evals_cov, label="True", **ax_kwargs)
ax.scatter(x_surr, evals_surr_cov, label="Surrogate", **ax_kwargs)
ax.legend()
fig.show()

# %%

surrogate = factor_grid.cov(factor_grad, surrogate=True)

# %%

if plot:
    fig, (ax1, ax2, ax3) = plt.subplots(ncols=3)

    plot_matrix(fig, ax1, cov_grad, "Analytical covariance")
    plot_matrix(fig, ax2, diag_scale(cov_grad), "Diag-scaled analytical covariance")
    plot_matrix(fig, ax3, diag_scale(surrogate), "Surrogate of analytical covariance")

# %%

factors_recovered = factor_grid.factor_from_surrogate(surrogate)

# %%

cov_recovered = factor_grid.cov(factors_recovered)

# %%

# cov_diff = diag_scale(cov) - diag_scale(cov_recovered)
cov_diff = cov - cov_recovered

# %%

diff_frobenius_norm = jnp.sum(cov_diff**2)
diff_frobenius_norm

# %%


if plot:
    fig, (ax1, ax2, ax3) = plt.subplots(ncols=3)
    plot_matrix(fig, ax1, cov, "Analytical covariance")
    plot_matrix(
        fig, ax2, cov_recovered, "Recovered from surrogate analytical covariance"
    )
    plot_matrix(fig, ax3, cov_diff, "Diff of covariances")

# %%

grad_dict = dict(flatten_with_string_path(grad)[0])
grad_list = list(grad_dict[k] for k, _ in groups_cfg)
grad_vec = jnp.concat([g.flatten() for g in grad_list])[:, None]

# %%

true_mvp = cov @ grad_vec
matvec = factor_grid.matvec(factors, grad_list)
mvp_vec = jnp.concat([m.flatten() for m in matvec])[:, None]

# %%

diff_mvp = true_mvp - mvp_vec
vec_inprod = jnp.sum(diff_mvp**2)
vec_inprod

# %%
