# %%
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

from functools import partial
from flax import nnx

from symo.notebooks.attention import TransformerModel, Encoding
from symo.notebooks.plot_utils import default_rcparams, orange_blue, blue_white_orange
from symo.model_factor import FactorTable
from symo.factor import I, S, O, B

from flax import nnx
from flax.training import train_state
import optax

from typing import Callable
import numpy as np
import matplotlib.pyplot as plt


plt.rcParams |= default_rcparams()

DictKey = jax.tree_util.DictKey

print(jax.devices())

# %%


def identity(x):
    return x


L = 4
din = 5
dhidden = 5
dout = 5
use_bias = True
activation = jax.nn.tanh

lr = 0.8

mlp_args = dict(
    din=din,
    dhidden=dhidden,
    dout=dout,
    nlayers=L,
    use_bias=use_bias,
    activation=activation,
)

num_train_data = 5000
num_test_data = 5000


D = np.logspace(0, -5, din)

Sigma, R = np.linalg.qr(np.random.normal(np.zeros((din, din)), 1))
Sigma = (Sigma * D) @ Sigma.T
inputs_train = jnp.array(
    np.random.multivariate_normal(np.zeros((din,)), Sigma, size=(num_train_data,))
)
inputs_test = jnp.array(
    np.random.multivariate_normal(np.zeros((din,)), Sigma, size=(num_test_data,))
)


def custom_kernel_init(key, shape, dtype=jnp.float32):
    din = shape[0]  # For Linear, shape is (in_features, out_features)
    std = 1.0 / jnp.sqrt(din)
    return jax.random.normal(key, shape, dtype) * std


def custom_bias_init(key, shape, dtype=jnp.float32):
    std = 1 / 3
    return jax.random.normal(key, shape, dtype) * std


class MLP(nnx.Module):
    def __init__(
        self,
        din,
        dhidden,
        dout,
        nlayers,
        use_bias,
        *,
        rngs: nnx.Rngs,
        activation: Callable,
    ):
        self.nlayers = nlayers
        self.linear_arrs = []
        self.linear_arrs += [
            nnx.Linear(
                din,
                dhidden,
                use_bias=use_bias,
                kernel_init=custom_kernel_init,
                bias_init=custom_bias_init,
                rngs=rngs,
            )
        ]
        for _ in range(nlayers - 2):
            self.linear_arrs += [
                nnx.Linear(
                    dhidden,
                    dhidden,
                    use_bias=use_bias,
                    kernel_init=custom_kernel_init,
                    bias_init=custom_bias_init,
                    rngs=rngs,
                )
            ]
        self.linear_arrs += [
            nnx.Linear(
                dhidden,
                dout,
                use_bias=use_bias,
                kernel_init=custom_kernel_init,
                bias_init=custom_bias_init,
                rngs=rngs,
            )
        ]
        self.activation = activation

    def __call__(self, x):
        x = self.linear_arrs[0](x)
        x = self.activation(x)
        for l in range(self.nlayers - 2):
            x = self.linear_arrs[1 + l](x)
            x = self.activation(x)
        x = self.linear_arrs[-1](x)
        return x


class TrainState(train_state.TrainState):
    graphdef: nnx.GraphDef


student_model = MLP(**mlp_args, rngs=nnx.Rngs(0))
teacher_model = MLP(**mlp_args, rngs=nnx.Rngs(1))

outputs_train = teacher_model(inputs_train)
outputs_test = teacher_model(inputs_test)


graphdef, params = nnx.split(student_model, nnx.Param)

state = TrainState.create(
    apply_fn=None, graphdef=graphdef, params=params, tx=optax.sgd(lr)
)
del params


@jax.jit
def train_step(state: TrainState, batch):
    x, y = batch

    def loss_fn(params):
        model = nnx.merge(state.graphdef, params)
        y_pred = model(x)
        loss = jnp.mean(0.5 * (y - y_pred) ** 2)
        return loss

    grads = jax.grad(loss_fn)(state.params)
    # sdg update
    state = state.apply_gradients(grads=grads)

    return state


@jax.jit
def test_step(state: nnx.TrainState[MLP], batch):
    x, y = batch
    model = nnx.merge(state.graphdef, state.params)
    y_pred = model(x)
    loss = jnp.mean(0.5 * (y - y_pred) ** 2)
    return {"loss": loss}


total_steps = 10_000
losses = []
for step in range(total_steps):
    state = train_step(state, (inputs_train, outputs_train))
    logs = test_step(state, (inputs_train, outputs_train))
    losses += [logs["loss"]]
    if step % 100 == 0:
        print(f"step: {step}, loss: {logs['loss']}")

model = nnx.merge(state.graphdef, state.params)


# %%
plt.plot(losses)
plt.ylabel("Test Loss")
plt.xlabel("Training Iteratoins")
plt.yscale("log")
plt.show()


# %%
def configure_mlp(
    key: jax.Array,
    lr: float,
    din: int,
    dhidden: int,
    dout: int,
    nlayers: int,
    use_bias: bool = False,
    activation: Callable = jax.nn.tanh,
) -> tuple[MLP, nnx.Optimizer]:
    rngs = nnx.Rngs(key)

    model = MLP(
        din=din,
        dhidden=dhidden,
        dout=dout,
        nlayers=nlayers,
        use_bias=use_bias,
        activation=activation,
        rngs=rngs,
    )
    optimizer = nnx.Optimizer(model, optax.sgd(lr))
    return model, optimizer


def param_dict(params):
    l, _ = jax.tree_util.tree_flatten_with_path(params)
    d = dict(
        [
            (
                ".".join(str(k.key) for k in ks if isinstance(k, DictKey)),
                jnp.atleast_2d(v),
            )
            for ks, v in l
        ]
    )

    return d


def get_params(model):
    return nnx.state(model, nnx.Param)


def param_count(model, batched: bool = False):
    params = param_dict(get_params(model))
    counts = dict()
    for k, v in params.items():
        counts[k] = v[0].size if batched else v.size
    return counts


def vecop(tensor: jax.Array):
    if tensor.ndim == 1:
        return tensor.flatten()
    elif tensor.ndim > 1:
        return tensor.flatten()
    return tensor.reshape(1)


seed = 2001
rnd_key = jax.random.PRNGKey(seed)
model_key, data_key = jax.random.split(rnd_key, 2)
num_models = 5000


batch_model_key = jax.random.split(model_key, num_models)
configure_mlps = partial(configure_mlp, lr=0.8, **mlp_args)
batch_models, batch_optimizers = nnx.vmap(configure_mlps)(batch_model_key)
batch_model_params = param_dict(get_params(batch_models))
param_vecs = jax.tree.map(lambda x: jax.vmap(vecop)(x), batch_model_params)
param_vecs

# %%
param_aliases = []
for l in range(L):
    param_aliases += [("linear_arrs." + str(l) + ".kernel", r"$W_" + str(l) + "$")]
    if use_bias:
        param_aliases += [("linear_arrs." + str(l) + ".bias", r"$b_" + str(l) + "$")]


# %%
print(param_aliases)

# %%
counts = param_count(batch_models, batched=True)
flat_params = jnp.concat([param_vecs[k] for k, _ in param_aliases], axis=-1)


# %%
def loss_fn(model, batch):
    inputs, targets = batch
    logits = model(inputs)
    loss = 0.5 * jnp.mean((targets - logits) ** 2)
    return loss


# %%


@nnx.jit
def train_step(
    model: TransformerModel,
    optimizer: nnx.Optimizer,
    data: tuple[jax.Array, jax.Array],
):
    loss, grads = nnx.value_and_grad(loss_fn)(model, data)
    optimizer.update(grads)
    return loss


batch_train_step_fn = nnx.vmap(train_step, in_axes=(0, 0, None))

# %%
inputs = inputs_train
targets = outputs_train
data = (inputs, targets)


# %%
def compute_grads(model, data):
    grads = nnx.grad(loss_fn)(model, data)
    return grads


def flatten(tensor: jax.Array):
    return tensor.flatten()


batch_compute_grads_fn = nnx.vmap(
    compute_grads, in_axes=(0, None), axis_size=num_models
)
batch_compute_loss_fn = nnx.vmap(loss_fn, in_axes=(0, None), axis_size=num_models)

batch_grads = batch_compute_grads_fn(batch_models, data)


# %%

batch_grads_dict = param_dict(batch_grads)
grad_vecs = jax.tree.map(lambda x: jax.vmap(flatten)(x), batch_grads_dict)
flat_grads = jnp.concat([grad_vecs[name] for name, _ in param_aliases], axis=-1)
batch_grads_list = [batch_grads_dict[name] for name, _ in param_aliases]


# %%
@jax.jit
def covariance(x):
    x_outer = jax.vmap(jnp.outer)(x, x).mean(axis=0)
    x_mean = x.mean(axis=0)
    mean_outer = x_mean[:, None] @ x_mean[None, :]
    x_cov = x_outer - mean_outer
    return x_cov, x_outer, mean_outer


grad_cov, _, _ = jnp.array(covariance(flat_grads))


inv_sqrt_scale = 1 / np.sqrt(grad_cov.diagonal())
inv_sqrt_scale_diag = np.diag(inv_sqrt_scale)

cov_est_scaled = inv_sqrt_scale_diag @ grad_cov @ inv_sqrt_scale_diag

# %%

Gi = I["i"](din)
if activation == identity:
    Gh = O
elif activation == jax.nn.tanh:
    Gh = B
else:
    Gh = S

Go = I["o"](dout)

groups_layers = [[Gi, Gh["h", 1](dhidden)]]

if use_bias:

    groups_layers += [
        [
            Gh["h", 1](dhidden),
        ],
    ]

if L > 2:
    for l in range(2, L):

        groups_layers += [
            [Gh["h", l - 1](dhidden), Gh["h", l](dhidden)],
        ]
        if use_bias:
            groups_layers += [
                [
                    Gh["h", l](dhidden),
                ],
            ]
else:
    l = 1

groups_layers += [
    [Gh["h", l](dhidden), Go],
]

if use_bias:
    groups_layers += [
        [
            Go,
        ],
    ]


print(groups_layers)

# %%
rnd_key, subkey = jax.random.split(rnd_key)
factors = FactorTable(groups_layers, subkey)
# %%
# factors.from_cov(cov_est_scaled)
factors.from_grads(batch_grads_list)


# %%
means = flat_grads.mean(axis=0)[..., None]
covariance = factors.to_cov() - means @ means.T


# %%
def matshow_named_axes(
    ax,
    param_names: list[tuple[str, str, int, int]],
    shift: float = -3,
    width: float = 0.5,
):
    for _, alias, size, end in param_names:
        start = end - size

        ax.axhline(y=start - width, color="gray", linestyle="-", linewidth=width)

        ax.axvline(x=start - width, color="gray", linestyle="-", linewidth=width)

        ax.text(
            shift,
            (start + end) / 2,
            alias,
            horizontalalignment="center",
            verticalalignment="center",
            fontsize=6,
            # fontweight="bold",
        )

        ax.text(
            (start + end) / 2,
            shift,
            alias,
            horizontalalignment="center",
            verticalalignment="center",
            fontsize=6,
            # fontweight="bold",
        )


param_alias, param_counts = zip(
    *[(alias, counts[name]) for (name, alias) in param_aliases]
)
param_cum_counts = np.cumsum(param_counts)
names_alias_counts = [
    (name, alias, param_counts[i], param_cum_counts[i])
    for i, (name, alias) in enumerate(param_aliases)
]

# %%
fig, ax = plt.subplots()


mat = cov_est_scaled

cmap = orange_blue().reversed()

norm = mcolors.TwoSlopeNorm(
    vcenter=0,
)

im = ax.matshow(mat, cmap=cmap, norm=norm)
cbar1 = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
matshow_named_axes(ax, names_alias_counts)

rows, cols = mat.shape

ax.set_title(
    r"Empirical $\Sigma$ MLP",
    y=1.05,
)
ax.set_xticks([])
ax.set_yticks([])


clim_min = np.min(mat)
clim_max = np.max(mat)
clim = 1
im.set_clim(clim_min, clim_max)

fig.tight_layout()
fig.show()

# %%

inv_sqrt_factor_cov = 1 / np.sqrt(covariance.diagonal())
inv_sqrt_factor_cov_diag = np.diag(inv_sqrt_factor_cov)

covariance_scaled = inv_sqrt_factor_cov_diag @ covariance @ inv_sqrt_factor_cov_diag


# %%

# %%

fig, ax = plt.subplots()

mat = covariance_scaled
cmap = orange_blue().reversed()

norm = mcolors.TwoSlopeNorm(
    vcenter=0,
)

im = ax.matshow(mat, cmap=cmap, norm=norm)
cbar1 = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
matshow_named_axes(ax, names_alias_counts)

rows, cols = mat.shape

ax.set_title(
    r"Theory $\Sigma$ MLP",
    y=1.05,
)
ax.set_xticks([])
ax.set_yticks([])

clim_min = np.min(mat)
clim_max = np.max(mat)
im.set_clim(clim_min, clim_max)

fig.tight_layout()
fig.show()
