# %%
from functools import partial
import numpy as np
from typing import Literal
from dataclasses import dataclass, asdict
import jax
import jax.numpy as jnp
import jax.random as jrnd
import matplotlib.pyplot as plt
import optax
from flax import nnx
from folx import ForwardLaplacianOperator, forward_laplacian

from symo.group import I, B, S
from symo.metrics import compute_metrics
import symo.optim as optim
import symo.optim2 as optim2
from symo.experiments.models import Activation, MLP
from symo.notebooks.plot_utils import default_rcparams
from symo.experiments.mlp_groups import group_config
from symo.experiments.utils import inverse_step_schedule, train_loop
from symo.experiments.pinns_config import (
    MLP64Config,
    MLP64_64_48_48Config,
    Muon64Config,
    Symo64Config,
    Symo2_64Config,
    Adam64Config,
    SGD64Config,
)
from symo.experiments.poisson_data import ArrayPair, create_poisson_data

# %%

plt.rcParams |= default_rcparams()
jax.config.update("jax_enable_x64", True)

# %%

Data4 = tuple[ArrayPair, ArrayPair]

# %%


@dataclass(frozen=True)
class ExperimentConfig:
    # Common
    device: str = "cpu"
    seed: int = 2025

    # Data
    num_train_interior: int = 900
    num_train_boundary: int = 120
    num_test: int = 9000


# %%

num_epochs = 6000


@dataclass(frozen=True)
class Symo64ConfigExp(Symo64Config):
    num_epochs: int = num_epochs


@dataclass(frozen=True)
class Symo2_64ConfigExp1(Symo2_64Config):
    num_epochs: int = num_epochs
    lr: float = 1
    sigma_g_bias_corr: bool = True


@dataclass(frozen=True)
class Symo2_64ConfigExp2(Symo2_64Config):
    num_epochs: int = num_epochs
    lr: float = 2
    grad_beta: float = 0.9
    sigma_g_beta: float = 0.9
    damping: float = 1e-8
    grad_bias_corr: bool = True
    sigma_g_bias_corr: bool = True


@dataclass(frozen=True)
class Adam64ConfigExp(Adam64Config):
    num_epochs: int = num_epochs


@dataclass(frozen=True)
class SGD64ConfigExp(SGD64Config):
    num_epochs: int = num_epochs


@dataclass(frozen=True)
class Muon64ConfigExp(Muon64Config):
    num_epochs: int = num_epochs


# %% [markdown]
# ## Configs

# %%

cfg = ExperimentConfig()
model_cfg = MLP64Config()

# %% [markdown]
# ## Dataset

# %%

key = jrnd.PRNGKey(cfg.seed)
key, train_data, val_data = create_poisson_data(
    key,
    2,
    cfg.num_train_interior,
    cfg.num_train_boundary,
    cfg.num_test,
)


# %%


def record_train_step(model: MLP, optimizer: nnx.Optimizer, data: Data4):
    old_params = nnx.state(model)
    loss, grads = train_step(model, optimizer, data)
    new_params = nnx.state(model)
    update = jax.tree.map(lambda new, old: new - old, new_params, old_params)
    metrics = compute_metrics(loss, new_params, grads, update)
    return loss, metrics


@nnx.jit
def train_step(model: MLP, optimizer: nnx.Optimizer, data: Data4):
    (x_bnd, y_bnd), (x_in, y_in) = data

    def loss_fn(model: MLP):
        lapl_ = ForwardLaplacianOperator(0)(model)
        lapl = jax.vmap(lapl_)
        y_in_pred, _ = lapl(x_in)

        y_bnd_pred = model(x_bnd)
        in_loss = jnp.mean((y_in_pred + y_in) ** 2)
        bnd_loss = jnp.mean((y_bnd_pred - y_bnd) ** 2)
        loss = 0.5 * (in_loss + bnd_loss)
        return loss

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)

    return loss, grads


@nnx.jit
def eval_loss(model: MLP, batch) -> jax.Array:
    x, y = batch
    y_pred = model(x)
    loss = jnp.sqrt(jnp.mean((y - y_pred) ** 2))
    return loss


# %%


def symo_run(seed: int, train_data, val_data, model_config, opt_cfg, opt_fn):
    model_config = asdict(model_config) | dict(
        rngs=nnx.Rngs(seed),
        # kernel_init=mlp_kernel_init,
        # bias_init=mlp_bias_init,
    )
    model = MLP(**model_config)

    nnx.display(model)

    group_spec = group_config(model, hid_group=S, same=True)
    # group_spec = group_config(model)

    learning_rate = inverse_step_schedule(opt_cfg.lr, transition_begin=0)
    # learning_rate = optax.cosine_decay_schedule(opt_cfg.lr, opt_cfg.num_epochs)
    # learning_rate = optax.linear_schedule(opt_cfg.lr, 1e-6, transition_steps=1000)

    # learning_rate = opt_cfg.lr

    opt_dict = asdict(opt_cfg)
    del opt_dict["num_epochs"]
    del opt_dict["lr"]

    opt = opt_fn(
        group_spec,
        learning_rate=learning_rate,
        **opt_dict,
    )

    optimizer = nnx.Optimizer(model, opt, wrt=nnx.Param)
    out = train_loop(
        model,
        train_data,
        val_data,
        optimizer,
        record_train_step,
        eval_loss,
        num_epochs=opt_cfg.num_epochs,
    )

    return out


# %%


def opt_run(seed, train_data, val_data, model_config, opt_cfg, opt_fn):
    model_config = asdict(model_config) | dict(
        rngs=nnx.Rngs(seed),
        # kernel_init=mlp_kernel_init,
        # bias_init=mlp_bias_init,
    )
    model = MLP(**model_config)
    nnx.display(model)

    opt = opt_fn(learning_rate=opt_cfg.lr)
    optimizer = nnx.Optimizer(model, opt, wrt=nnx.Param)

    out = train_loop(
        model,
        train_data,
        val_data,
        optimizer,
        train_step,
        eval_loss,
        num_epochs=opt_cfg.num_epochs,
    )
    return out


# %%

symo1_run = partial(symo_run, opt_fn=optim.symo)
symo2_run = partial(symo_run, opt_fn=optim.symo2)
adam_run = partial(opt_run, opt_fn=optax.adam)
sgd_run = partial(opt_run, opt_fn=optax.sgd)
muon_run = partial(opt_run, opt_fn=optax.contrib.muon)

common_args = [cfg.seed, train_data, val_data, model_cfg]

opt_configs = (
    # ("Symo 1", (symo1_run, [*common_args, Symo64ConfigExp()])),
    # ("Symo 2.2", (symo2_run, [*common_args, Symo2_64ConfigExp1()])),
    ("Symo 2.3", (symo2_run, [*common_args, Symo2_64ConfigExp2()])),
    ("Adam", (adam_run, [*common_args, Adam64ConfigExp()])),
    ("SGD", (sgd_run, [*common_args, SGD64ConfigExp()])),
)

# %%

outputs = []
for name, (run_fn, args) in opt_configs:
    print(f">>> Running: {name}")
    out = run_fn(*args)
    outputs.append((name, out))


# %%

# u = int(1e3)
u = int(-1)

ax_args = dict(linewidth=0.5, alpha=0.5)
fontsize = 10
fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True)


for name, (losses, vals, _) in outputs:
    axes[0].plot(losses, label=name, **ax_args)
    axes[1].plot(vals, label=name, **ax_args)

for ax in axes:
    ax.set_yscale("log")
    # ax.set_xscale("log")

axes[0].set_ylabel("Loss (train)")
axes[0].legend(fontsize=fontsize)
axes[0].set_title(
    rf"PINNs {len(model_cfg.hidden_dims)}-layer $\text{{{model_cfg.activation}}}$",
)

axes[1].set_xlabel("Iteration")
axes[1].set_ylabel("RMSE (validation)")

fig.show()

# %%
