# %%

from dataclasses import dataclass, asdict
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jrnd
import flax.nnx as nnx
import optax

from symo.factory import FactorGrid

from symo.notebooks.plot_utils import default_rcparams
from symo.experiments.mlp_groups import group_config
from symo.experiments.models import Activation, MLP
from symo.experiments.utils import sync_order_values
from symo.data import mlp_teacher_data
from symo.notebooks.utils import analyze_repeated_eigvals, align_eigval_groups, align_eigval

import matplotlib.pyplot as plt

plt.rcParams |= default_rcparams(dpi=500)

# %% [markdown]
# # Model setup

# %%


@dataclass(frozen=True)
class ExperimentConfig:
    # Common
    device: str = "cpu"
    seed: int = 2025

    # Data
    num_train_points = 5000
    num_test_points = 5000


depth: int = 3
skip_every: int | None = 1
if depth == 1:
    skip_every = None


@dataclass(frozen=True)
class MLPConfig:
    # Model
    input_dim: int = 13
    hidden_dims: tuple[int, ...] = (35,) * depth
    output_dim: int = 11
    skip_every: int | None = skip_every
    use_bias: bool = False
    activation: Activation = "relu"


# %%

num_epochs = 500


@dataclass(frozen=True)
class SymoConfig:
    num_epochs: int = num_epochs
    # grad_momentum: float = 0.1
    # param_momentum: float = 0.1
    grad_momentum: float = 0.0
    param_momentum: float = 0.0
    decay: float = 0.98
    damping: float = 1e-15
    lr: float = 0.5


@dataclass(frozen=True)
class AdamConfig:
    num_epochs: int = num_epochs
    lr: float = 1e-2


# %%

exp_teacher_cfg = ExperimentConfig(seed=1)
exp_cfg = ExperimentConfig()
mlp_cfg = MLPConfig()
symo_cfg = SymoConfig()
adam_cfg = AdamConfig()

# %%

mlp_teacher = MLP(rngs=nnx.Rngs(exp_teacher_cfg.seed), **asdict(mlp_cfg))

# %%

key_teacher = jrnd.PRNGKey(exp_teacher_cfg.seed)

rnd_key, (train_data, test_data) = mlp_teacher_data(
    key_teacher,
    mlp_teacher,
    exp_cfg.num_train_points,
    exp_cfg.num_test_points,
)

# %%

mlp = MLP(rngs=nnx.Rngs(exp_cfg.seed), **asdict(mlp_cfg))

# %%

mlp_params = nnx.state(mlp, nnx.Param)
group_spec = group_config(mlp)
group_spec = sync_order_values(dict(group_spec), mlp_params)
group_spec_tuple = tuple([g for _, g in group_spec])

# %%

# %%


def loss_fn(model: MLP):
    x, y = train_data
    pred = model(x)
    mse = optax.losses.squared_error(pred, y).mean()
    return mse


loss, grad = nnx.value_and_grad(loss_fn)(mlp)

# %% [markdown]
# Compute grid of G-invariant factors

# %%

factor_grid = FactorGrid(group_spec_tuple)

# %%

param_flat, treedef = jax.tree.flatten(mlp_params)
grad_flat, _ = jax.tree.flatten(grad)

# %% [markdown]
# Compute G-invariant parameters

# %%

fixed_param_flat = factor_grid.invariant_mean(param_flat)
fixed_param = jax.tree.unflatten(treedef, fixed_param_flat)

graph, _ = nnx.split(mlp, nnx.Param)
fixed_mlp = nnx.merge(graph, fixed_param)

# %% [markdown]
# Evaluate gradients at G-invariant parameters

# %%

fixed_loss, fixed_grad = nnx.value_and_grad(loss_fn)(fixed_mlp)

# %%

fixed_grad_flat = factor_grid.invariant_mean(grad_flat)
fixed_grad = jax.tree.unflatten(treedef, fixed_grad_flat)

# %%


# %%
