# %%
from typing import Iterator
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.random as jrnd
import jax.numpy as jnp
import flax.nnx as nnx
import optax

from symo.group import I, B, S
import symo.optim as optim

from symo.experiments.utils import (
    load_mnist_dataset,
    unflatten_mnist_image,
)
from symo.notebooks.plot_utils import default_rcparams
from symo.experiments.models import Activation, Autoencoder, autoencoder_loss
from symo.metrics import compute_metrics
from symo.experiments.utils import flatten_with_string_path
import kfac_jax
from symo.experiments.autoencoder_config import (
    ExperimentConfig,
    KFACConfig,
    SymoConfig,
    AdamConfig,
)
import matplotlib.pyplot as plt
from itertools import tee
from dataclasses import asdict

# %%

plt.rcParams |= default_rcparams()
Data = tuple[jax.Array, jax.Array]

# %% [markdown]
# # Configuration

# %%


cfg = ExperimentConfig()
kfac_cfg = KFACConfig()
symo_cfg = SymoConfig()
adam_cfg = AdamConfig()


acts = {
    "tanh": nnx.tanh,
    "relu": nnx.relu,
    "linear": lambda x: x,
}

# %%

config = dict(
    input_dim=cfg.input_dim,
    activation=acts[cfg.activation],
)

ae_cfg = config | dict(rngs=nnx.Rngs(cfg.seed))

# %% [markdown]
# # Data

ds_train, num_train_batches = load_mnist_dataset(cfg.seed, batch_size=cfg.batch_size)
ds_test, num_test_batches = load_mnist_dataset(
    cfg.seed, train=False, batch_size=cfg.test_batch_size
)


# %%

symo_autoencoder = Autoencoder(**ae_cfg)

# %%

activation = ae_cfg["activation"]
if activation == nnx.relu:
    group = S
elif activation == nnx.tanh:
    group = B


In = I["input", cfg.input_dim]
G1 = group["L1", 1000]
G2 = group["L2", 500]
G3 = group["L3", 250]
G4 = group["L4", 30]
G5 = group["L5", 250]
G6 = group["L6", 500]
G7 = group["L7", 1000]
Ou = I["output", cfg.input_dim]


groups_cfg = (
    ("encoder/layers/#0/kernel/.value", (In, G1)),
    ("encoder/layers/#0/bias/.value", G1),
    ("encoder/layers/#2/kernel/.value", (G1, G2)),
    ("encoder/layers/#2/bias/.value", G2),
    ("encoder/layers/#4/kernel/.value", (G2, G3)),
    ("encoder/layers/#4/bias/.value", G3),
    ("encoder/layers/#6/kernel/.value", (G3, G4)),
    ("encoder/layers/#6/bias/.value", G4),
    ("decoder/layers/#0/kernel/.value", (G4, G5)),
    ("decoder/layers/#0/bias/.value", G5),
    ("decoder/layers/#2/kernel/.value", (G5, G6)),
    ("decoder/layers/#2/bias/.value", G6),
    ("decoder/layers/#4/kernel/.value", (G6, G7)),
    ("decoder/layers/#4/bias/.value", G7),
    ("decoder/layers/#6/kernel/.value", (G7, Ou)),
    ("decoder/layers/#6/bias/.value", Ou),
)

# %%


@nnx.jit
def train_step(
    model: Autoencoder,
    optimizer: nnx.Optimizer,
    data: Data,
    l2_reg: float | None = None,
):
    old_params = jax.tree_util.tree_map(lambda params: params.copy(), nnx.state(model))

    if l2_reg == None:
        l2_reg = 0

    def loss_fn(model: Autoencoder):
        loss = autoencoder_loss(model, data, l2_reg=l2_reg, is_training=True)
        return loss

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)
    new_params = nnx.state(model)
    update = jax.tree_util.tree_map(lambda new, old: new - old, new_params, old_params)
    metrics = compute_metrics(loss, new_params, grads, update)

    return loss, metrics


@nnx.jit
def eval_loss_fn(
    model: Autoencoder,
    data: Data,
):
    return autoencoder_loss(
        model, jnp.squeeze(data["image"]), l2_reg=0, is_training=False, mse_loss=True
    )


@nnx.jit
def eval_bce_loss_fn(
    model: Autoencoder,
    data: Data,
):
    return autoencoder_loss(
        model, jnp.squeeze(data["image"]), l2_reg=0, is_training=False
    )


def eval_batch(
    model: Autoencoder,
    data_iterator: Iterator,
    num_test_batches: int,
):
    sum_val_loss = 0
    sum_val_bce_loss = 0
    for _ in range(num_test_batches):
        test_batch = next(data_iterator)
        sum_val_loss += eval_loss_fn(model, test_batch)
        sum_val_bce_loss += eval_bce_loss_fn(model, test_batch)

    return sum_val_loss / num_test_batches, sum_val_bce_loss / num_test_batches


def metrics_to_dict(metrics):
    metrics_dict = metrics._asdict()
    path_values, _ = flatten_with_string_path(metrics_dict)
    return {p: np.asarray(v) for p, v in path_values}


# %% [markdown]
# # KFAC

# %%

kfac_autoencoder = Autoencoder(**ae_cfg)
num_epochs = kfac_cfg.num_epochs
kfac_losses, kfac_tests, kfac_bce_tests = [], [], []

kfac_graphdef, kfac_autoencoder_params = nnx.split(kfac_autoencoder)


def kfac_loss_fn(model_params, batch):
    model = nnx.merge(kfac_graphdef, model_params)
    image = jnp.squeeze(batch["image"])
    loss, output = autoencoder_loss(
        model, image, l2_reg=cfg.l2_reg, is_training=True, return_output=True
    )
    kfac_jax.register_sigmoid_cross_entropy_loss(output, image)
    return loss


# %%
kfac_cfg_dict = asdict(kfac_cfg)
kfac_cfg_dict.pop("num_epochs", None)


kfac = kfac_jax.Optimizer(
    value_and_grad_func=nnx.value_and_grad(kfac_loss_fn),
    l2_reg=cfg.l2_reg,
    **kfac_cfg_dict,
)

# %%

dummy_iterator, ds_train = tee(ds_train)
dummy_batch = next(dummy_iterator)
kfac_rng = jrnd.PRNGKey(0)
optimizer_kfac = kfac.init(
    params=kfac_autoencoder_params, batch=dummy_batch, rng=kfac_rng
)

steps = 0

for epoch in range(num_epochs):
    sum_loss = 0

    for batch_id in range(num_train_batches):
        kfac_rng, _ = jax.random.split(kfac_rng)
        kfac_autoencoder_params, optimizer_kfac, kfac_metrics = kfac.step(
            kfac_autoencoder_params,
            optimizer_kfac,
            kfac_rng,
            data_iterator=ds_train,
            global_step_int=steps,
        )
        steps += 1
        sum_loss += kfac_metrics["loss"]
    kfac_losses.append(sum_loss / num_train_batches)
    kfac_autoencoder = nnx.merge(kfac_graphdef, kfac_autoencoder_params)

    eval_loss, eval_bce_loss = eval_batch(
        kfac_autoencoder,
        ds_test,
        num_test_batches,
    )

    kfac_tests.append(eval_loss)
    kfac_bce_tests.append(eval_bce_loss)

    print(f"KFAC Epoch {epoch} train loss: {kfac_losses[epoch]}")


# %% [markdown]
# # Adam

# %%

adam_autoencoder = Autoencoder(**ae_cfg)
num_epochs = adam_cfg.num_epochs
adam_losses, adam_tests, adam_bce_tests = [], [], []
adam = optax.adam(
    learning_rate=adam_cfg.lr, b1=adam_cfg.b1, b2=adam_cfg.b2, eps=adam_cfg.eps
)
optimizer_adam = nnx.Optimizer(adam_autoencoder, adam, wrt=nnx.Param)

for epoch in range(num_epochs):
    sum_loss = 0
    sum_val_loss = 0
    sum_val_bce_loss = 0

    for batch_id in range(num_train_batches):
        images = jnp.squeeze(next(ds_train)["image"])
        loss, adam_metrics = train_step(
            adam_autoencoder, optimizer_adam, images, l2_reg=cfg.l2_reg
        )
        sum_loss += loss

    adam_losses.append(sum_loss / num_train_batches)

    eval_loss, eval_bce_loss = eval_batch(
        adam_autoencoder,
        ds_test,
        num_test_batches,
    )
    adam_tests.append(eval_loss)
    adam_bce_tests.append(eval_bce_loss)

    print(f"Adam Epoch {epoch} loss: {adam_losses[epoch]}")


# %% [markdown]
# # Symo

# %%

learning_rate = optax.exponential_decay(
    symo_cfg.lr,
    decay_rate=symo_cfg.decay,
    transition_begin=0,
    transition_steps=1,
)

symo_opt = optim.symo(
    groups=groups_cfg,
    learning_rate=learning_rate,
    momentum=symo_cfg.momentum,
    damping=symo_cfg.damping,
)
optimizer = nnx.Optimizer(symo_autoencoder, symo_opt, wrt=nnx.Param)

# %%
num_epochs = symo_cfg.num_epochs

symo_losses, symo_tests, symo_bce_tests = [], [], []

for epoch in range(num_epochs):
    sum_loss = 0
    sum_val_loss = 0
    sum_val_bce_loss = 0

    for batch_id in range(num_train_batches):
        images = jnp.squeeze(next(ds_train)["image"])
        loss, symo_metrics = train_step(
            symo_autoencoder, optimizer, images, l2_reg=cfg.l2_reg
        )
        sum_loss += loss

    symo_losses.append(sum_loss / num_train_batches)

    eval_loss, eval_bce_loss = eval_batch(
        symo_autoencoder,
        ds_test,
        num_test_batches,
    )

    symo_tests.append(eval_loss)
    symo_bce_tests.append(eval_bce_loss)

    print(f"Symo Epoch {epoch} loss: {symo_losses[epoch]}")


# %%

# Define a single list of data and labels
plotting_data = [
    {
        "losses": symo_losses,
        "tests": symo_tests,
        "bce_tests": symo_bce_tests,
        "label": "SymO Tuned",
    },
    {
        "losses": adam_losses,
        "tests": adam_tests,
        "bce_tests": adam_bce_tests,
        "label": "Adam Tuned",
    },
    {
        "losses": kfac_losses,
        "tests": kfac_tests,
        "bce_tests": kfac_bce_tests,
        "label": "KFAC Tuned",
    },
]

# Define a list of y-labels for each subplot
ylabels = ["Train Loss (BCE)", "Test Error (MSE)", "Test Error (BCE)"]

# Create figure and axes
fontsize = 15
fig, axs = plt.subplots(3, 1, figsize=(10, 12), sharex=True, gridspec_kw={"hspace": 0})
fig.suptitle("MNIST Autoencoder Full Batch", fontsize=fontsize + 2)

# Loop over each subplot axis
for i, ax in enumerate(axs):
    # Plot each data series within the current subplot
    for data in plotting_data:
        # Select the correct data series based on the subplot index
        if i == 0:
            ax.plot(data["losses"], label=data["label"])
        elif i == 1:
            ax.plot(data["tests"], label=data["label"])
        else:  # i == 2
            ax.plot(data["bce_tests"], label=data["label"])

    # Set properties common to all subplots
    ax.set_yscale("log")
    ax.legend(fontsize=fontsize)
    ax.set_ylabel(ylabels[i], fontsize=fontsize)
    ax.tick_params(axis="both", which="major", labelsize=fontsize)
    ax.tick_params(axis="both", which="minor", labelsize=fontsize)

    # Customize tick parameters for specific subplots
    if i < 2:
        ax.tick_params(
            axis="x", which="both", bottom=False, top=False, labelbottom=False
        )
    else:
        ax.tick_params(axis="x", which="both", bottom=True, top=False, labelbottom=True)
        ax.set_xlabel("Iteration", fontsize=fontsize)

# Adjust the layout
fig.tight_layout()
plt.subplots_adjust(top=0.95, hspace=0)
plt.show()

# %%

test_data = jnp.squeeze(next(ds_test)["image"])

true_image = unflatten_mnist_image(test_data)
symo_recon = unflatten_mnist_image(symo_autoencoder(test_data))
adam_recon = unflatten_mnist_image(adam_autoencoder(test_data))
kfac_recon = unflatten_mnist_image(kfac_autoencoder(test_data))

selected_num_ids = [66, 123, 45]

fig, axes = plt.subplots(
    len(selected_num_ids), 4, figsize=(12, 3 * len(selected_num_ids))
)

if len(selected_num_ids) == 1:
    axes = axes.reshape(1, -1)

for i, num_id in enumerate(selected_num_ids):
    if i == 0:
        axes[i, 0].set_title(f"True")
        axes[i, 1].set_title(f"Symo")
        axes[i, 2].set_title(f"Adam")
        axes[i, 3].set_title(f"KFAC")

    axes[i, 0].imshow(true_image[num_id, :, :], vmin=0, vmax=1.0)
    axes[i, 1].imshow(symo_recon[num_id, :, :], vmin=0, vmax=1.0)
    axes[i, 2].imshow(adam_recon[num_id, :, :], vmin=0, vmax=1.0)
    axes[i, 3].imshow(kfac_recon[num_id, :, :], vmin=0, vmax=1.0)

    for j in range(4):
        axes[i, j].axis("off")


plt.tight_layout()

plt.show()


# %%
