# %%
from typing import Any
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import jax
import jax.random as jrnd
import jax.numpy as jnp
import jax.scipy as jsci
import symo.notebooks.special_matrix as sm
from symo.notebooks.plot_utils import orange_blue, default_rcparams, plot_matrix

plt.rcParams |= default_rcparams()


def matshow(
    mats: tuple[tuple[str, Any], ...] | dict[str, Any],
    xlabels: dict[str, str] | None = None,
    grid: tuple[int, ...] | None = None,
    cmap=orange_blue().reversed(),
    norm=mcolors.TwoSlopeNorm(vcenter=0),
):
    if grid is None:
        ncols = len(mats)
        nrows = 1
    else:
        nrows, ncols = grid

    assert nrows * ncols == len(mats)

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, squeeze=False)

    mats_tuple = mats.items() if isinstance(mats, dict) else mats
    for i, (key, mat) in enumerate(mats_tuple):
        mat_np = np.array(mat)
        mat_np
        ax = axes[0, i]
        ax.matshow(mat_np, cmap=cmap, norm=norm)
        if xlabels and (key in xlabels):
            ax.set_xlabel(xlabels[key])

    [ax.set_xticks([]) for ax in axes.flatten()]
    [ax.set_yticks([]) for ax in axes.flatten()]

    return fig, axes


# # %% [markdown]
# # $(\Pi \otimes I) \Sigma (\Pi \otimes I$)
# # Solution: $\alpha I + \beta (I \otimes A) + \gamma (\mathbb{1} \otimes B)$

# # %%

# d = 5
# d2 = d**2
# e = torch.eye(d)
# # e = torch.rand(d).diag()
# ep = torch.flatten(e)[:, None]
# P = ep @ ep.T
# P

# # %%

# omat = sm.orthogonal_matrix(d)[0]
# pmat = sm.permutation_matrix(d)[0]
# bmat = sm.signed_permutation_matrix(d)[0]

# omat2 = torch.kron(omat, omat)
# pmat2 = torch.kron(pmat, pmat)
# bmat2 = torch.kron(bmat, bmat)

# pbmat2 = torch.kron(pmat, bmat)

# pbmat2 = torch.kron(pmat, bmat)
# bpmat2 = torch.kron(bmat, pmat)

# # %%

# (P @ pmat2 - pmat2 @ P).max()

# # %%

# (P @ omat2 - omat2 @ P).max()

# # %%

# (P @ bmat2 - bmat2 @ P).max()

# # %%

# (P @ pbmat2 - bpmat2 @ P).max()

# # %%

# fig, ax = plt.subplots()

# ones_d = torch.ones(d)[:, None]
# ones_d_mat = ones_d @ ones_d.T
# eye_d = torch.eye(d)
# eye_d2 = torch.eye(d**2)

# a1 = 0.1
# a2 = 0.2
# a3 = 0.3
# a4 = -0.1

# comm_mat = sm.commutation_matrix(d, d)
# mat1 = (
#     a1 * eye_d2
#     + a2 * torch.kron(ones_d_mat, eye_d)
#     + a3 * torch.kron(eye_d, ones_d_mat)
#     + a4 * P
# )
# # mat = P

# b1 = 0.5
# b2 = 0.6
# b3 = 0.7
# b4 = -0.8

# mat2 = (
#     b1 * eye_d2
#     + b2 * torch.kron(ones_d_mat, eye_d)
#     + b3 * torch.kron(eye_d, ones_d_mat)
#     + b4 * P
# )
# mat2 = mat2 @ comm_mat

# mat = mat1 + mat2

# cmap = "seismic"
# cmap = orange_blue().reversed()
# norm = mcolors.TwoSlopeNorm(
#     vcenter=0,
# )
# im = ax.matshow(mat, cmap=cmap, norm=norm)
# cbar1 = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# ax.set_yticks([])
# fig.tight_layout()
# fig.show()

# # %%


# fig, ax = plt.subplots()

# ones_d = torch.ones(d)[:, None]
# ones_d_mat = ones_d @ ones_d.T
# eye_d = torch.eye(d)
# eye_d2 = torch.eye(d**2)

# mat = P

# cmap = orange_blue().reversed()
# im = ax.matshow(mat, cmap=cmap, norm=norm)
# cbar1 = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# ax.set_yticks([])
# fig.tight_layout()
# fig.show()

# # %%

# n, m = 3, 3
# kmat = sm.commutation_matrix(n, m).to_dense()
# dtype = kmat.dtype

# a = torch.arange(1, (n * m) ** 2 + 1, dtype=dtype).reshape(n * m, -1)
# a @ kmat

# # %%

# generator = torch.manual_seed(111)
# n = 4
# m = 3
# pmat = sm.permutation_matrix(n, generator=generator)[0]
# iden = torch.eye(m)

# pmat_kron_1 = torch.kron(pmat, iden)
# pmat_kron_2 = torch.kron(iden, pmat)

# blocks_1 = [(torch.arange(m**2) * float(i + 1) + 1).reshape(m, m) for i in range(n)]
# block_diag_1 = torch.block_diag(*blocks_1)

# blocks_2 = [(torch.arange(n**2) * float(i + 1) + 1).reshape(n, n) for i in range(m)]
# block_diag_2 = torch.block_diag(*blocks_2)

# ones_cols = torch.arange((n * m))[None, :].repeat(n * m, 1)
# ones_rows = torch.arange((n * m))[:, None].repeat(1, n * m)

# out1 = block_diag_1 @ pmat_kron_1
# out2 = block_diag_2 @ pmat_kron_2
# out3 = (block_diag_2 + ones_cols) @ pmat_kron_2
# out4 = (block_diag_2 + ones_rows) @ pmat_kron_2

# # %%

# generator = torch.manual_seed(111)
# n = 3
# m = 3
# k = 4
# pmat = sm.permutation_matrix(m, generator=generator)[0]
# iden = torch.eye(k)

# pmat_kron_1 = torch.kron(pmat, iden)
# pmat_kron_2 = torch.kron(iden, pmat)

# mat = torch.arange(n * m * k).reshape(m * k, n).to(torch.get_default_dtype())

# out = pmat_kron_2 @ mat

# # %%

# kmat = sm.commutation_matrix(3, 5)
# kmat.to_dense()

# # %%

# generator = torch.manual_seed(111)
# n = 3
# m = 4
# l = 5

# B = torch.arange(n * m).to(torch.get_default_dtype()) + 1
# B = B.reshape(n, m)

# ones = torch.ones(l)[None, :]

# kron_prod_1 = torch.kron(B, ones)
# kron_prod_2 = torch.kron(ones, B)

# comm_mat = sm.commutation_matrix(l, m)

# kron_prod_1 - kron_prod_2 @ comm_mat

# # %%

# n, l = 3, 5
# g1 = sm.orthogonal_matrix(n)[0]
# g2 = g1.T
# g3 = torch.eye(l)

# wvec = torch.randn((1, l))
# vvec = torch.randn((1, l))

# sol = torch.kron(torch.eye(n), wvec)
# # sol = torch.kron(torch.eye(n), wvec) + torch.kron(torch.ones((n, n)), vvec)

# sol

# # %%

# (g1 @ sol @ torch.kron(g2, g3))

# # %%

# import matplotlib.pyplot as plt
# import jax.numpy as jnp

# n = 4
# l = 6
# m = 2

# a = jnp.kron(jnp.eye(n), jnp.linspace(-m, m, l)[None, :] * 3)
# b = jnp.kron(jnp.ones((n, n)), jnp.linspace(m, -m, l)[None, :])

# mat = a + b

# fig, ax = plt.subplots(nrows=1)

# cmap = orange_blue().reversed()
# norm = mcolors.TwoSlopeNorm(
#     vcenter=0,
# )

# im = ax.matshow(mat, cmap=cmap, norm=norm)
# cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# clim = 5
# im.set_clim(-clim, clim)

# fig.show()
# # %%

# import jax.random as jrnd

# seed = 111
# key = jrnd.PRNGKey(seed)

# n = 4
# l = 5

# k1, k2 = jrnd.split(key)
# pi1 = sm.permutation_matrix(k1, n)
# pi2 = sm.permutation_matrix(k2, n)

# a = jnp.linspace(-1, 1, n * l).reshape(n, l)
# w = jnp.linspace(-1, 1, n).reshape((n, 1))
# ov = jnp.ones((l, 1))
# o = jnp.ones((n, 1))

# # %%

# # pi1.T @ ((pi1 @ w @ o.T) * (a)) - ((w @ o.T) * a)

# # pi1.T @ ((pi1 @ jnp.diag(w[:, 0]) @ pi2.T @ o @ ov.T) * (pi2 @ a)) - ((w @ ov.T) * (a))

# pi1.T @ ((pi1 @ jnp.diag(w[:, 0]) @ pi2.T) @ pi2 @ a) - ((w @ ov.T) * (a))

# # %%


# seed = 111
# key = jrnd.PRNGKey(seed)
# keys = jrnd.split(key, 3)

# v = 5
# d = 4
# d1s = [1, 2, 4]
# d2s = [4, 2, 1]

# pairs = []
# out = None
# for i, (d1, d2) in enumerate(zip(d1s, d2s)):
#     k = keys[i]
#     k1, k2 = jrnd.split(k, 2)
#     if d1 != d:
#         a = jrnd.normal(k1, (d1, d))
#     else:
#         a = jnp.eye(d)

#     b = jrnd.normal(k2, (d2, 1)).repeat(v, axis=1)
#     # a = jnp.linspace(-1, 1, d1 * d).reshape(d1, d)
#     # b = jnp.linspace(-1, 1, d2)[:, None].repeat(v, axis=1)

#     pairs.append((a, b))
#     c = jnp.kron(a, b)
#     if out is None:
#         out = c
#     else:
#         out += c

# # %%

# fig, ax = plt.subplots(nrows=1)
# mat = np.array(out)

# cmap = orange_blue().reversed()
# norm = mcolors.TwoSlopeNorm(
#     vcenter=0,
# )

# im = ax.matshow(mat, cmap=cmap, norm=norm)
# cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# clim = 3
# im.set_clim(-clim, clim)

# fig.show()

# # %%

# seed = 111
# key = jrnd.PRNGKey(seed)
# v = 5
# d = 4
# k1, k2 = jrnd.split(key, 2)

# # a = jrnd.normal(k2, (d2, 1)) @ jnp.ones((1, v))
# # b = jrnd.normal(k1, (d1, d))

# a = jnp.linspace(-1, 1, d2)[:, None] @ jnp.ones((1, v))
# b = jnp.linspace(-1, 1, d1 * d).reshape(d1, d)

# # out = jnp.kron(a, b) @ sm.commutation_matrix(v, d)
# out = sm.commutation_matrix(v, d) @ jnp.kron(b, a)

# # %%

# fig, ax = plt.subplots(nrows=1)
# mat = np.array(out)

# cmap = orange_blue().reversed()
# norm = mcolors.TwoSlopeNorm(
#     vcenter=0,
# )

# im = ax.matshow(mat, cmap=cmap, norm=norm)
# cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# clim = 1
# im.set_clim(-clim, clim)

# fig.show()

# # %%

# %%

seed = 333
n = 5
key = jrnd.PRNGKey(seed)
k1, k2 = jrnd.split(key)
g1 = sm.orthogonal_matrix(k1, n)
g2 = sm.orthogonal_matrix(k1, n)
# g1 = sm.signed_permutation_matrix(k1, n)
# g2 = sm.signed_permutation_matrix(k1, n)
# g1 = sm.permutation_matrix(k1, n)
# g2 = sm.permutation_matrix(k1, n)

g = jnp.kron(g1, g2)
# g = sm.permutation_matrix(k2, d * n)
# g = sm.signed_permutation_matrix(k2, d * n)
# g = sm.orthogonal_matrix(k2, d * n)

v = jnp.eye(n).flatten("F")
s = jnp.outer(v, v)
o = g @ s @ g.T
d = o - s

fig, axes = plt.subplots(ncols=4)
ax1, ax2, ax3, ax4 = axes

g_np = np.array(g)
s_np = np.array(s)
o_np = np.array(o)
d_np = np.array(d)

cmap = orange_blue().reversed()
norm = mcolors.TwoSlopeNorm(
    vcenter=0,
)

im1 = ax1.matshow(g_np, cmap=cmap, norm=norm)
im2 = ax2.matshow(s_np, cmap=cmap, norm=norm)
im3 = ax3.matshow(o_np, cmap=cmap, norm=norm)
im4 = ax4.matshow(d_np, cmap=cmap, norm=norm)

ax1.set_xlabel("$G$")
ax2.set_xlabel(r"$\Sigma$")
ax3.set_xlabel(r"$G \Sigma G^T$")
ax4.set_xlabel(r"$\Sigma - G \Sigma G^T$")

[ax.set_xticks([]) for ax in axes]
[ax.set_yticks([]) for ax in axes]

fig.show()

# %%

seed = 333
n = 6
key = jrnd.PRNGKey(seed)
key, _ = jrnd.split(key)
g1 = sm.orthogonal_matrix(key, n)
g2 = sm.orthogonal_matrix(key, n)

# g1 = sm.signed_permutation_matrix(k1, n)
# g2 = sm.signed_permutation_matrix(k1, n)
# g1 = sm.permutation_matrix(k1, n)
# g2 = sm.permutation_matrix(k1, n)

g = jnp.kron(g1, g2)
# g = sm.permutation_matrix(k2, d * n)
# g = sm.signed_permutation_matrix(k2, d * n)
# g = sm.orthogonal_matrix(k2, d * n)

k = sm.commutation_matrix(n, n)
v = jnp.eye(n).flatten("F")
s = jnp.outer(v, v) @ k
o = g @ s @ g.T
d = o - s

fig, axes = plt.subplots(ncols=4)
ax1, ax2, ax3, ax4 = axes

g_np = np.array(g)
s_np = np.array(s)
o_np = np.array(o)
d_np = np.array(d)

cmap = orange_blue().reversed()
norm = mcolors.TwoSlopeNorm(
    vcenter=0,
)

im1 = ax1.matshow(g_np, cmap=cmap, norm=norm)
im2 = ax2.matshow(s_np, cmap=cmap, norm=norm)
im3 = ax3.matshow(o_np, cmap=cmap, norm=norm)
im4 = ax4.matshow(d_np, cmap=cmap, norm=norm)

ax1.set_xlabel("$G$")
ax2.set_xlabel(r"$\Sigma$")
ax3.set_xlabel(r"$G \Sigma G^T$")
ax4.set_xlabel(r"$\Sigma - G \Sigma G^T$")

[ax.set_xticks([]) for ax in axes]
[ax.set_yticks([]) for ax in axes]

fig.show()

# %%

n = 5
key, _ = jrnd.split(key, 2)
o = sm.orthogonal_matrix(key, n**2)
k = sm.commutation_matrix(n, n)
v = o @ k @ o.T
d = k - v

fig, axes = plt.subplots(ncols=4)
ax1, ax2, ax3, ax4 = axes

o_np = np.array(o)
k_np = np.array(k)
v_np = np.array(v)
d_np = np.array(d)

cmap = orange_blue().reversed()
norm = mcolors.TwoSlopeNorm(
    vcenter=0,
)

im1 = ax1.matshow(o_np, cmap=cmap, norm=norm)
im2 = ax2.matshow(k_np, cmap=cmap, norm=norm)
im3 = ax3.matshow(v_np, cmap=cmap, norm=norm)
im4 = ax4.matshow(d_np, cmap=cmap, norm=norm)

ax1.set_xlabel("$G$")
ax2.set_xlabel(r"$K^{N,N}$")
ax3.set_xlabel(r"$G K^{N,N} G^T$")
ax4.set_xlabel(r"$\Sigma - G K^{N,N} G^T$")

[ax.set_xticks([]) for ax in axes]
[ax.set_yticks([]) for ax in axes]

fig.show()


# %%

n = 5
key, _ = jrnd.split(key, 2)
o1 = sm.orthogonal_matrix(key, n)
o = jnp.kron(o1, o1)
k = sm.commutation_matrix(n, n)
v = o @ k @ o.T
d = k - v

fig, axes = plt.subplots(ncols=4)
ax1, ax2, ax3, ax4 = axes

o_np = np.array(o)
k_np = np.array(k)
v_np = np.array(v)
d_np = np.array(d)

cmap = orange_blue().reversed()
norm = mcolors.TwoSlopeNorm(
    vcenter=0,
)

im1 = ax1.matshow(o_np, cmap=cmap, norm=norm)
im2 = ax2.matshow(k_np, cmap=cmap, norm=norm)
im3 = ax3.matshow(v_np, cmap=cmap, norm=norm)
im4 = ax4.matshow(d_np, cmap=cmap, norm=norm)

ax1.set_xlabel("$G$")
ax2.set_xlabel(r"$K^{N,N}$")
ax3.set_xlabel(r"$G K^{N,N} G^T$")
ax4.set_xlabel(r"$\Sigma - G K^{N,N} G^T$")

[ax.set_xticks([]) for ax in axes]
[ax.set_yticks([]) for ax in axes]

fig.show()

# %%

n = 5
key, _ = jrnd.split(key, 2)
o1 = sm.orthogonal_matrix(key, n)
b1 = sm.signed_permutation_matrix(key, n)
o = jnp.kron(b1, o1)
k = sm.commutation_matrix(n, n)
bbI = jnp.outer(jnp.eye(n).reshape(-1), jnp.eye(n).reshape(-1))
eye = jnp.eye(n * n)
diff = lambda s: s - o @ s @ o.T

fig, axes = plt.subplots(3, 4)
o_np = np.array(o)
k_np = np.array(k)
v_np = np.array(v)
d_np = np.array(d)

cmap = orange_blue().reversed()
norm = mcolors.TwoSlopeNorm(
    vcenter=0,
)

axes[0, 0].matshow(o, cmap=cmap, norm=norm)
im1 = axes[1, 0].matshow(eye, cmap=cmap, norm=norm)
im2 = axes[1, 1].matshow(k, cmap=cmap, norm=norm)
im3 = axes[1, 2].matshow(bbI, cmap=cmap, norm=norm)
im4 = axes[1, 3].matshow(bbI * k, cmap=cmap, norm=norm)
axes[2, 0].matshow(diff(eye), cmap=cmap, norm=norm)
axes[2, 1].matshow(diff(k), cmap=cmap, norm=norm)
axes[2, 2].matshow(diff(bbI), cmap=cmap, norm=norm)
axes[2, 3].matshow(diff(bbI * k), cmap=cmap, norm=norm)

axes[0, 0].set_xlabel("$G$")
axes[1, 0].set_xlabel(r"$\delta_{ij}\delta_{kl}$")
axes[1, 1].set_xlabel(r"$\delta_{il}\delta_{jk}$")
axes[1, 2].set_xlabel(r"$\delta_{ik}\delta_{jl}$")
axes[1, 3].set_xlabel(r"$\delta_{ik}\delta_{jl}\delta_{il}$")


[ax.set_xticks([]) for ax in axes[0]]
[ax.set_yticks([]) for ax in axes[0]]

fig.show()

# %%
# [Pi_1, Pi_2] [Pi_2, I]
cmap = orange_blue().reversed()
norm = mcolors.TwoSlopeNorm(
    vcenter=0,
)

n, m, k = 3, 5, 2
key, subkey1, subkey2 = jax.random.split(key, 3)
w1 = jax.random.normal(subkey1, (1, n))
w2 = jax.random.normal(subkey2, (1, n))
ones = jnp.ones((k, 1))

fig, axs = plt.subplots(3, 2)
r1 = jnp.kron(jnp.kron(ones, w1), jnp.eye(m)) @ sm.commutation_matrix(n, m)
r2 = jnp.kron(jnp.ones((m, m)), jnp.kron(ones, w2))
r = r1 + r2
rng, subkey1, subkey2 = jax.random.split(key, 3)
s1 = sm.permutation_matrix(subkey1, k)
s2 = sm.permutation_matrix(subkey2, m)

axs[0, 0].imshow(jnp.kron(s1, s2), cmap=cmap, norm=norm)
axs[0, 0].set_xlabel(r"$G=\Pi_1 \otimes \Pi_2$")
axs[0, 1].imshow(jnp.kron(s2, jnp.eye(n)), cmap=cmap, norm=norm)
axs[0, 1].set_xlabel(r"$G=\Pi_2 \otimes I$")
axs[1, 0].imshow(r1, cmap=cmap, norm=norm)
axs[1, 0].set_xlabel(r"$\Sigma_1=\mathbb{1}_M \otimes 1_M \otimes \omega_N^T$")
axs[1, 1].imshow(r2, cmap=cmap, norm=norm)
axs[1, 1].set_xlabel(r"$\Sigma_2=(1_M \otimes \omega_N^T \otimes I_M)K^{N,M}$")
axs[2, 0].imshow(
    r1 - jnp.kron(s1, s2) @ r1 @ jnp.kron(s2, jnp.eye(n)).T, cmap=cmap, norm=norm
)
axs[2, 0].set_xlabel(r"$\Sigma_1 - (\Pi_1 \otimes \Pi_2) \Sigma_1 (\Pi_2 \otimes I)^T$")
axs[2, 1].imshow(
    r2 - jnp.kron(s1, s2) @ r2 @ jnp.kron(s2, jnp.eye(n)).T, cmap=cmap, norm=norm
)
axs[2, 1].set_xlabel(r"$\Sigma_2 - (\Pi_1 \otimes \Pi_2) \Sigma_2 (\Pi_2 \otimes I)^T$")
[[a.set_xticks([]) for a in ax] for ax in axs]
[[a.set_yticks([]) for a in ax] for ax in axs]

# %%
# [Pi_1, Pi_2] [I, Pi_2]
n, m, k = 3, 5, 2

key, subkey1, subkey2 = jax.random.split(key, 3)
w1 = jax.random.normal(subkey1, (1, n))
w2 = jax.random.normal(subkey2, (1, n))
ones = jnp.ones((k, 1))

fig, axs = plt.subplots(3, 2)
r1 = jnp.kron(jnp.kron(ones, w1), jnp.eye(m))
r2 = jnp.kron(jnp.kron(ones, w2), jnp.ones((m, m)))
r = r1 + r2
rng, subkey1, subkey2 = jax.random.split(key, 3)
s1 = sm.permutation_matrix(subkey1, k)
s2 = sm.permutation_matrix(subkey2, m)

axs[0, 0].imshow(jnp.kron(s1, s2), cmap=cmap, norm=norm)
axs[0, 0].set_xlabel(r"$G=\Pi_1 \otimes \Pi_2$")
axs[0, 1].imshow(jnp.kron(jnp.eye(n), s2), cmap=cmap, norm=norm)
axs[0, 1].set_xlabel(r"$G=I \otimes\Pi_2$")
axs[1, 0].imshow(r1, cmap=cmap, norm=norm)
axs[1, 0].set_xlabel(r"$\Sigma_1=1_M \otimes \omega_N^T \otimes I_M $")
axs[1, 1].imshow(r2, cmap=cmap, norm=norm)
axs[1, 1].set_xlabel(r"$\Sigma_2=(1_M \otimes \omega_N^T \otimes \mathbb{1}_M)$")
axs[2, 0].imshow(
    r1 - jnp.kron(s1, s2) @ r1 @ jnp.kron(jnp.eye(n), s2).T, cmap=cmap, norm=norm
)
axs[2, 0].set_xlabel(r"$\Sigma_1 - (\Pi_1 \otimes \Pi_2) \Sigma_1 (I \otimes\Pi_2)^T$")
axs[2, 1].imshow(
    r2 - jnp.kron(s1, s2) @ r2 @ jnp.kron(jnp.eye(n), s2).T, cmap=cmap, norm=norm
)
axs[2, 1].set_xlabel(r"$\Sigma_2 - (\Pi_1 \otimes \Pi_2) \Sigma_2 (I \otimes\Pi_2)^T$")
[[a.set_xticks([]) for a in ax] for ax in axs]
[[a.set_yticks([]) for a in ax] for ax in axs]
# %%
# [Pi, Pi] [Pi, I]
n, m = 4, 3

key, subkey1, subkey2, subkey3, subkey4, subkey5 = jax.random.split(key, 6)
w1 = jax.random.normal(subkey1, (1, n))
w2 = jax.random.normal(subkey2, (1, n))
w3 = jax.random.normal(subkey3, (1, n))
w4 = jax.random.normal(subkey4, (1, n))
w5 = jax.random.normal(subkey5, (1, n))
ones = jnp.ones((m, 1))

fig, axs = plt.subplots(3, 5)
bbI = jnp.outer(jnp.eye(m).reshape(-1), jnp.eye(m).reshape(-1))

all_ones = jnp.kron(jnp.ones((m, m)), ones)
delta_jk = sm.commutation_matrix(m, m) @ jnp.kron(jnp.eye(m), ones)
delta_ij = jnp.kron(jnp.eye(m), ones)
delta_ik = bbI @ jnp.kron(jnp.eye(m), ones)
r1 = jnp.kron(delta_ij, w1)
r2 = jnp.kron(delta_ik, w2)
r3 = jnp.kron(delta_jk, w3)
r4 = jnp.kron(all_ones, w4)
r5 = jnp.kron(delta_ij * delta_ik, w5)

r = r1 + r2 + r3 + r4 + r5
rng, subkey1, subkey2 = jax.random.split(key, 3)
s1 = sm.permutation_matrix(subkey1, m)

axs[0, 0].imshow(jnp.kron(s1, s1), cmap=cmap, norm=norm)
axs[0, 0].set_xlabel(r"$G=\Pi_1 \otimes \Pi_1$")
axs[0, 1].imshow(jnp.kron(s1, jnp.eye(n)), cmap=cmap, norm=norm)
axs[0, 1].set_xlabel(r"$G=\Pi_1 \otimes I$")
axs[0, 2].imshow(r, cmap=cmap, norm=norm)
axs[1, 0].imshow(r1, cmap=cmap, norm=norm)
axs[1, 0].set_xlabel(r"$\Sigma_1=(1_N \otimes \omega_L^T \otimes I_N)K^{L,N}$")
axs[1, 1].imshow(r2, cmap=cmap, norm=norm)
axs[1, 1].set_xlabel(r"$\Sigma_2=\mathbb{1}_N \otimes 1_N \otimes \omega_L^T$")
axs[1, 2].imshow(r3, cmap=cmap, norm=norm)
axs[1, 2].set_xlabel(r"$\Sigma_3=I_N \otimes 1_N \otimes \omega_L^T$")
axs[1, 3].imshow(r4, cmap=cmap, norm=norm)
axs[1, 3].set_xlabel(
    r"$\Sigma_4=\mathbb{I}^{N,N} (I_N \otimes 1_N \otimes \omega_L^T)$"
)
axs[1, 4].imshow(r5, cmap=cmap, norm=norm)
axs[1, 4].set_xlabel(r"$\Sigma_5=\delta_{ij}\delta_{ik}\omega_l$")

axs[2, 0].imshow(
    r1 - jnp.kron(s1, s1) @ r1 @ jnp.kron(s1, jnp.eye(n)).T, cmap=cmap, norm=norm
)
axs[2, 0].set_xlabel(r"$\Sigma_1 - (\Pi_1 \otimes \Pi_1) \Sigma_1 (\Pi_1 \otimes I)^T$")
axs[2, 1].imshow(
    r2 - jnp.kron(s1, s1) @ r2 @ jnp.kron(s1, jnp.eye(n)).T, cmap=cmap, norm=norm
)
axs[2, 1].set_xlabel(r"$\Sigma_2 - (\Pi_1 \otimes \Pi_1) \Sigma_2 (\Pi_1 \otimes I)^T$")
axs[2, 2].imshow(
    r3 - jnp.kron(s1, s1) @ r3 @ jnp.kron(s1, jnp.eye(n)).T, cmap=cmap, norm=norm
)
axs[2, 2].set_xlabel(r"$\Sigma_3 - (\Pi_1 \otimes \Pi_1) \Sigma_3 (\Pi_1 \otimes I)^T$")
axs[2, 3].imshow(
    r4 - jnp.kron(s1, s1) @ r4 @ jnp.kron(s1, jnp.eye(n)).T, cmap=cmap, norm=norm
)
axs[2, 3].set_xlabel(r"$\Sigma_4 - (\Pi_1 \otimes \Pi_1) \Sigma_4 (\Pi_1 \otimes I)^T$")
axs[2, 4].imshow(
    r5 - jnp.kron(s1, s1) @ r5 @ jnp.kron(s1, jnp.eye(n)).T, cmap=cmap, norm=norm
)
axs[2, 4].set_xlabel(r"$\Sigma_5 - (\Pi_1 \otimes \Pi_1) \Sigma_5 (\Pi_1 \otimes I)^T$")


[[a.set_xticks([]) for a in ax] for ax in axs]
[[a.set_yticks([]) for a in ax] for ax in axs]

# %%

n = 4
key, key2, key3, key4 = jrnd.split(key, 4)
k = sm.commutation_matrix(n, n)
ones = jnp.ones((n**2, n**2))
b = jnp.outer(jnp.eye(n).flatten(), jnp.eye(n).flatten())

mats = {"a": b @ k, "b": k @ b, "c": b @ ones, "d": ones @ b}
matshow(mats)

# %%

v1 = jnp.eye(n).flatten()
v2 = jnp.ones((n, n)).flatten()

mats = {"e": jnp.outer(v1, v2), "f": jnp.outer(v2, v1)}
matshow(mats)

b.shape
# %%

# [Pi, Pi] [Pi, Pi]
n = 4

eye = jnp.eye(n * n)
k = sm.commutation_matrix(n, n)
bb1 = jnp.outer(jnp.eye(n).reshape(-1), jnp.eye(n).reshape(-1))
ones = jnp.ones((n * n, n * n))
delta_ij = jnp.kron(jnp.diag(jnp.ones(n)), jnp.ones((n, n)))
delta_kl = jnp.kron(jnp.ones((n, n)), jnp.diag(jnp.ones(n)))
delta_il = delta_ij @ k
delta_jk = delta_kl @ k
delta_jl = delta_ij @ bb1
delta_ik = bb1 @ delta_kl
delta_ij_delta_ik = delta_ij * delta_jk
delta_ij_delta_il = delta_ij * delta_il
delta_ik_delta_il = delta_il * delta_kl
delta_jk_delta_jl = delta_jk * delta_kl
delta_ij_delta_jk_delta_jl = delta_ij * delta_jk * delta_jl

factors = [
    ones,
    eye,
    k,
    bb1,
    delta_ij,
    delta_kl,
    delta_il,
    delta_jk,
    delta_jl,
    delta_ik,
    delta_ij_delta_ik,
    delta_ij_delta_il,
    delta_ik_delta_il,
    delta_jk_delta_jl,
    delta_ij_delta_jk_delta_jl,
]

name = [
    "1",
    r"$\delta_{ij}\delta_{kl}$",
    r"$\delta_{il}\delta_{kj}$",
    r"$\delta_{ik}\delta_{jl}$",
    r"$\delta_{ij}$",
    r"$\delta_{kl}$",
    r"$\delta_{il}$",
    r"$\delta_{jk}$",
    r"$\delta_{jl}$",
    r"$\delta_{ik}$",
    r"$\delta_{ij}\delta_{ik}$",
    r"$\delta_{ij}\delta_{il}$",
    r"$\delta_{ik}\delta_{il}$",
    r"$\delta_{jk}\delta_{jl}$",
    r"$\delta_{ij}\delta_{jk}\delta_{kl}$",
]

w = jax.random.normal(rng, shape=(15,))


fig, axs = plt.subplots(3, 5)
for i in range(3):
    for j in range(5):
        which = i * 5 + j
        axs[i, j].imshow(factors[which] * w[which], cmap=cmap, norm=norm)
        axs[i, j].set_xlabel(name[which])

[[a.set_xticks([]) for a in ax] for ax in axs]
[[a.set_yticks([]) for a in ax] for ax in axs]

# %%
# [Pi, Pi] [Pi, Pi]
n = 4

rng, subkey1, subkey2 = jax.random.split(key, 3)
s1 = sm.permutation_matrix(subkey1, n)

k = sm.commutation_matrix(n, n)
ones = jnp.ones((n, n))
eyes = jnp.eye(n)
bbI = jnp.outer(jnp.eye(n).reshape(-1), jnp.eye(n).reshape(-1))

m1 = jnp.kron(eyes, ones) * (jnp.kron(eyes, ones) @ k)
m2 = jnp.kron(ones, eyes) * (jnp.kron(ones, eyes) @ k)
m3 = jnp.kron(eyes, ones) * (bbI @ jnp.kron(ones, eyes))
m4 = (bbI @ jnp.kron(ones, eyes)) * (jnp.kron(eyes, ones) @ k)
m5 = jnp.kron(eyes, ones) * (jnp.kron(eyes, ones) @ k) * (bbI @ jnp.kron(ones, eyes))

fig, axs = plt.subplots(2, 5)
axs[0, 0].imshow(m1, cmap=cmap, norm=norm)
axs[0, 0].set_xlabel(r"$\Sigma_1 = \delta_{ij} \delta_{il}$")
axs[0, 1].imshow(m2, cmap=cmap, norm=norm)
axs[0, 1].set_xlabel(r"$\Sigma_2 = \delta_{kl} \delta_{jk}$")
axs[0, 2].imshow(m3, cmap=cmap, norm=norm)
axs[0, 2].set_xlabel(r"$\Sigma_3 = \delta_{ij} \delta_{ik}$")
axs[0, 3].imshow(m4, cmap=cmap, norm=norm)
axs[0, 3].set_xlabel(r"$\Sigma_4 = \delta_{ik} \delta_{il}$")
axs[0, 4].imshow(m5, cmap=cmap, norm=norm)
axs[0, 4].set_xlabel(r"$\Sigma_5 = \delta_{ij} \delta_{ik} \delta_{il}$")


axs[1, 0].imshow(m1 - jnp.kron(s1, s1) @ m1 @ jnp.kron(s1, s1).T, cmap=cmap, norm=norm)
axs[1, 0].set_xlabel(
    r"$\Sigma_1 - (\Pi_1 \otimes \Pi_1) \Sigma_1 (\Pi_1 \otimes \Pi_1)^T$"
)
axs[1, 1].imshow(m2 - jnp.kron(s1, s1) @ m2 @ jnp.kron(s1, s1).T, cmap=cmap, norm=norm)
axs[1, 1].set_xlabel(
    r"$\Sigma_2 - (\Pi_1 \otimes \Pi_1) \Sigma_2 (\Pi_1 \otimes \Pi_1)^T$"
)
axs[1, 2].imshow(m3 - jnp.kron(s1, s1) @ m3 @ jnp.kron(s1, s1).T, cmap=cmap, norm=norm)
axs[1, 2].set_xlabel(
    r"$\Sigma_3 - (\Pi_1 \otimes \Pi_1) \Sigma_3 (\Pi_1 \otimes \Pi_1)^T$"
)
axs[1, 3].imshow(m4 - jnp.kron(s1, s1) @ m4 @ jnp.kron(s1, s1).T, cmap=cmap, norm=norm)
axs[1, 3].set_xlabel(
    r"$\Sigma_4 - (\Pi_1 \otimes \Pi_1) \Sigma_4 (\Pi_1 \otimes \Pi_1)^T$"
)

axs[1, 4].imshow(m5 - jnp.kron(s1, s1) @ m5 @ jnp.kron(s1, s1).T, cmap=cmap, norm=norm)
axs[1, 4].set_xlabel(
    r"$\Sigma_5 - (\Pi_1 \otimes \Pi_1) \Sigma_5 (\Pi_1 \otimes \Pi_1)^T$"
)

# %%

from symo.group import O, B, Eq
from symo.factor import factor_from_cov

key, *_ = jrnd.split(key, 3)

n = 4
m = 4

group = O["N", n]
param = (group, group)
eq = Eq[param, param]


fig.show()

# %%

key, key1, key2, key3 = jrnd.split(key, 4)
n = 2
o1 = sm.orthogonal_matrix(key1, n)[0]
so1 = sm.so_matrix(key2, n)
so2 = sm.so_matrix(key3, n)

a = float(jnp.sum(jnp.abs((so1 @ so2 - so2 @ so1))))
b = float(jnp.sum(jnp.abs((so1 @ o1 - o1 @ so1))))
c = float(jnp.sum(jnp.abs((so2 @ o1 - o1 @ so2))))

print(f"{a=:0.3e}, {b=:0.3e}, {c=:0.3e}")

# %%

from symo.notebooks.attention import PositionalEncoding

n = 6
s = 5
pe = PositionalEncoding(n, s, encoding="sin").pe[0].T

so = [sm.so_matrix(key2, 2) for _ in range(n // 2)]
so_blk = jsci.linalg.block_diag(*so)

jnp.abs(so_blk @ pe - pe)

# %%

n = 5
m = 3

