import os
import jax
import time
import copy
import jax.nn as nn
import jax.lax as lax
import jax.numpy as jnp
from lmc_model import print_model
from collections import defaultdict
from typing import NamedTuple
from flax.core import freeze, unfreeze
from jax import random, tree_util, jit, grad, value_and_grad
from scipy.optimize import linear_sum_assignment, minimize, minimize_scalar
from scipy.linalg import block_diag
from math import sqrt, cos, sin, atan2
import numpy as np
import matplotlib.pyplot as plt
def compute_objective(A, X, X_prime, Y, Y_prime):
    A_inv = np.linalg.inv(A)
    term1 = X - X_prime @ A.T
    term2 = Y - Y_prime @ A_inv
    return np.sum(term1**2) + np.sum(term2**2)

def compute_gradient(A, X, X_prime, Y, Y_prime):
    A_inv = np.linalg.inv(A)
    term1 = -2 * X.T @ X_prime + 2 * A @ X_prime.T @ X_prime
    term2 = 2 * A_inv.T @ Y_prime.T @ (Y - Y_prime @ A_inv) @ A_inv.T
    return term1 + term2

def line_search(A, grad, X, X_prime, Y, Y_prime, max_step=1, tau=0.5, c1=1e-4):
    eta = max_step
    f_current = compute_objective(A, X, X_prime, Y, Y_prime)
    grad_norm2 = np.sum(grad**2)
    n = A.shape[0]
    while eta > 1e-10:
        A_new = A - eta * grad
        if np.linalg.matrix_rank(A_new) < n:
            eta *= tau
            continue
        f_new = compute_objective(A_new, X, X_prime, Y, Y_prime)
        if f_new <= f_current - c1 * eta * grad_norm2:
            return eta
        eta *= tau
    return 0

@jax.jit
def compute_objective_jax(A, X, X_prime, Y, Y_prime, cond_threshold=1e6):
    cond = jnp.linalg.cond(A)
    def safe_obj():
        A_inv = jnp.linalg.inv(A)
        term1 = X - X_prime @ A.T
        term2 = Y - Y_prime @ A_inv
        return jnp.sum(term1**2) + jnp.sum(term2**2)
    return lax.cond(cond > cond_threshold, lambda: jnp.inf, safe_obj)
compute_value_and_grad_jax = jit(value_and_grad(compute_objective_jax))
def solve_orthogonal(X, X_prime, Y, Y_prime):
    B = X.T @ X_prime + Y.T @ Y_prime
    U, _, Vt = np.linalg.svd(B)
    return U @ Vt

def solve_rope_qk_alignment(W_Q_a_i, b_Q_a_i, W_K_a_i, b_K_a_i,
                            W_Q_b_i, b_Q_b_i, W_K_b_i, b_K_b_i):
    """
    Solves for the G_RoPE alignment matrix U for a single head's QK weights.
    """
    tilde_W_Q_a = compute_extended_weights(W_Q_a_i, b_Q_a_i)
    tilde_W_K_a = compute_extended_weights(W_K_a_i, b_K_a_i)
    tilde_W_Q_b = compute_extended_weights(W_Q_b_i, b_Q_b_i)
    tilde_W_K_b = compute_extended_weights(W_K_b_i, b_K_b_i)
    
    D_k = tilde_W_Q_a.shape[1]
    assert D_k % 2 == 0, "Head dimension must be even for RoPE."

    U_blocks = []
    J = jnp.array([[0, -1], [1, 0]])

    for j in range(D_k // 2):
        # 1. Slice submatrices for the j-th 2D subspace
        sl = slice(2 * j, 2 * j + 2)
        Q_a_j, Q_b_j = tilde_W_Q_a[:, sl], tilde_W_Q_b[:, sl]
        K_a_j, K_b_j = tilde_W_K_a[:, sl], tilde_W_K_b[:, sl]

        # 2. Precompute constants
        N_Q = jnp.sum(Q_b_j**2)
        N_K = jnp.sum(K_b_j**2)
        C_Q = Q_a_j.T @ Q_b_j
        C_K = K_a_j.T @ K_b_j
        
        c_q = 0.5 * (jnp.trace(C_Q) + 1j * jnp.trace(C_Q @ J))
        c_k = 0.5 * (jnp.trace(C_K) + 1j * jnp.trace(C_K @ J))

        A = jnp.abs(c_q)**2
        B = jnp.abs(c_k)**2
        C = 2 * jnp.real(c_q * jnp.conj(c_k))

        # Convert all JAX/Device arrays to native Python/NumPy types for SciPy
        N_Q_f, N_K_f = float(N_Q), float(N_K)
        A_f, B_f, C_f = float(A), float(B), float(C)
        c_q_f = complex(c_q)
        c_k_f = complex(c_k)

        # Define the 1D scalar objective function using native floats
        def g_objective(x):
            x = float(x)
            # Protect the sqrt argument from tiny negative values due to roundoff
            inner_term = A_f * x + (B_f / x) + C_f
            safe_inner = max(inner_term, 1e-20)
            return x * N_Q_f + N_K_f / x - 4.0 * sqrt(safe_inner)

        # 4. Find the minimizer x* using robust bounds
        res = minimize_scalar(g_objective, bounds=(1e-8, 1e8), method='bounded')
        x_star = res.x

        # 5. Reconstruct the optimal 2x2 alignment matrix U_j using NumPy/math
        r_star = sqrt(x_star)
        combined_c = r_star * c_q_f + (1 / r_star) * c_k_f
        if abs(combined_c) < 1e-30:
            theta_star = 0.0
        else:
            theta_star = -atan2(combined_c.imag, combined_c.real)
        a = r_star * cos(theta_star)
        b = r_star * sin(theta_star)
        
        U_j = np.array([[a, -b], [b, a]])
        U_blocks.append(U_j)
        
    # 6. Assemble the full block-diagonal matrix U
    U_opt = block_diag(*U_blocks)
    condU = np.linalg.cond(U_opt)
    if condU > 1e12:
        # fallback: scale blocks to have minimum magnitude, or add small diag:
        eps = 1e-6
        U_opt = U_opt + eps * np.eye(U_opt.shape[0])
    return U_opt

def optimize_alignment(A_init, X, X_prime, Y, Y_prime, max_iter=5000):
    objective_values = []
    grad_norms = []
    condition_nums = []

    def obj_fn(flat_A):
        A = flat_A.reshape(A_init.shape)
        obj, grad_val = compute_value_and_grad_jax(jnp.array(A), jnp.array(X), jnp.array(X_prime), jnp.array(Y), jnp.array(Y_prime))
        return float(obj), np.array(grad_val).flatten()

    def callback(flat_A):
        A = flat_A.reshape(A_init.shape)
        obj, grad_val = compute_value_and_grad_jax(jnp.array(A), jnp.array(X), jnp.array(X_prime), jnp.array(Y), jnp.array(Y_prime))
        grad_norm = jnp.linalg.norm(grad_val, 'fro')
        cond = jnp.linalg.cond(jnp.array(A))
        objective_values.append(float(obj))
        grad_norms.append(float(grad_norm))
        condition_nums.append(float(cond))

    res = minimize(obj_fn, A_init.flatten(), jac=True, method='L-BFGS-B', options={'maxiter': max_iter}, callback=callback)
    A_opt = res.x.reshape(A_init.shape)
    return A_opt, objective_values, grad_norms, condition_nums

def extract_attention_params(attn):
    key =np.array(attn['attention']['key']['kernel'])
    key_bias = np.array(attn['attention']['key']['bias'])
    query = np.array(attn['attention']['query']['kernel'])
    query_bias = np.array(attn['attention']['query']['bias'])
    value = np.array(attn['attention']['value']['kernel'])
    value_bias = np.array(attn['attention']['value']['bias'])
    out = np.array(attn['output']['dense']['kernel'])
    out_bias = np.array(attn['output']['dense']['bias'])
    return  query, key, value, query_bias, key_bias, value_bias, out, out_bias
def reshape_attention_weights(query, key, value, query_bias, key_bias, value_bias, out_kernel, num_heads):
    D = query.shape[0]
    D_k = D_v = D // num_heads
    def stack_per_head(tensor, axis=0):
        return np.stack([
            tensor[:,i * D_k:(i + 1) * D_k] if axis == 0 else tensor[i * D_k:(i + 1) * D_k,:]
            for i in range(num_heads)
        ])
    def stack_bias_per_head(bias):
        return np.stack([bias[i * D_k:(i + 1) * D_k] for i in range(num_heads)])
    W_Q = stack_per_head(query)
    W_K = stack_per_head(key)
    W_V = stack_per_head(value)
    W_O = stack_per_head(out_kernel, axis=1)
    b_Q = stack_bias_per_head(query_bias)
    b_K = stack_bias_per_head(key_bias)
    b_V = stack_bias_per_head(value_bias)
    return W_Q, b_Q, W_K, b_K, W_V, b_V, W_O

def compute_extended_weights(W, b):
    return np.vstack([W, b.reshape(1, -1)])

def compute_cost_matrix(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                        W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                        h, activations, alpha=0.5):
    C = np.zeros((h, h))
    for i in range(h):
        tilde_W_Q_a_i = np.vstack([W_Q_a[i], b_Q_a[i].reshape(1, -1)])
        tilde_W_K_a_i = np.vstack([W_K_a[i], b_K_a[i].reshape(1, -1)])
        tilde_W_V_a_i = np.vstack([W_V_a[i], b_V_a[i].reshape(1, -1)])
        QKT_a_i = tilde_W_Q_a_i @ tilde_W_K_a_i.T
        VO_a_i = tilde_W_V_a_i @ W_O_a[i]
        centered_QKT_a_i = QKT_a_i - np.mean(QKT_a_i, axis=1, keepdims=True)
        for j in range(h):
            tilde_W_Q_b_j = np.vstack([W_Q_b[j], b_Q_b[j].reshape(1, -1)])
            tilde_W_K_b_j = np.vstack([W_K_b[j], b_K_b[j].reshape(1, -1)])
            tilde_W_V_b_j = np.vstack([W_V_b[j], b_V_b[j].reshape(1, -1)])
            QKT_b_j = tilde_W_Q_b_j @ tilde_W_K_b_j.T
            VO_b_j = tilde_W_V_b_j @ W_O_b[j]
            centered_QKT_b_j = QKT_b_j - np.mean(QKT_b_j, axis=1, keepdims=True)
            cost = 0.5 * np.sum((centered_QKT_a_i - centered_QKT_b_j) ** 2)
            cost += 0.5 * np.sum((VO_a_i - VO_b_j) ** 2)
            C[i, j] = cost
    return C

def additive_align_single_head(W_Q_a_i, b_Q_a_i, W_K_a_i, b_K_a_i, W_V_a_i, b_V_a_i, W_O_a_i,
                      W_Q_b_i, b_Q_b_i, W_K_b_i, b_K_b_i, W_V_b_i, b_V_b_i, W_O_b_i, optimize):
    tilde_W_Q_a_i = compute_extended_weights(W_Q_a_i, b_Q_a_i)
    tilde_W_K_a_i = compute_extended_weights(W_K_a_i, b_K_a_i)
    tilde_W_V_a_i = compute_extended_weights(W_V_a_i, b_V_a_i)
    Y_O_a_i = W_O_a_i.T
    tilde_W_Q_b_i = compute_extended_weights(W_Q_b_i, b_Q_b_i)
    tilde_W_K_b_i = compute_extended_weights(W_K_b_i, b_K_b_i)
    tilde_W_V_b_i = compute_extended_weights(W_V_b_i, b_V_b_i)
    Y_O_b_i = W_O_b_i.T
    A_init = solve_orthogonal(tilde_W_Q_a_i, tilde_W_Q_b_i, tilde_W_K_a_i, tilde_W_K_b_i)
    B_init = solve_orthogonal(Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i)
    if optimize:
        A, objective_values_A, grad_norms_A, condition_nums_A = optimize_alignment(
            A_init, tilde_W_Q_a_i, tilde_W_Q_b_i, tilde_W_K_a_i, tilde_W_K_b_i
        )
        B, objective_values_B, grad_norms_B, condition_nums_B = optimize_alignment(
            B_init, Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i
        )
    else:
        A = A_init
        B = B_init
    
    A_inv = np.linalg.inv(A)
    B_inv = np.linalg.inv(B)
    W_Q_aligned = W_Q_b_i @ A.T
    b_Q_aligned = b_Q_b_i @ A.T
    W_K_aligned = W_K_b_i @ A_inv
    b_K_aligned = b_K_b_i @ A_inv
    W_V_aligned = W_V_b_i @ B_inv
    b_V_aligned = b_V_b_i @ B_inv
    W_O_aligned = B @ W_O_b_i
    
    aligned_params = {
        'query': {'kernel': W_Q_aligned, 'bias': b_Q_aligned},
        'key': {'kernel': W_K_aligned, 'bias': b_K_aligned},
        'value': {'kernel': W_V_aligned, 'bias': b_V_aligned},
        'out': {'kernel': W_O_aligned}
    }
    
    if optimize:
        return {
            'aligned_params': aligned_params,
            'metrics_A': {
                'objective_values': objective_values_A,
                'grad_norms': grad_norms_A,
                'condition_nums': condition_nums_A
            },
            'metrics_B': {
                'objective_values': objective_values_B,
                'grad_norms': grad_norms_B,
                'condition_nums': condition_nums_B
            }
        }
    return {'aligned_params': aligned_params}

def get_rope_matrix(seq_len, d_head):
    """Generates RoPE rotation matrices R[m] of shape (seq_len, d_head/2, 2, 2)."""
    assert d_head % 2 == 0, "d_head must be even"
    # inv_freq: (d_head/2,)
    inv_freq = 1.0 / (10000 ** (jnp.arange(0, d_head, 2) / d_head))
    t = jnp.arange(seq_len)                 # (seq_len,)
    freqs = jnp.einsum('i,j->ij', t, inv_freq)   # (seq_len, d_head/2)

    cos_freqs = jnp.cos(freqs)   # (seq_len, h)
    sin_freqs = jnp.sin(freqs)   # (seq_len, h)

    # Build rotation matrices per position and subspace: shape (seq_len, h, 2, 2)
    # Each 2x2 is [[cos, -sin], [sin, cos]]
    R = jnp.stack(
        [
            jnp.stack([cos_freqs, -sin_freqs], axis=-1),  # (seq_len, h, 2) -> first row entries
            jnp.stack([sin_freqs,  cos_freqs], axis=-1),  # (seq_len, h, 2) -> second row entries
        ],
        axis=-2
    )  # After this stack: shape (seq_len, h, 2, 2)
    return R

@jit
def apply_rope(x, R):
    """Applies RoPE to x of shape (B, L, D_k) using R shape (L, D_k/2, 2, 2)."""
    B, L, Dk = x.shape
    assert Dk % 2 == 0
    x_pairs = x.reshape((B, L, Dk//2, 2))   # (B, L, h, 2)
    # R must be (L, h, 2, 2)
    # einsum: 'b l h c, l h c r -> b l h r' -> back to (B, L, h, 2)
    x_rotated = jnp.einsum('blhc,lhcr->blhr', x_pairs, R)
    return x_rotated.reshape((B, L, Dk))
def compute_cost_matrix_presoftmax_rope(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                                        W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                                        num_heads, activations, alpha=0.5, epsilon=1e-8):
    """
    Computes the cost matrix for attention head permutation with RoPE.
    """
    B, L, D = activations.shape
    d_head = W_Q_a[0].shape[1]
    
    # Pre-compute RoPE matrices
    rope_matrices = get_rope_matrix(L, d_head)
    
    ones_col = jnp.ones((B, L, 1))
    X_tilde = jnp.concatenate([activations, ones_col], axis=-1)
    d_head = W_Q_a[0].shape[1]
    sqrt_d = jnp.sqrt(float(d_head))
    C = np.zeros((num_heads, num_heads))

    # Pre-compute all head outputs to avoid redundant calculations
    S_bar_flat_a = []
    V_flat_a = []
    S_bar_flat_b = []
    V_flat_b = []

    for i in range(num_heads):
        # Model A
        tilde_W_Q_a_i = compute_extended_weights(W_Q_a[i], b_Q_a[i])
        tilde_W_K_a_i = compute_extended_weights(W_K_a[i], b_K_a[i])
        tilde_W_V_a_i = compute_extended_weights(W_V_a[i], b_V_a[i])
        Q_a_i = X_tilde @ tilde_W_Q_a_i
        K_a_i = X_tilde @ tilde_W_K_a_i
        # Apply RoPE
        Q_rope_a_i = apply_rope(Q_a_i, rope_matrices)
        K_rope_a_i = apply_rope(K_a_i, rope_matrices)
        S_a_i = jnp.einsum('bld,bmd->blm', Q_rope_a_i, K_rope_a_i) / sqrt_d
        S_bar_a_i = S_a_i - jnp.mean(S_a_i, axis=2, keepdims=True)
        S_bar_flat_a.append(S_bar_a_i.flatten())
        V_tilde_a_i = X_tilde @ tilde_W_V_a_i
        V_a_i = V_tilde_a_i @ W_O_a[i]
        V_flat_a.append(V_a_i.flatten())
        # Model B
        tilde_W_Q_b_i = compute_extended_weights(W_Q_b[i], b_Q_b[i])
        tilde_W_K_b_i = compute_extended_weights(W_K_b[i], b_K_b[i])
        tilde_W_V_b_i = compute_extended_weights(W_V_b[i], b_V_b[i])
        Q_b_i = X_tilde @ tilde_W_Q_b_i
        K_b_i = X_tilde @ tilde_W_K_b_i
        # Apply RoPE
        Q_rope_b_i = apply_rope(Q_b_i, rope_matrices)
        K_rope_b_i = apply_rope(K_b_i, rope_matrices)
        S_b_i = jnp.einsum('bld,bmd->blm', Q_rope_b_i, K_rope_b_i) / sqrt_d
        S_bar_b_i = S_b_i - jnp.mean(S_b_i, axis=2, keepdims=True)
        S_bar_flat_b.append(S_bar_b_i.flatten())
        V_tilde_b_i = X_tilde @ tilde_W_V_b_i
        V_b_i = V_tilde_b_i @ W_O_b[i]
        V_flat_b.append(V_b_i.flatten())
    # Compute cost matrix from pre-computed values
    for i in range(num_heads):
        for j in range(num_heads):
            # Cosine similarity for S
            dot_S = jnp.dot(S_bar_flat_a[i], S_bar_flat_b[j])
            norm_S_a = jnp.linalg.norm(S_bar_flat_a[i])
            norm_S_b = jnp.linalg.norm(S_bar_flat_b[j])
            cos_sim_S = dot_S / (norm_S_a * norm_S_b + epsilon)
            cost_S = 1.0 - cos_sim_S

            # Cosine similarity for V
            dot_V = jnp.dot(V_flat_a[i], V_flat_b[j])
            norm_V_a = jnp.linalg.norm(V_flat_a[i])
            norm_V_b = jnp.linalg.norm(V_flat_b[j])
            cos_sim_V = dot_V / (norm_V_a * norm_V_b + epsilon)
            cost_V = 1.0 - cos_sim_V
            
            C[i, j] = (alpha * cost_S + (1 - alpha) * cost_V)
    return C
def multiplytive_align_single_head(W_Q_a_i, b_Q_a_i, W_K_a_i, b_K_a_i, W_V_a_i, b_V_a_i, W_O_a_i,
                      W_Q_b_i, b_Q_b_i, W_K_b_i, b_K_b_i, W_V_b_i, b_V_b_i, W_O_b_i, optimize):
    U = solve_rope_qk_alignment(W_Q_a_i, b_Q_a_i, W_K_a_i, b_K_a_i,W_Q_b_i, b_Q_b_i, W_K_b_i, b_K_b_i)
    U_inv = np.linalg.inv(U)
    W_Q_aligned = W_Q_b_i @ U.T
    b_Q_aligned = b_Q_b_i @ U.T
    W_K_aligned = W_K_b_i @ U_inv
    b_K_aligned = b_K_b_i @ U_inv
    tilde_W_V_a_i = compute_extended_weights(W_V_a_i, b_V_a_i)
    Y_O_a_i = W_O_a_i.T
    tilde_W_V_b_i = compute_extended_weights(W_V_b_i, b_V_b_i)
    Y_O_b_i = W_O_b_i.T
    B_init = solve_orthogonal(Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i)
    B, _, _, _ = optimize_alignment(B_init, Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i)
    B_inv = np.linalg.inv(B)
    W_V_aligned = W_V_b_i @ B_inv
    b_V_aligned = b_V_b_i @ B_inv
    W_O_aligned = B @ W_O_b_i
    aligned_params = {
        'query': {'kernel': W_Q_aligned, 'bias': b_Q_aligned},'key': {'kernel': W_K_aligned, 'bias': b_K_aligned},
        'value': {'kernel': W_V_aligned, 'bias': b_V_aligned},'out': {'kernel': W_O_aligned}
    }
    return {'aligned_params': aligned_params}
def apply_alignment(W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b, W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a, h):
    aligned_params = {}
    for i in range(h):
        print(f"Aligning Heads {i}")
        tilde_W_Q_a_i = np.vstack([W_Q_a[i], b_Q_a[i].reshape(1, -1)])
        tilde_W_K_a_i = np.vstack([W_K_a[i], b_K_a[i].reshape(1, -1)])
        tilde_W_Q_b_i = np.vstack([W_Q_b[i], b_Q_b[i].reshape(1, -1)])
        tilde_W_K_b_i = np.vstack([W_K_b[i], b_K_b[i].reshape(1, -1)])
        A_i_init = solve_orthogonal(tilde_W_Q_a_i, tilde_W_Q_b_i, tilde_W_K_a_i, tilde_W_K_b_i)
        # A_i = optimize(A_i_init, tilde_W_Q_a_i, tilde_W_Q_b_i, tilde_W_K_a_i, tilde_W_K_b_i)
        A_i = A_i_init
        tilde_W_V_a_i = np.vstack([W_V_a[i], b_V_a[i].reshape(1, -1)])
        tilde_W_V_b_i = np.vstack([W_V_b[i], b_V_b[i].reshape(1, -1)])
        Y_O = W_O_a[i].T
        Y_O_prime = W_O_b[i].T
        B_i_init = solve_orthogonal(tilde_W_V_a_i, tilde_W_V_b_i, Y_O, Y_O_prime)
        # B_i = optimize(B_i_init, tilde_W_V_a_i, tilde_W_V_b_i, Y_O, Y_O_prime)
        B_i = B_i_init
        A_i_inv = np.linalg.inv(A_i)
        B_i_inv = np.linalg.inv(B_i)

        W_Q_aligned = W_Q_b[i] @ A_i.T
        b_Q_aligned = b_Q_b[i] @ A_i.T
        W_K_aligned = W_K_b[i] @ A_i_inv
        b_K_aligned = b_K_b[i] @ A_i_inv
        W_V_aligned = W_V_b[i] @ B_i_inv
        b_V_aligned = b_V_b[i] @ B_i_inv
        W_O_aligned = B_i @ W_O_b[i]

        aligned_params[f'head_{i}'] = {
            'query': {'kernel': W_Q_aligned, 'bias': b_Q_aligned},
            'key': {'kernel': W_K_aligned, 'bias': b_K_aligned},
            'value': {'kernel': W_V_aligned, 'bias': b_V_aligned},
            'out': {'kernel': W_O_aligned}
        }
    return aligned_params

def merge_aligned_params(aligned_params, h, D, out_bias_b):
    query_kernel = np.stack([aligned_params[f'head_{i}']['query']['kernel'] for i in range(h)], axis=1)
    query_bias = np.stack([aligned_params[f'head_{i}']['query']['bias'] for i in range(h)], axis=0)
    key_kernel = np.stack([aligned_params[f'head_{i}']['key']['kernel'] for i in range(h)], axis=1)
    key_bias = np.stack([aligned_params[f'head_{i}']['key']['bias'] for i in range(h)], axis=0)
    value_kernel = np.stack([aligned_params[f'head_{i}']['value']['kernel'] for i in range(h)], axis=1)
    value_bias = np.stack([aligned_params[f'head_{i}']['value']['bias'] for i in range(h)], axis=0)
    out_kernel = np.stack([aligned_params[f'head_{i}']['out']['kernel'] for i in range(h)], axis=0)
    query_kernel = query_kernel.reshape(-1, D)
    key_kernel = key_kernel.reshape(-1, D)
    value_kernel = value_kernel.reshape(-1, D)
    out_kernel = out_kernel.reshape(D, -1)
    query_bias = query_bias.reshape(-1)
    key_bias = key_bias.reshape(-1)
    value_bias = value_bias.reshape(-1)
    return {
        'attention': {
            'query': {'kernel': jnp.array(query_kernel),'bias': jnp.array(query_bias),},
            'value': {'kernel': jnp.array(value_kernel),'bias': jnp.array(value_bias),},
            'key':  {'kernel': jnp.array(key_kernel),'bias': jnp.array(key_bias),},
        },
        'output': {
            'dense': {'kernel': jnp.array(out_kernel),'bias': jnp.array(out_bias_b),},
        },
    }

def align_attention_params(rng, params_a, params_b, layer_idx, config, activation, permute_heads=True, optimize=False, alpha=0.5):
    num_heads = config.lmc_config.num_attention_heads
    attn_a = params_a['vit']['encoder']["layer"][str(layer_idx)]['attention']
    attn_b = params_b['vit']['encoder']["layer"][str(layer_idx)]['attention']
    query_a, key_a, value_a, query_bias_a, key_bias_a, value_bias_a, out_a, out_bias_a = extract_attention_params(attn_a)
    query_b, key_b, value_b, query_bias_b, key_bias_b, value_bias_b, out_b, out_bias_b = extract_attention_params(attn_b)
    W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a = reshape_attention_weights(query_a, key_a, value_a, query_bias_a, key_bias_a, value_bias_a, out_a, num_heads)
    W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b = reshape_attention_weights(query_b, key_b, value_b, query_bias_b, key_bias_b, value_bias_b, out_b, num_heads)
    if permute_heads:
        C = compute_cost_matrix(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                                W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b, num_heads, activation, alpha)
        row_ind, col_ind = linear_sum_assignment(C)
        print("Best Permutation Heads:", col_ind)
        W_Q_b = [W_Q_b[j] for j in col_ind]
        b_Q_b = [b_Q_b[j] for j in col_ind]
        W_K_b = [W_K_b[j] for j in col_ind]
        b_K_b = [b_K_b[j] for j in col_ind]
        W_V_b = [W_V_b[j] for j in col_ind]
        b_V_b = [b_V_b[j] for j in col_ind]
        W_O_b = [W_O_b[j] for j in col_ind]
    if optimize:
        metrics_A_all = {key: [] for key in ['objective_values', 'grad_norms', 'condition_nums']}
        metrics_B_all = {key: [] for key in ['objective_values', 'grad_norms', 'condition_nums']}
    aligned_params, return_dict = {}, {}
    if(config.position_embeddings in ["learnable","sinusoidal"]):
        for i in range(num_heads):
            result = additive_align_single_head(
                W_Q_a[i], b_Q_a[i], W_K_a[i], b_K_a[i], W_V_a[i], b_V_a[i], W_O_a[i],
                W_Q_b[i], b_Q_b[i], W_K_b[i], b_K_b[i], W_V_b[i], b_V_b[i], W_O_b[i], optimize
            )
            aligned_params[f'head_{i}'] = result['aligned_params']
    elif(config.position_embeddings in ["rope"]): 
        for i in range(num_heads):
            result = multiplytive_align_single_head(
                W_Q_a[i], b_Q_a[i], W_K_a[i], b_K_a[i], W_V_a[i], b_V_a[i], W_O_a[i],
                W_Q_b[i], b_Q_b[i], W_K_b[i], b_K_b[i], W_V_b[i], b_V_b[i], W_O_b[i], optimize
            )
            aligned_params[f'head_{i}'] = result['aligned_params']
    return_dict['aligned_params'] = merge_aligned_params(aligned_params, num_heads, query_a.shape[1], out_bias_b)

    return return_dict 

def weight_matching_attn(rng, params_a, params_b, activation, config):
    params_dict = {}
    configurations = [
        # ("permu_head_init_ortho_no_opt", 'ortho', True, False),
        ("permu_head_init_ortho_opt", 'ortho', True, True),
        # ("naive_head_init_ortho_no_opt", 'ortho', False, False),
        # ("naive_head_init_ortho_opt", 'ortho', False, True),
    ]
    for name, init_method, permute_heads, optimize in configurations:
        aligned_params = copy.deepcopy(params_b)
        if optimize:
            layer_to_metrics_A = {}
            layer_to_metrics_B = {}
        for layer_idx in config.lmc_layer_indices:
            if activation is not None:
                activations_for_layer = activation[layer_idx]
            else: activations_for_layer = None 
            result = align_attention_params(
                rng, params_a, aligned_params, layer_idx, config,
                activations_for_layer, permute_heads=permute_heads, optimize=optimize
            )
            aligned_params['vit']['encoder']["layer"][str(layer_idx)]['attention'] = result['aligned_params']
        total_sum = tree_util.tree_reduce(lambda acc, x: acc + jnp.sum(x), aligned_params, initializer=0)
        print(f"{name}: {total_sum}, sanity check")
        params_dict[name] = aligned_params
    return params_dict