# %%

from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jrnd

from symo.factor import Eq, S, factor_from_param, factor_from_cov
from symo.notebooks.plot_utils import default_rcparams, plot_matrix
import symo.special_matrix as sm

import matplotlib.pyplot as plt

plt.rcParams |= default_rcparams(dpi=500)

# %%

n = 5

# # %%

# I = jnp.eye(n * n)
# K = sm.commutation_matrix(n, n)
# U = jnp.outer(jnp.eye(n).reshape(-1), jnp.eye(n).reshape(-1))
# J = jnp.ones((n * n, n * n))

# A = jnp.kron(jnp.eye(n), jnp.ones((n, n)))
# B = jnp.kron(jnp.ones((n, n)), jnp.eye(n))
# AK = A @ K
# BK = B @ K  # K @ A
# AU = A @ U
# UB = U @ B  # U @ A
# A_UB = A * UB  # A * UA
# A_AK = A * AK
# UB_AK = UB * AK  # UA * AK
# BK_AU = BK * AU  # KA * AU
# A_BK_AU = A * BK * AU  # I * U

# %%

I = jnp.eye(n * n)
K = sm.commutation_matrix(n, n)
U = jnp.outer(jnp.eye(n).reshape(-1), jnp.eye(n).reshape(-1))
J = jnp.ones((n * n, n * n))

# Alternative
A = jnp.kron(jnp.eye(n), jnp.ones((n, n)))
B = jnp.kron(jnp.ones((n, n)), jnp.eye(n))
AK = A @ K
KA = K @ A
AU = A @ U
UA = U @ A
A_UA = A * UA
A_AK = A * AK
UA_AK = UA * AK
KA_AU = KA * AU
I_U = I * U

# %%

basis = {
    "$I$": I,
    "$K$": K,
    "$U$": U,
    "$J$": J,
    "$A$": A,
    "$B$": B,
    "$AK$": AK,
    "$KA$": KA,
    "$AU$": AU,
    "$UA$": UA,
    r"$A \cdot UA$": A_UA,
    r"$A \cdot AK$": A_AK,
    r"$UA \cdot AK$": UA_AK,
    r"$KA \cdot AU$": KA_AU,
    r"$I \cdot U$": I_U,
}

# %%

k = 16
d = int(np.sqrt(k))
fig, axes = plt.subplots(nrows=d, ncols=d)
axes = axes.flatten()

sum_matrix = jnp.zeros_like(I)

for i, (name, mat) in enumerate(basis.items()):
    ax = axes[i]
    sum_matrix += mat
    plot_matrix(
        fig,
        ax,
        mat,
        name,
    )

axes[-1].axis("off")

fig.tight_layout()

# %%

P = AU
fig, ax = plt.subplots()
plot_matrix(fig, ax, P, "")
fig.tight_layout()
fig.show()

# %%

a0 = 1

T = (
    I
    + K
    + J / (n**2)
    + A / n
    + B / n
    + AK / n
    + KA / n
    # + U / n
    # + AU
    # + UB
    # + A_UB
    # + A_AK
    # + UB_AK
    # + BK_AU
    # + A_BK_AU
)

# %%

P = UA
sm.is_projector(P)

# %%

jnp.sum(T @ T - T)

# %%

P = AK / n
sm.is_projector(P)

# %%

# %%

P = A / n
sm.is_projector(P)

# %%

P = 0.5 * (I + K)
sm.is_projector(P)

# %%

P = 0.5 * (I - K)
sm.is_projector(P)

# %%

P = 0.5 * (I + U / n) - 0.5 * (I - U / n)
sm.is_projector(P)

# %%

T @ T - T

# %%

jnp.linalg.eigvalsh(T)

# %% [markdown]
# # Projector matrix
# Case when (S ⊗ S ⊗ S ⊗ S) vec(T) = vec(T)

# %%

fig, ax = plt.subplots()
plot_matrix(fig, ax, T, "Intertwiner")

# %%

seed = 222
key = jrnd.PRNGKey(seed)
batch_size = int(1e4)
keys = jrnd.split(key, batch_size)

# %%

n = 5


def kron4(k):
    s = sm.permutation_matrix(k, n)
    s4 = jnp.kron(jnp.kron(s, s), jnp.kron(s, s))
    # s4 = jnp.kron(s, s)
    return s4


Ss = jnp.mean(jax.vmap(kron4)(keys), axis=0)

# %%

fig, ax = plt.subplots()
plot_matrix(fig, ax, Ss, "")

# %%

s_evals, s_evecs = jnp.linalg.eigh(Ss)

# %%

jnp.sum(jnp.abs(Ss @ Ss - Ss))

# %%

k = 15
v = s_evecs[:, -k:]
transpose = False
if transpose:
    v = v.reshape([n**2, n**2, k])
    w = jnp.transpose(v, (2, 1, 0))
    wp = jnp.sum(w, axis=0)
else:
    w = v.reshape((n * n, n * n, k))
    wp = jnp.sum(w, axis=-1)

# %%

wp_evals, wp_evecs = jnp.linalg.eigh(wp)
wp_evals

# %%

fig, ax = plt.subplots()
plot_matrix(fig, ax, wp, "Projector S ⊗ S ⊗ S ⊗ S")

# %% [markdown]
# # PSD and rank-1 matrices

p = 8
d = p**2

# seed = 846
# seed = 276
seed = 786
# seed = 876
# seed = np.random.random_integers(0, 1000)
key, other_key = jrnd.split(jrnd.PRNGKey(seed), 2)

num_samples = 100000
other_keys = jrnd.split(other_key, num_samples + 1)
sample_keys, other_key = other_keys[:-1], other_keys[-1]

# %% [markdown]
# PSD case

# %%

A = sm.positive_definite(key, d)
evals, evecs = jnp.linalg.eigh(A)
evals = jnp.geomspace(1e-7, 1.5, d)

# %%

A = evecs @ (evals[:, None] * evecs.T)

# %%

fig, ax = plt.subplots()
plot_matrix(fig, ax, A, "PSD matrix")

# %%


def avg_step(key, mat):
    nn = mat.shape[0]
    n = int(jnp.sqrt(nn))
    s = sm.permutation_matrix(key, n)
    ss = jnp.kron(s, s)
    avg_mat = ss @ mat @ ss.T
    return avg_mat


def is_psd(key, mat):
    v = jrnd.normal(key, shape=(mat.shape[-1], 1)) / 100
    mat_v = mat @ v
    inprod = jnp.sum(mat_v**2)
    return inprod


# %%

avg_mat_map = jax.vmap(partial(avg_step, mat=A))

# %%

A_avg = jnp.mean(avg_mat_map(sample_keys), axis=0)

# %%

avg_A_evals, avg_A_evecs = jnp.linalg.eigh(A_avg)
avg_A_evals

# %%

fig, ax = plt.subplots()
plot_matrix(fig, ax, A_avg, "Group avg PSD matrix")

# %% [markdown]
# Factors of PSD matrix

# %%

Sg = (S["N", p], S["N", p])
eq = Eq[Sg, Sg]()

factor_A = factor_from_cov(eq, A_avg)
cov_A = factor_A.cov()
surr_A = factor_A.cov(surrogate=True)

# %%

A_evals = jnp.linalg.eigvalsh(A_avg)
A_y = jnp.ones_like(A_evals)

cov_A_evals = jnp.linalg.eigvalsh(cov_A)
cov_A_y = jnp.ones_like(cov_A_evals)

surr_A_evals = jnp.linalg.eigvalsh(surr_A)
surr_A_y = jnp.ones_like(surr_A_evals)


evals_y = (
    ("emp", (A_evals, A_y)),
    ("Full", (cov_A_evals, cov_A_y)),
    ("Surr", (surr_A_evals, surr_A_y)),
)

# %%

fig, ax = plt.subplots()

evals_y_rev = list(reversed(evals_y))

ax_kwargs = dict(linewidth=0.2, alpha=0.8)
for _, (xs, _) in evals_y_rev[:1]:
    line = ax.axvline(xs[0], **ax_kwargs)
    color = line.get_color()
    for x in xs[1:]:
        line = ax.axvline(x, **ax_kwargs)

ax_kwargs = dict(s=10, marker="o", alpha=0.5)
for i, (label, (x, y)) in enumerate(evals_y_rev):
    ax.scatter(x, y * i, label=label, **ax_kwargs)

# ax.set_xscale("log")
ax.set_xlim(left=0)
ax.set_yticks([])
ax.set_ylabel("Matrix Source")
ax.set_xlabel("Eigenvalues")
ax.set_title("Eigenvalues of group averaged PSD matrix")
ax.legend()

fig.show()

# %%


# %% [markdown]
# Rank-1 matrices

# %%

key, other_key = jrnd.split(other_key, 2)
v = jrnd.normal(key, shape=(d, 1))
R = v @ v.T

# %%

fig, ax = plt.subplots()
plot_matrix(fig, ax, R, "Rank-1 matrix")

# %%

avg_rank1_map = jax.vmap(partial(avg_step, mat=R))

# %%

rank1_avg = jnp.mean(avg_rank1_map(sample_keys), axis=0)

# %%

avg_rank1_evals, avg_rank1_evecs = jnp.linalg.eigh(rank1_avg)
avg_rank1_evals

# %%

fig, ax = plt.subplots()
plot_matrix(fig, ax, rank1_avg, "Group avg rank-1 matrix")

# %%

is_rank1_psd = partial(is_psd, mat=rank1_avg)
is_rank1_psd_map = jax.vmap(is_rank1_psd)

# %%

num_vecs = int(1e6)
vec_keys = jrnd.split(other_key, num_vecs)
vec_keys, other_key = vec_keys[:-1], vec_keys[-1]

jnp.all(is_rank1_psd_map(vec_keys) > 0)

# %% [markdown]
# Factors of rank-1 matrix, $v v^T$

# %%

factor_from_v = factor_from_param(eq, (v.reshape(p, p)[None], v.reshape(p, p)[None]))
factor_from_R = factor_from_cov(eq, R)

# %%

avg_cov_R = factor_from_R.cov()
avg_cov_rank1 = factor_from_v.cov()

# %%

fig, (ax1, ax2, ax3) = plt.subplots(ncols=3)
plot_matrix(fig, ax1, rank1_avg, "Empirical G-invariant rank-1 matrix of $v v^T$")
plot_matrix(fig, ax2, avg_cov_R, "G-invariant rank-1 matrix from $v v^T$")
plot_matrix(fig, ax3, avg_cov_rank1, "G-invariant rank-1 matrix from pair of $v$")

fig.tight_layout()
fig.show()

# %%

avg_cov_R - avg_cov_rank1

# %%

avg_rank1_evals

# %%

shape = (4,) * 4
avg_surr_R = factor_from_R.cov(shape=shape, surrogate=True)
avg_surr_rank1 = factor_from_v.cov(shape=shape, surrogate=True)

# %%

fig, (ax1, ax2) = plt.subplots(ncols=2)
plot_matrix(fig, ax1, avg_surr_R, "G-invariant rank-1 surrogate matrix from $v v^T$")
plot_matrix(
    fig, ax2, avg_surr_rank1, "G-invariant rank-1 surrogate matrix from pair of $v$"
)


# %%

rank1_evals = jnp.linalg.eigvalsh(rank1_avg)
rank1_y = jnp.ones_like(rank1_evals)

surr_R_evals = jnp.linalg.eigvalsh(avg_surr_R)
surr_R_y = jnp.ones_like(surr_R_evals)

surr_rank1_evals = jnp.linalg.eigvalsh(avg_surr_rank1)
surr_rank1_y = jnp.ones_like(surr_rank1_evals)

cov_R_evals = jnp.linalg.eigvalsh(avg_cov_R)
cov_R_y = jnp.ones_like(cov_R_evals)

cov_rank1_evals = jnp.linalg.eigvalsh(avg_cov_rank1)
cov_rank1_y = jnp.ones_like(cov_rank1_evals)

evals_y = (
    ("emp", (rank1_evals, rank1_y)),
    ("Full (source: $v v^T$)", (cov_R_evals, cov_R_y)),
    ("Full (source: pairs $v$)", (cov_rank1_evals, cov_rank1_y)),
    ("Surr (source: $v v^T$)", (surr_R_evals, surr_R_y)),
    ("Surr (source: pairs $v$)", (surr_rank1_evals, surr_rank1_y)),
)

# %%

fig, ax = plt.subplots()

evals_y_rev = list(reversed(evals_y))

ax_kwargs = dict(linewidth=0.2, alpha=0.5)
for _, (xs, _) in evals_y_rev[:2]:
    line = ax.axvline(xs[0], **ax_kwargs)
    color = line.get_color()
    for x in xs[1:]:
        line = ax.axvline(x, **ax_kwargs)

ax_kwargs = dict(s=10, marker="o", alpha=0.5)
for i, (label, (x, y)) in enumerate(evals_y_rev):
    ax.scatter(x, y * i, label=label, **ax_kwargs)

# ax.set_xscale("log")
ax.set_xlim(left=0)
ax.set_ylabel("Matrix Source")
ax.set_xlabel("Eigenvalues")
ax.set_title("Eigenvalues of group averaged matrix")
ax.legend()

fig.show()

# %%
