# %%
from decimal import Decimal
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

from functools import partial
from flax import nnx

from symo.notebooks.attention import TransformerModel, Encoding
from symo.notebooks.plot_utils import default_rcparams, orange_blue
from symo.model_factor import FactorTable
from symo.group import I, S, O, B, Eq

import optax
import numpy as np

import matplotlib.pyplot as plt

plt.rcParams |= default_rcparams()
DictKey = jax.tree_util.DictKey
SequenceKey = jax.tree_util.SequenceKey
GetAttrKey = jax.tree_util.GetAttrKey


# %%

epochs = 10
seq_len = 5
max_seq_len = max(15, seq_len)
vocab_size = 10
embed_dim = 4
batch_size = 100
ff_dim = 10
out_dim = 1
num_models = 5000
learning_rate = 0.5
jitter = 1e-4
encoding: Encoding = "lin"
# encoding: Encoding = "sin"
# encoding: Encoding = "lin"

transformer_args = dict(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=out_dim,
    ff_dim=ff_dim,
    num_layers=1,
    use_bias=False,
    dropout=0.0,
    encoding=encoding,
    regression=False,
)


# %%

device = "cpu"
seed = 2001
# seed = 2003
rnd_key = jax.random.PRNGKey(seed)
model_key, data_key = jax.random.split(rnd_key, 2)


# %%

def configure_transformer(
    key: jax.Array,
    lr: float = 1e-1,
    max_seq_len: int = 50,
    vocab_size: int = 7,
    embed_dim: int = 6,
    num_heads: int = 1,
    ff_dim: int = 10,
    num_layers: int = 1,
    use_bias: bool = False,
    dropout: float = 0.0,
    encoding: Encoding = "lin",
    *,
    regression: bool = False,
) -> tuple[TransformerModel, nnx.Optimizer]:
    rngs = nnx.Rngs(key)
    model = TransformerModel(
        vocab_size=vocab_size,
        embed_dim=embed_dim,
        num_heads=num_heads,
        ff_dim=ff_dim,
        num_layers=num_layers,
        dropout=dropout,
        max_seq_len=max_seq_len,
        encoding=encoding,
        bias=use_bias,
        regression=regression,
        rngs=rngs,
    )
    optimizer = nnx.Optimizer(model, optax.sgd(lr))
    return model, optimizer


# %%

batch_model_key = jax.random.split(model_key, num_models)
configure_transformers = partial(
    configure_transformer, lr=learning_rate, **transformer_args
)

batch_models, batch_optimizers = nnx.vmap(configure_transformers)(batch_model_key)


# %%
def param_dict(params):
    l, _ = jax.tree_util.tree_flatten_with_path(params)
    d = dict(
        [
            (
                ".".join(str(k.key) for k in ks if isinstance(k, DictKey)),
                jnp.atleast_2d(v),
            )
            for ks, v in l
        ]
    )

    return d


def get_params(model):
    return nnx.state(model, nnx.Param)


def param_slices(params, param_order, batched: bool = False):
    counts = np.array(
        [(params[k][0].size if batched else params[k].size) for k in param_order]
    )
    ends = np.cumsum(counts)
    begs = np.array([0, *ends[:-1]])
    slices = {param_order[i]: slice(b, e) for i, (b, e) in enumerate(zip(begs, ends))}
    return slices


def param_count(model, batched: bool = False):
    params = param_dict(get_params(model))
    counts = dict()
    for k, v in params.items():
        counts[k] = v[0].size if batched else v.size
    return counts


def vecop(tensor: jax.Array):
    if tensor.ndim == 1:
        return tensor.flatten()
    elif tensor.ndim > 1:
        return tensor.flatten()
    return tensor.reshape(1)


# %%

batch_model_params = param_dict(get_params(batch_models))
param_vecs = jax.tree.map(lambda x: jax.vmap(vecop)(x), batch_model_params)
param_vecs
# %%
param_candidates = [
    (
        "embedding.embedding",
        r"$W^{0}_e$",
    ),
    (
        "transformer_blocks.{0}.attention.query.kernel",
        r"$W^{0}_q$",
    ),
    (
        "transformer_blocks.{0}.attention.key.kernel",
        r"$W^{0}_k$",
    ),
    (
        "transformer_blocks.{0}.attention.value.kernel",
        r"$W^{0}_v$",
    ),
    (
        "transformer_blocks.{0}.attention.out_proj.kernel",
        r"$P^{0}_a$",
    ),
    (
        "transformer_blocks.{0}.norm1.scale",
        r"$\ell^{0}_{{w,1}}$",
    ),
    (
        "transformer_blocks.{0}.norm1.bias",
        r"$\ell^{0}_{{b,1}}$",
    ),
    (
        "transformer_blocks.{0}.feed_forward.linear1.kernel",
        r"$W^{0}_{{\text{{ff}}_1}}$",
    ),
    (
        "transformer_blocks.{0}.feed_forward.linear2.kernel",
        r"$W^{0}_{{\text{{ff}}_2}}$",
    ),
    (
        "transformer_blocks.{0}.norm2.scale",
        r"$\ell^{0}_{{w,2}}$",
    ),
    (
        "transformer_blocks.{0}.norm2.bias",
        r"$\ell^{0}_{{b,2}}$",
    ),
    (
        "output_proj.kernel",
        r"$P^{0}_o$",
    ),
]

# %%

param_aliases = [
    (p.format(i), n.format(i)) for p, n in param_candidates for i in range(1)
]

param_alias_keys = [k for k, _ in param_aliases]

counts = param_count(batch_models, batched=True)
flat_params = jnp.concat([param_vecs[k] for k, _ in param_aliases], axis=-1)

# %%

print(counts)
#
# %%


def loss_fn(model, batch):
    inputs, targets = batch
    logits = model(inputs, deterministic=False)

    logits_flat = logits.reshape(-1, vocab_size)
    targets_flat = targets.reshape(-1)

    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits_flat, targets_flat
    ).mean()
    return loss


@nnx.jit
def train_step(
    model: TransformerModel,
    optimizer: nnx.Optimizer,
    data: tuple[jax.Array, jax.Array],
):
    loss, grads = nnx.value_and_grad(loss_fn)(model, data)
    optimizer.update(grads)
    return loss


batch_train_step_fn = nnx.vmap(train_step, in_axes=(0, 0, None))
# %%


def generate_toy_language_data(
    key: jax.Array,
    batch_size: int,
    seq_len: int,
    vocab_size: int,
):
    key1, key2 = jax.random.split(key, 2)
    inputs = jax.random.randint(
        key1, (batch_size, seq_len), minval=0, maxval=vocab_size
    )
    targets = jax.random.randint(
        key2, (batch_size, seq_len), minval=0, maxval=vocab_size
    )
    return inputs, targets


inputs, targets = generate_toy_language_data(
    data_key,
    batch_size,
    seq_len,
    vocab_size,
)

data = (inputs, targets)

# %%
# losses = []
# for epoch in range(epochs):
#    loss = batch_train_step_fn(batch_models, batch_optimizers, data)
#    losses.append(np.asarray(loss))
#    print(f"Epoch {epoch} loss: {losses[epoch].mean()}")


# %%
def compute_grads(model, data):
    grads = nnx.grad(loss_fn)(model, data)
    return grads


def flatten(tensor: jax.Array):
    return tensor.flatten()


batch_compute_grads_fn = nnx.vmap(compute_grads, in_axes=(0, None))
batch_grads = batch_compute_grads_fn(batch_models, data)

# %%

batch_grads_dict = param_dict(batch_grads)
grad_vecs = jax.tree.map(lambda x: jax.vmap(flatten)(x), batch_grads_dict)
flat_grads = jnp.concat([grad_vecs[name] for name, _ in param_aliases], axis=-1)
batch_grads_list = [batch_grads_dict[name] for name, _ in param_aliases]

# %%
flat_grads.shape

# %%


@jax.jit
def covariance(x):
    x_outer = jax.vmap(jnp.outer)(x, x).mean(axis=0)
    x_mean = x.mean(axis=0)
    mean_outer = x_mean[:, None] @ x_mean[None, :]
    x_cov = x_outer - mean_outer
    return x_cov, x_outer, mean_outer


def diag_scale(mat):
    mat_scale = 1 / np.sqrt(mat.diagonal())
    scaled = mat_scale[None, :] * mat * mat_scale[:, None]
    return scaled


# %%

_, grad_cov, _ = jnp.array(covariance(flat_grads))

cov_est_scaled = diag_scale(grad_cov)

# %%

# v: vocabulary size
# d: embed dim
# h: hidden dim in ff
# o: output dim

# input group
Gx = I["vocab-size", vocab_size]  # "S_v_1"

# Embedding group
Ge = S["embed-dim", embed_dim]  # "I_d_2"

# Attention
Gq = O["embed-dim-qk", embed_dim]
Gk = Gq
Gv = O["embed-dim-v", embed_dim]

# Feedforward
Gf = S["ff-dim", ff_dim]

groups_layers = [
    [Gx, Ge],
    [Ge, Gk],
    [Ge, Gq],
    [Ge, Gv],
    [Gv, Ge],
    [Ge],
    [Ge],
    [Ge, Gf],
    [Gf, Ge],
    [Ge],
    [Ge],
    [Ge, Gx],
]

factors_table = []

# %%

rnd_key, subkey = jax.random.split(rnd_key)
factors = FactorTable(groups_layers, subkey)

# factors.from_cov(cov_est_scaled)
factors.from_grads(batch_grads_list)

avg_grads = [a.mean(axis=0) for a in batch_grads_list]
mvp_out = factors.matvec(avg_grads)
mvp_out = jnp.concat([grad.flatten() for grad in mvp_out])

covariance = factors.to_cov()
full_mvp = covariance @ jnp.concat([grad.flatten() for grad in avg_grads])

cmap = orange_blue().reversed()
norm = mcolors.TwoSlopeNorm(
    vcenter=0,
)

fig, axs = plt.subplots(1, 2)
axs[0].matshow(mvp_out.reshape(12, 20), cmap=cmap, norm=norm, aspect="equal")
axs[1].matshow(full_mvp.reshape(12, 20), cmap=cmap, norm=norm, aspect="equal")

diff = jnp.linalg.norm(mvp_out - full_mvp) / jnp.linalg.norm(full_mvp)
print("normalized distance:", diff)

# %%

# invert through direct
iden = jnp.eye(covariance.shape[0])

L_cov = jnp.linalg.cholesky(
    covariance + iden * jitter,
)
inv_cov = jax.lax.linalg.triangular_solve(L_cov, iden, left_side=True, lower=True)
inv_cov_theory = jax.lax.linalg.triangular_solve(
    L_cov.T, inv_cov, left_side=True, lower=False
)

# %%
iden = jnp.eye(grad_cov.shape[0])

L_cov = jnp.linalg.cholesky(
    grad_cov + iden * jitter,
)
inv_cov = jax.lax.linalg.triangular_solve(L_cov, iden, left_side=True, lower=True)
inv_cov_emp = jax.lax.linalg.triangular_solve(
    L_cov.T, inv_cov, left_side=True, lower=False
)

# %%
# invert through surrogate
surrogate = factors.surrogate()

# check whether factors not changing
factors.from_surrogate(surrogate)
refactor_covariance = factors.to_cov()
print(jnp.abs(refactor_covariance - covariance).max())

# %%
iden = jnp.eye(surrogate.shape[0])

L = jnp.linalg.cholesky(
    surrogate + iden * jitter,
)
inv_surr = jax.lax.linalg.triangular_solve(L, iden, left_side=True, lower=True)
inv_surr = jax.lax.linalg.triangular_solve(L.T, inv_surr, left_side=True, lower=False)

factors.from_surrogate(inv_surr)
inv_cov_surr = factors.to_cov()


# %%


def matshow_named_axes(
    ax,
    param_names: list[tuple[str, str, int, int]],
    shift: float = -10,
    width: float = 0.1,
):
    for _, alias, size, end in param_names:
        start = end - size
        ax.axhline(y=start - 0.5, color="gray", linestyle="-", linewidth=width)
        ax.axvline(x=start - 0.5, color="gray", linestyle="-", linewidth=width)
        ax.text(
            shift,
            (start + end) / 2,
            alias,
            horizontalalignment="center",
            verticalalignment="center",
            fontsize=6,
            # fontweight="bold",
        )

        ax.text(
            (start + end) / 2,
            shift,
            alias,
            horizontalalignment="center",
            verticalalignment="center",
            fontsize=6,
            # fontweight="bold",
        )


def matshow_norm_diff(
    ax,
    param_names: list[tuple[str, str, int, int]],
    mat,
    true_inv,
):
    for _, _, sizei, endi in param_names:
        starti = endi - sizei
        midi = sizei // 2
        for _, _, sizej, endj in param_names:
            startj = endj - sizej
            midj = sizej // 2
            pred = mat[starti:endi, startj:endj]
            label = true_inv[starti:endi, startj:endj]

            norm = jnp.linalg.norm(pred - label) / jnp.linalg.norm(label)
            if not jnp.isnan(norm):
                text = "%.0E" % Decimal(str(norm))
                ax.text(starti + midi, startj + midj, text, ha="center", va="center")


param_alias, param_counts = zip(
    *[(alias, counts[name]) for (name, alias) in param_aliases]
)
param_cum_counts = np.cumsum(param_counts)
names_alias_counts = [
    (name, alias, param_counts[i], param_cum_counts[i])
    for i, (name, alias) in enumerate(param_aliases)
]

# %%


def plot_matrix(
    ax,
    mat,
    subtitle,
    encoding: Encoding,
    show_axes=True,
    true_inv=None,
    show_bar: bool = True,
    clim_max: float | None = None,
    clim_min: float | None = None,
):
    cmap = orange_blue().reversed()
    norm = mcolors.TwoSlopeNorm(
        vcenter=0,
    )

    im = ax.matshow(mat, cmap=cmap, norm=norm, aspect="equal")
    if show_bar:
        cbar1 = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    if show_axes:
        matshow_named_axes(ax, names_alias_counts)
    if true_inv is not None:
        matshow_norm_diff(ax, names_alias_counts, mat, true_inv)

    rows, cols = mat.shape

    ax.set_title(
        subtitle + rf", Encoding = $\text{{{encoding}}}$",
        y=1.05,
    )
    ax.set_xticks([])
    ax.set_yticks([])

    if clim_min is None:
        clim_min = np.min(mat)

    if clim_max is None:
        clim_max = np.max(mat)

    im.set_clim(clim_min, clim_max)


# %%

fig, ax = plt.subplots(1, 2)

plot_matrix(ax[0], grad_cov, "Empirical", encoding=encoding)
plot_matrix(ax[1], covariance, "Theory", encoding=encoding)
fig.tight_layout()
fig.show()

# %%

fig, ax = plt.subplots(1, 2)

plot_matrix(ax[0], covariance, "Theory", encoding=encoding)
plot_matrix(ax[1], surrogate, "Surrogate", show_axes=False, encoding=encoding)

fig.tight_layout()
fig.show()

# %%

# %%

fig, ax = plt.subplots(1, 2)


cov_theory_scaled = diag_scale(covariance)

plot_matrix(ax[0], cov_est_scaled, "Empirical", encoding=encoding)
plot_matrix(ax[1], cov_theory_scaled, "Surrogate", show_axes=False, encoding=encoding)

fig.tight_layout()
fig.show()

# %%


fig, ax = plt.subplots(1, 1)

plot_matrix(ax, covariance, "", encoding=encoding)

fig.tight_layout()
fig.show()

# %%

fig, ax = plt.subplots(1, 1)
plot_matrix(
    ax, covariance, "Surrogate inv", true_inv=grad_cov, encoding=encoding
)

# %%


fig, ax = plt.subplots(1, 1)
plot_matrix(
    ax, inv_cov_surr, "Surrogate inv", true_inv=inv_cov_theory, encoding=encoding
)

# %%

figsize = (8, 5)
fig, ax = plt.subplots(2, 3, figsize=figsize)

# mat1 = grad_cov
# mat2 = covariance
# mat3 = surrogate
# mat4 = inv_cov_emp
# mat5 = inv_cov_theory
# mat6 = inv_cov_surr

mat1 = diag_scale(grad_cov)
mat2 = diag_scale(covariance)
mat3 = diag_scale(surrogate)
mat4 = diag_scale(inv_cov_emp)
mat5 = diag_scale(inv_cov_theory)
mat6 = diag_scale(inv_cov_surr)

plot_matrix(ax[0, 0], mat1, "Empirical", encoding=encoding)
plot_matrix(ax[0, 1], mat2, "Theory", encoding=encoding)
plot_matrix(ax[0, 2], mat3, "Surrogate", show_axes=False, encoding=encoding)
plot_matrix(ax[1, 0], mat4, "Direct inv (Empirical)", encoding=encoding)
plot_matrix(ax[1, 1], mat5, "Direct inv (Theory)", encoding=encoding)
plot_matrix(ax[1, 2], mat6, "Surrogate inv", encoding=encoding)

fig.tight_layout()
# fig.subplots_adjust(hspace=0.0)

fig.show()

# %%

# mat_left = np.array(grad_cov)
# mat_right = np.array(covariance)

mat_left = np.array(cov_est_scaled)
mat_right = np.array(cov_theory_scaled)

slices = param_slices(batch_model_params, param_alias_keys, batched=True)

blocks = []
for i, (row_name, row_alias) in enumerate(param_aliases):
    for col_name, col_alias in param_aliases[: i + 1]:
        row_slice = slices[row_name]
        col_slice = slices[col_name]

        slice_left = mat_left[row_slice, col_slice]
        slice_right = mat_right[row_slice, col_slice]

        title = f"{row_alias} - {col_alias}"
        blocks.append((title, slice_left, slice_right))


def plot_mat_pair(
    blocks,
    block_index: int,
    encoding: Encoding,
    title_left: str = "",
    title_right: str = "",
    sync_clip: bool = True,
):
    name, left, right = blocks[block_index]
    fig, axes = plt.subplots(1, 2)

    if sync_clip:
        clim_min = right.min()
        clim_max = right.min()
    else:
        clim_min, clim_max = None, None

    plot_matrix(
        axes[0],
        left,
        f"{title_left}, {name}",
        show_axes=False,
        encoding=encoding,
        clim_max=clim_max,
        clim_min=clim_min,
    )

    plot_matrix(
        axes[1],
        right,
        f"{title_right}, {name}",
        show_axes=False,
        encoding=encoding,
        clim_max=clim_max,
        clim_min=clim_min,
    )

    return fig, axes


# %%

title_left = "Empirical"
title_right = "Theory"
block_index = 0

# %%
blk = 11
print(f"Plot block {blk}")

# %%


_ = plot_mat_pair(
    blocks,
    blk,
    encoding,
    title_left=title_left,
    title_right=title_right,
    sync_clip=False,
)


# %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1


# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1


# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1

# # %%

# _ = plot_mat_pair(
#     blocks, block_index, encoding, title_left=title_left, title_right=title_right
# )
# print(f"{block_index=}")
# block_index += 1
