# %%

from dataclasses import dataclass, asdict
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jrnd
import symo.special_matrix as sm

from symo.notebooks.plot_utils import default_rcparams
import matplotlib.pyplot as plt

plt.rcParams |= default_rcparams(dpi=300)

# %% [markdown]

n = 4
n2 = n**2

eye = jnp.eye(n)
matones = sm.matones(n)

eye2 = jnp.eye(n2)
matones2 = sm.matones(n2)
gridones2 = sm.gridones(n)
comm = sm.commutation_matrix(n, n)

def proj_diff(p):
    return p @ p - p

# %%

eye_ones = jnp.kron(eye, matones)
ones_eye = jnp.kron(matones, eye)
eye_ones_comm = eye_ones @ comm
ones_eye_comm = ones_eye @ comm

eye_ones_double = eye_ones * eye_ones_comm
eye_ones_double_twisted = eye_ones * ones_eye_comm

# %%

m = ((1/n) ** 2)
p = m * matones2
proj_diff(p)

# %%

m = ((1/n) ** 2)
p = (eye2 - m * matones2)
proj_diff(p)

# %%

m = 1/n
p = m * eye_ones
proj_diff(p)

# %%

m = 1/n
p = m * comm @ eye_ones_comm
p @ p - p
proj_diff(p)

# %%

