# %%


import numpy as np
from numpy import kron

from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jrnd

from symo.factor import Eq, S, factor_from_param, factor_from_cov
from symo.notebooks.plot_utils import default_rcparams, plot_matrix
import symo.special_matrix as sm

import matplotlib.pyplot as plt

plt.rcParams |= default_rcparams(dpi=500)


# %%


def commutation_matrix_einsum(n: int):
    """
    Commutation matrix using einsum - most elegant!

    K in 4D tensor form: K[i,j,k,l] = δ_il · δ_jk
    Then reshape to 2D
    """
    # Create identity matrices
    I_n = np.eye(n)

    # K[i,j,k,l] = I[i,l] * I[j,k]
    # Using einsum: 'il,jk->ijkl'
    K_4d = np.einsum("il,jk->ijkl", I_n, I_n)

    # Reshape from (n,n,n,n) to (n²,n²)
    K = K_4d.reshape(n * n, n * n)

    return K


def diagonal_spreader_outer_product(n: int, s: int | None = None):
    """
    DiagSpread = (1/n) vec(I) vec(I)^T
    """
    I = np.eye(n)
    s = n if s is None else s
    vec_I = I.ravel()
    return (1 / s) * np.outer(vec_I, vec_I)


def diagonal_trace_left_vectorized(n, s: int | None = None):
    """DiagTrLeft: |i,i⟩ → (1/n)Σ_k|k,i⟩"""

    s = n if s is None else s
    n2 = n * n
    i_vals = np.arange(n)
    k_vals = np.arange(n)

    input_indices = i_vals * (n + 1)
    output_indices = k_vals[:, None] * n + i_vals[None, :]

    output_flat = output_indices.T.ravel()
    input_flat = np.repeat(input_indices, n)

    D_L = np.zeros((n2, n2))
    D_L[output_flat, input_flat] = 1 / s
    return D_L


def diagonal_trace_right_vectorized(n: int, s: int | None = None):
    """DiagTrRight: |i,i⟩ → (1/n)Σ_l|i,l⟩"""

    s = n if s is None else s
    n2 = n * n
    i_vals = np.arange(n)
    l_vals = np.arange(n)

    input_indices = i_vals * (n + 1)
    output_indices = i_vals[:, None] * n + l_vals[None, :]

    output_flat = output_indices.ravel()
    input_flat = np.repeat(input_indices, n)

    D_R = np.zeros((n2, n2))
    D_R[output_flat, input_flat] = 1 / s
    return D_R


def fully_vectorized_basis(n: int, s: int | None = None):
    """
    Construct complete 15-dimensional basis with ZERO explicit loops!
    Everything is vectorized or uses matrix operations.
    """
    s = n if s is None else s

    # Basic building blocks
    I = np.eye(n)
    J = np.ones((n, n))

    n2 = n * n
    I_full = np.eye(n2)
    J_full = np.ones((n2, n2))

    # Commutation matrix - vectorized!
    K = commutation_matrix_einsum(n)

    # Projections - pure matrix operations
    Diag = K * I_full  # Hadamard product
    OffDiag = I_full - Diag  # Subtraction

    # Trace operators - Kronecker products
    Tr1 = (1 / s) * kron(J, I)
    Tr2 = (1 / s) * kron(I, J)
    TrBoth = (1 / s**2) * kron(J, J)

    # Diagonal operations - vectorized!
    DiagSpread = diagonal_spreader_outer_product(n, s)
    DiagTrLeft = diagonal_trace_left_vectorized(n, s)
    DiagTrRight = diagonal_trace_right_vectorized(n, s)

    # Composite operations - matrix products
    DiagFull = (1 / s**2) * Diag @ J_full
    OffDiagFull = (1 / s**2) * OffDiag @ J_full
    Mixed = (1 / s) * DiagSpread + OffDiag @ K

    # Assemble basis
    basis = [
        I_full,  # 1. Identity
        K,  # 2. SWAP
        Tr1,  # 3. Trace first
        Tr2,  # 4. Trace second
        TrBoth,  # 5. Trace both
        Diag,  # 6. Diagonal (K ∘ I)
        OffDiag,  # 7. Off-diagonal
        DiagSpread,  # 8. Diagonal spread
        K @ Tr1,  # 9. K·Tr1
        K @ Tr2,  # 10. K·Tr2
        DiagTrLeft,  # 11. Diagonal trace left
        DiagTrRight,  # 12. Diagonal trace right
        DiagFull,  # 13. Diagonal to full
        OffDiagFull,  # 14. Off-diagonal to full
        Mixed,  # 15. Mixed
    ]

    labels = [
        "I",
        "K (einsum: 'il,jk->ijkl')",
        "(1/n)J⊗I (Tr2)",
        "(1/n)I⊗J (Tr1)",
        "(1/n²)J⊗J (Tr⊗Tr)",
        "K∘I (Hadamard)",
        "I - K∘I",
        "DiagSpread",
        "K·Tr1",
        "K·Tr2",
        "DiagTrLeft",
        "DiagTrRight",
        "(1/n²)Diag·J",
        "(1/n²)OffDiag·J",
        "(1/n)DiagSpread + OffDiag·K",
    ]

    return basis, labels


# %%

m = 3
n = 20

# %%

basis_n, labels_n = fully_vectorized_basis(n)
basis_dict = {k: b for b, k in zip(basis_n, labels_n)}

# %%


basis_m, labels_m = fully_vectorized_basis(m)

# %%

k = 16
d = int(np.sqrt(k))
fig, axes = plt.subplots(nrows=d, ncols=d)
axes = axes.flatten()

I = np.eye(n**2)
sum_matrix = jnp.zeros_like(I)

for i, (name, mat) in enumerate(basis_dict.items()):
    ax = axes[i]
    sum_matrix += mat
    plot_matrix(
        fig,
        ax,
        mat,
        name,
    )

axes[-1].axis("off")

fig.tight_layout()

# %%

Tn = np.zeros_like(basis_n[0])
for b in basis_n:
    Tn += b

Tm = np.zeros_like(basis_m[0])
for b in basis_m:
    Tm += b

# %%

Tn_evals = np.linalg.svdvals(Tn)
Tm_evals = np.linalg.svdvals(Tm)

# %%

Tn2 = Tn @ Tn

# %%

Tm2 = Tm @ Tm

# %%

Tn_evals

# %%

Tm_evals

# %%


def recommended_check(basis_matrices, labels=None):
    """
    Recommended approach for checking 15 basis matrices
    """
    print("=" * 80)
    print("RECOMMENDED INDEPENDENCE CHECK")
    print("=" * 80)

    # Step 1: Quick rank check
    flattened = np.array([M.flatten() for M in basis_matrices])
    rank = np.linalg.matrix_rank(flattened, tol=1e-10)
    n_matrices = len(basis_matrices)

    print(f"\nNumber of matrices: {n_matrices}")
    print(f"Rank: {rank}")

    if rank == n_matrices:
        print("✓✓✓ LINEARLY INDEPENDENT ✓✓✓")
        return True
    else:
        print(f"✗ DEPENDENT (rank = {rank} < {n_matrices})")

        # Step 2: Find which ones are dependent
        print("\n" + "-" * 80)
        print("Finding redundant matrices...")
        print("-" * 80)

        independent = []
        independent_labels = []

        for i, M in enumerate(basis_matrices):
            v = M.flatten()
            v_proj = v.copy()

            for prev_M in independent:
                prev_v = prev_M.flatten()
                v_proj -= np.dot(v_proj, prev_v) * prev_v

            if np.linalg.norm(v_proj) > 1e-10:
                independent.append(M)
                if labels:
                    independent_labels.append(labels[i])
            else:
                if labels:
                    print(f"  ✗ {labels[i]} is redundant")
                else:
                    print(f"  ✗ Matrix {i} is redundant")

        return False


# Final test
n = 5

I = jnp.eye(n * n)
K = sm.commutation_matrix(n, n)
U = jnp.outer(jnp.eye(n).reshape(-1), jnp.eye(n).reshape(-1))
J = jnp.ones((n * n, n * n))

# Alternative
A = jnp.kron(jnp.eye(n), jnp.ones((n, n)))
B = jnp.kron(jnp.ones((n, n)), jnp.eye(n))
AK = A @ K
KA = K @ A
AU = A @ U
UA = U @ A
A_UA = A * UA
A_AK = A * AK
UA_AK = UA * AK
KA_AU = KA * AU
I_U = I * U

basis = {
    "$I$": I,
    "$K$": K,
    "$U$": U,
    "$J$": J,
    "$A$": A,
    "$B$": B,
    "$AK$": AK,
    "$KA$": KA,
    "$AU$": AU,
    "$UA$": UA,
    r"$A \cdot UA$": A_UA,
    r"$A \cdot AK$": A_AK,
    r"$UA \cdot AK$": UA_AK,
    r"$KA \cdot AU$": KA_AU,
    r"$I \cdot U$": I_U,
}

labels_mat, basis_mat = zip(*basis.items())

is_independent = recommended_check(basis_mat, labels_mat)

# %%

import numpy as np
from numpy import kron
import matplotlib.pyplot as plt
from matplotlib.patches import Arc, Circle, FancyBboxPatch
import matplotlib.patches as mpatches


def draw_partition_diagram(ax, connections, title, description=""):
    """
    Draw a single partition diagram

    connections: list of tuples (point1, point2) or 'loop' for closed loops
    Points numbered: 0,1 = bottom (inputs), 2,3 = top (outputs)
    """
    ax.set_xlim(-0.5, 3.5)
    ax.set_ylim(-0.5, 2.5)
    ax.axis("off")

    # Point positions
    # Bottom: i=0, j=1 (inputs)
    # Top: k=2, l=3 (outputs)
    positions = {
        0: (0.5, 0.3),  # i (bottom left)
        1: (2.5, 0.3),  # j (bottom right)
        2: (0.5, 1.7),  # k (top left)
        3: (2.5, 1.7),  # l (top right)
    }

    # Draw points
    for point, (x, y) in positions.items():
        circle = Circle((x, y), 0.08, color="black", zorder=10)
        ax.add_patch(circle)

    # Labels
    ax.text(0.5, -0.1, "i", ha="center", fontsize=12, weight="bold")
    ax.text(2.5, -0.1, "j", ha="center", fontsize=12, weight="bold")
    ax.text(0.5, 2.1, "k", ha="center", fontsize=12, weight="bold")
    ax.text(2.5, 2.1, "l", ha="center", fontsize=12, weight="bold")

    ax.text(-0.3, 0.3, "input:", ha="right", fontsize=9, style="italic")
    ax.text(-0.3, 1.7, "output:", ha="right", fontsize=9, style="italic")

    # Draw connections
    for conn in connections:
        if conn == "loop_left":
            # Draw loop on left
            arc = Arc(
                (0.5, 1.0),
                0.6,
                1.4,
                angle=0,
                theta1=90,
                theta2=270,
                linewidth=2,
                color="blue",
            )
            ax.add_patch(arc)
        elif conn == "loop_right":
            # Draw loop on right
            arc = Arc(
                (2.5, 1.0),
                0.6,
                1.4,
                angle=0,
                theta1=270,
                theta2=90,
                linewidth=2,
                color="blue",
            )
            ax.add_patch(arc)
        elif conn == "diag_connect":
            # Connect i to j and k to l (diagonal block)
            ax.plot([0.5, 2.5], [0.3, 0.3], "b-", linewidth=2)
            ax.plot([0.5, 2.5], [1.7, 1.7], "b-", linewidth=2)
            ax.plot([0.5, 0.5], [0.3, 1.7], "b-", linewidth=2)
            ax.plot([2.5, 2.5], [0.3, 1.7], "b-", linewidth=2)
        elif isinstance(conn, tuple):
            p1, p2 = conn
            x1, y1 = positions[p1]
            x2, y2 = positions[p2]

            if p1 == p2:
                # Self-loop (shouldn't happen in valid diagrams)
                continue
            elif x1 == x2:
                # Vertical line (straight connection)
                ax.plot([x1, x2], [y1, y2], "b-", linewidth=2)
            else:
                # Horizontal or crossing line
                if (p1, p2) in [(0, 3), (1, 2)]:
                    # Crossing lines (swap)
                    ax.plot([x1, x2], [y1, y2], "b-", linewidth=2)
                else:
                    ax.plot([x1, x2], [y1, y2], "b-", linewidth=2)

    # Title and description
    ax.text(1.5, 2.4, title, ha="center", fontsize=11, weight="bold")
    if description:
        ax.text(
            1.5, -0.5, description, ha="center", fontsize=8, style="italic", wrap=True
        )


def visualize_key_partition_diagrams():
    """
    Visualize the most important partition diagrams
    """
    print("\n" + "=" * 80)
    print("VISUALIZING KEY PARTITION DIAGRAMS")
    print("=" * 80)

    fig, axes = plt.subplots(3, 5, figsize=(15, 9))
    fig.suptitle("The 15 Partition Diagrams of P₂(n)", fontsize=14, weight="bold")

    diagrams = [
        # Row 1: Basic operations
        ([(0, 2), (1, 3)], "1. Identity", "|i,j⟩ → |i,j⟩"),
        ([(0, 3), (1, 2)], "2. SWAP", "|i,j⟩ → |j,i⟩"),
        (["loop_left", (1, 3)], "3. Trace₁⊗I", "(1/n)Σₖ|k,j⟩"),
        ([(0, 2), "loop_right"], "4. I⊗Trace₂", "(1/n)Σₗ|i,l⟩"),
        (["loop_left", "loop_right"], "5. Tr⊗Tr", "(1/n²)Σₖₗ|k,l⟩"),
        # Row 2: Diagonal operations
        (["diag_connect"], "6. Diagonal", "δᵢⱼδₖₗδᵢₖ"),
        ([(0, 2), (1, 3)], "7. Off-Diag", "(1-δᵢⱼ)δᵢₖδⱼₗ"),
        (["diag_connect"], "8. DiagSpread", "(1/n)δᵢⱼδₖₗ"),
        ([(0, 3)], "9. SWAP∘Tr₁", "(1/n)δⱼₖ"),
        ([(1, 2)], "10. SWAP∘Tr₂", "(1/n)δᵢₗ"),
        # Row 3: More complex
        ([(0, 2)], "11. DiagTrLeft", "(1/n)δᵢⱼδⱼₗ"),
        ([(1, 3)], "12. DiagTrRight", "(1/n)δᵢⱼδᵢₖ"),
        (["diag_connect"], "13. Diag→Full", "(1/n²)δᵢⱼ"),
        ([(0, 2), (1, 3)], "14. OffDiag→Full", "(1/n²)(1-δᵢⱼ)"),
        ([(0, 3), (1, 2)], "15. Mixed", "conditional"),
    ]

    for idx, (ax, (connections, title, desc)) in enumerate(zip(axes.flat, diagrams)):
        draw_partition_diagram(ax, connections, title, desc)

    plt.tight_layout()
    plt.savefig("partition_diagrams.png", dpi=150, bbox_inches="tight")
    print("\n✓ Saved diagram to 'partition_diagrams.png'")

    return fig


# Uncomment to generate visualization
fig = visualize_key_partition_diagrams()
plt.show()

# %%
