# %%

from dataclasses import dataclass, asdict
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jrnd
import flax.nnx as nnx
import optax

from symo.factory import FactorGrid

from symo.notebooks.plot_utils import default_rcparams
from symo.experiments.mlp_groups import group_config
from symo.experiments.models import Activation, MLP
from symo.experiments.utils import sync_order_values
from symo.data import mlp_teacher_data
from symo.notebooks.utils import analyze_repeated_eigvals, align_eigval_groups, align_eigval

import matplotlib.pyplot as plt

plt.rcParams |= default_rcparams(dpi=500)

# %% [markdown]
# # Model setup

# %%


@dataclass(frozen=True)
class ExperimentConfig:
    # Common
    device: str = "cpu"
    seed: int = 2025

    # Data
    num_train_points = 5000
    num_test_points = 5000


depth: int = 3
skip_every: int | None = 1
if depth == 1:
    skip_every = None


@dataclass(frozen=True)
class MLPConfig:
    # Model
    input_dim: int = 13
    hidden_dims: tuple[int, ...] = (35,) * depth
    output_dim: int = 11
    skip_every: int | None = skip_every
    use_bias: bool = False
    activation: Activation = "relu"


# %%

num_epochs = 500


@dataclass(frozen=True)
class SymoConfig:
    num_epochs: int = num_epochs
    # grad_momentum: float = 0.1
    # param_momentum: float = 0.1
    grad_momentum: float = 0.0
    param_momentum: float = 0.0
    decay: float = 0.98
    damping: float = 1e-15
    lr: float = 0.5


@dataclass(frozen=True)
class AdamConfig:
    num_epochs: int = num_epochs
    lr: float = 1e-2


# %%

exp_teacher_cfg = ExperimentConfig(seed=1)
exp_cfg = ExperimentConfig()
mlp_cfg = MLPConfig()
symo_cfg = SymoConfig()
adam_cfg = AdamConfig()

# %%

mlp_teacher = MLP(rngs=nnx.Rngs(exp_teacher_cfg.seed), **asdict(mlp_cfg))

# %%

key_teacher = jrnd.PRNGKey(exp_teacher_cfg.seed)

rnd_key, (train_data, test_data) = mlp_teacher_data(
    key_teacher,
    mlp_teacher,
    exp_cfg.num_train_points,
    exp_cfg.num_test_points,
)

# %%

mlp = MLP(rngs=nnx.Rngs(exp_cfg.seed), **asdict(mlp_cfg))

# %%

mlp_params = nnx.state(mlp, nnx.Param)
groups_spec = group_config(mlp)
groups_spec = sync_order_values(dict(groups_spec), mlp_params)
groups = tuple([v for _, v in groups_spec])

# %%

factor_grid = FactorGrid(groups)

# %%


def loss_fn(model: MLP):
    x, y = train_data
    pred = model(x)
    mse = optax.losses.squared_error(pred, y).mean()
    return mse


loss, grad = nnx.value_and_grad(loss_fn)(mlp)

# %%

grad_leaves = jax.tree.leaves(grad)
param_leaves = jax.tree.leaves(mlp_params)
factor_grad = factor_grid.cov_factors_from_vectors(grad_leaves)
factor_param = factor_grid.cov_factors_from_vectors(param_leaves)

# %% [markdown]
# # Checking eigendecomposition

# %%

cov_grad = factor_grid.cov(factor_grad)
cov_surr_grad = factor_grid.cov(factor_grad, surrogate=True)

# %%

evals_full = np.array(jnp.linalg.eigvalsh(cov_grad))
evals_surr = np.array(jnp.linalg.eigvalsh(cov_surr_grad))

# %%

tol = 1e-5
sort = False
rel_tol = True

analysis_full = analyze_repeated_eigvals(evals_full, tol=tol, sort=sort, rel_tol=rel_tol)
# analysis_surr = analyze_repeated_eigvals(evals_surr, tol=tol, sort=sort)
analysis_surr = align_eigval(analysis_full, evals_surr)

# %%

assert len(analysis_full) == len(analysis_surr)


# %%

groups = align_eigval_groups(analysis_full, analysis_surr)

# %%

ng = len(groups)
ncols = int(np.ceil(np.sqrt(ng)))
nrows = int(np.ceil(ng / ncols))

ax_args = dict(s=5, alpha=0.5)
# ax_args = dict(linewidth=0.5, s=5, alpha=.5)

fig, axes = plt.subplots(ncols=ncols, nrows=nrows)

axes = axes.flatten()
for i, group in enumerate(groups):
    ax = axes[i]

    full = group.g1
    surr = group.g2
    ii, jj = group.idx
    full_y = full.eigvals
    full_x = np.arange(full_y.shape[0])
    ax.scatter(full_x, full_y, label="Full", **ax_args)

    surr_y = surr.eigvals
    surr_x = np.arange(surr_y.shape[0])
    ax.scatter(surr_x, surr_y, label="Surr", **ax_args)

    ax.set_ylabel(rf"$({ii}, {jj})$")
    ax.xaxis.get_major_locator().set_params(integer=True)

axes[-1].legend()
fig.tight_layout()
fig.show()

# %%

ng = len(groups)
ncols = int(np.ceil(np.sqrt(ng)))
nrows = int(np.ceil(ng / ncols))

ax_args = dict(s=15, alpha=0.8, marker="_")

figsize = (9, 6)
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=figsize)

axes = axes.flatten()
for i, group in enumerate(groups):
    ax = axes[i]

    full = group.g1
    surr = group.g2
    ii, jj = group.idx
    full_y = full.eigvals
    surr_y = surr.eigvals

    full_x = np.zeros(full_y.shape)
    surr_x = np.ones(surr_y.shape)

    ax.scatter(full_x, full_y, label="Full", **ax_args)
    ax.scatter(surr_x, surr_y, label="Surr", **ax_args)

    pos = (0, 1)
    lw = dict(linewidth=0.3, color="gray")
    bp = ax.boxplot(
        [full_y, surr_y],
        positions=pos,
        patch_artist=True,
        widths=0.3,
        tick_labels=["Full", "Surr"],
        boxprops=lw | dict(alpha=0.5, zorder=1),
        whiskerprops=lw,
        capprops=lw,
        flierprops=dict(marker="", markersize=0, alpha=0),
        medianprops=lw | dict(color="darkred"),
    )

    colors = ["lightblue", "lightcoral"]
    for patch, color in zip(bp["boxes"], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.3)
        patch.set_zorder(1)
        # patch.set_visible(False)

    ylabel = f"#{i} \n $m_f={group.g1.multiplicity}$, $m_s={group.g2.multiplicity}$"
    ax.set_ylabel(ylabel)

    # if len(surr_y) <= 2:
    #     ticks = ax.get_yticks()
    #     ax.set_yticks(np.concat([ticks, surr_y]))
    # ax.xaxis.get_major_locator().set_params(integer=True)

    ax.set_xlim(-0.5, 1.5)
    # ax.set_yscale("log")

for ax in np.flip(axes)[:(len(axes) - len(groups))]:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)

axes[-1].legend()
fig.tight_layout()
fig.show()

# %%

evals_full.min()
# evals_surr.min()

# %%
