import os
import copy
import time
import random
import itertools
import numpy as np
import jax.lax as lax
import jax.numpy as jnp
from utils import rngmix
import matplotlib.pyplot as plt
from typing import NamedTuple
from collections import defaultdict
from flax.core import freeze, unfreeze
from scipy.optimize import linear_sum_assignment, minimize, minimize_scalar
from jax import random, tree_util, jit, grad, value_and_grad
from scipy.linalg import block_diag
from math import sqrt, cos, sin, atan2
import numpy as np

def l2_dist(A, B):
    return np.sum((A - B) ** 2)
def l1_dist(A, B):
    return np.sum(np.abs(A - B))
def cosine_dist(A, B, eps=1e-8):
    a = A.reshape(-1)
    b = B.reshape(-1)
    return 1.0 - np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + eps)
def corr_dist(A, B, eps=1e-8):
    a = (A - A.mean()).reshape(-1)
    b = (B - B.mean()).reshape(-1)
    return 1.0 - np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + eps)
def spectral_dist(A, B):
    return abs(np.linalg.norm(A, 2) - np.linalg.norm(B, 2))
def compute_objective(A, X, X_prime, Y, Y_prime):
    A_inv = np.linalg.inv(A)
    term1 = X - X_prime @ A.T
    term2 = Y - Y_prime @ A_inv
    return np.sum(term1**2) + np.sum(term2**2)

def compute_gradient(A, X, X_prime, Y, Y_prime):
    A_inv = np.linalg.inv(A)
    term1 = -2 * X.T @ X_prime + 2 * A @ X_prime.T @ X_prime
    term2 = 2 * A_inv.T @ Y_prime.T @ (Y - Y_prime @ A_inv) @ A_inv.T
    return term1 + term2

def line_search(A, grad, X, X_prime, Y, Y_prime, max_step=1, tau=0.5, c1=1e-4):
    eta = max_step
    f_current = compute_objective(A, X, X_prime, Y, Y_prime)
    grad_norm2 = np.sum(grad**2)
    n = A.shape[0]
    while eta > 1e-10:
        A_new = A - eta * grad
        if np.linalg.matrix_rank(A_new) < n:
            eta *= tau
            continue
        f_new = compute_objective(A_new, X, X_prime, Y, Y_prime)
        if f_new <= f_current - c1 * eta * grad_norm2:
            return eta
        eta *= tau
    return 0
@jit
def compute_objective_jax(A, X, X_prime, Y, Y_prime, cond_threshold=1e6):
    cond = jnp.linalg.cond(A)
    def safe_obj():
        A_inv = jnp.linalg.inv(A)
        term1 = X - X_prime @ A.T
        term2 = Y - Y_prime @ A_inv
        return jnp.sum(term1**2) + jnp.sum(term2**2)
    return lax.cond(cond > cond_threshold, lambda: jnp.inf, safe_obj)
compute_value_and_grad_jax = jit(value_and_grad(compute_objective_jax))
def solve_orthogonal(X, X_prime, Y, Y_prime):
    B = X.T @ X_prime + Y.T @ Y_prime
    U, _, Vt = np.linalg.svd(B)
    return U @ Vt

def solve_rope_qk_alignment(W_Q_a_i, b_Q_a_i, W_K_a_i, b_K_a_i,
                            W_Q_b_i, b_Q_b_i, W_K_b_i, b_K_b_i):
    """
    Solves for the G_RoPE alignment matrix U for a single head's QK weights.
    """
    tilde_W_Q_a = compute_extended_weights(W_Q_a_i, b_Q_a_i)
    tilde_W_K_a = compute_extended_weights(W_K_a_i, b_K_a_i)
    tilde_W_Q_b = compute_extended_weights(W_Q_b_i, b_Q_b_i)
    tilde_W_K_b = compute_extended_weights(W_K_b_i, b_K_b_i)
    
    D_k = tilde_W_Q_a.shape[1]
    assert D_k % 2 == 0, "Head dimension must be even for RoPE."

    U_blocks = []
    J = jnp.array([[0, -1], [1, 0]])

    for j in range(D_k // 2):
        # 1. Slice submatrices for the j-th 2D subspace
        sl = slice(2 * j, 2 * j + 2)
        Q_a_j, Q_b_j = tilde_W_Q_a[:, sl], tilde_W_Q_b[:, sl]
        K_a_j, K_b_j = tilde_W_K_a[:, sl], tilde_W_K_b[:, sl]

        # 2. Precompute constants
        N_Q = jnp.sum(Q_b_j**2)
        N_K = jnp.sum(K_b_j**2)
        C_Q = Q_a_j.T @ Q_b_j
        C_K = K_a_j.T @ K_b_j
        
        c_q = 0.5 * (jnp.trace(C_Q) + 1j * jnp.trace(C_Q @ J))
        c_k = 0.5 * (jnp.trace(C_K) + 1j * jnp.trace(C_K @ J))

        A = jnp.abs(c_q)**2
        B = jnp.abs(c_k)**2
        C = 2 * jnp.real(c_q * jnp.conj(c_k))

        # Convert all JAX/Device arrays to native Python/NumPy types for SciPy
        N_Q_f, N_K_f = float(N_Q), float(N_K)
        A_f, B_f, C_f = float(A), float(B), float(C)
        c_q_f = complex(c_q)
        c_k_f = complex(c_k)

        # Define the 1D scalar objective function using native floats
        def g_objective(x):
            x = float(x)
            # Protect the sqrt argument from tiny negative values due to roundoff
            inner_term = A_f * x + (B_f / x) + C_f
            safe_inner = max(inner_term, 1e-20)
            return x * N_Q_f + N_K_f / x - 4.0 * sqrt(safe_inner)

        # 4. Find the minimizer x* using robust bounds
        res = minimize_scalar(g_objective, bounds=(1e-8, 1e8), method='bounded')
        x_star = res.x

        # 5. Reconstruct the optimal 2x2 alignment matrix U_j using NumPy/math
        r_star = sqrt(x_star)
        combined_c = r_star * c_q_f + (1 / r_star) * c_k_f
        if abs(combined_c) < 1e-30:
            theta_star = 0.0
        else:
            theta_star = -atan2(combined_c.imag, combined_c.real)
        a = r_star * cos(theta_star)
        b = r_star * sin(theta_star)
        
        U_j = np.array([[a, -b], [b, a]])
        U_blocks.append(U_j)
        
    # 6. Assemble the full block-diagonal matrix U
    U_opt = block_diag(*U_blocks)
    condU = np.linalg.cond(U_opt)
    if condU > 1e12:
        # fallback: scale blocks to have minimum magnitude, or add small diag:
        eps = 1e-6
        U_opt = U_opt + eps * np.eye(U_opt.shape[0])
    return U_opt

def solve_rope(X, X_prime, Y, Y_prime,max_iters=200, tol=1e-16):
    d = X.shape[1]
    assert d % 2 == 0, "d must be even."
    assert X.shape[1] == X_prime.shape[1] == Y.shape[1] == Y_prime.shape[1]
    def rot(theta):
        c, s = np.cos(theta), np.sin(theta)
        return np.array([[c, -s], [s, c]])
    def block_cols(j): return [2*j, 2*j+1]
    def solve_block(Q1blk, Q2blk, K1blk, K2blk):
        A, B, Ah, Bh = Q1blk, Q2blk, K1blk, K2blk
        a, ah = np.sum(A*A), np.sum(Ah*Ah)
        c_const = np.sum(B*B) + np.sum(Bh*Bh)
        C, Ch = A.T @ B, Ah.T @ Bh
        t_tr, s_sk = np.trace(C), C[0,1] - C[1,0]
        th_tr, sh_sk = np.trace(Ch), Ch[0,1] - Ch[1,0]
        u, v, w = t_tr**2+s_sk**2, th_tr**2+sh_sk**2, t_tr*th_tr+s_sk*sh_sk
        eps = 1e-18
        def phi(t): return max(u*t + v/max(t,eps) + 2*w, 0.0)
        def gprime(t):
            denom = np.sqrt(phi(t))
            if denom < eps: return a - ah/(t*t)
            return (a - ah/(t*t)) - (u - v/(t*t)) / denom
        t0 = np.sqrt((ah+eps)/(a+eps))
        t_lo, gp_lo = t0, gprime(t0)
        if gp_lo < 0.0:
            t_hi = t_lo
            for _ in range(max_iters):
                t_hi *= 2.0
                if gprime(t_hi) >= 0.0: break
        else:
            t_hi = t_lo
            for _ in range(max_iters):
                t_lo *= 0.5
                if gprime(t_lo) <= 0.0: break
        def gval(t): return a*t + ah/max(t,eps) + c_const - 2*np.sqrt(phi(t))
        if not (gprime(t_lo) <= 0.0 <= gprime(t_hi)):
            t_star = min([(t_lo,gval(t_lo)),(t_hi,gval(t_hi))], key=lambda z:z[1])[0]
        else:
            for _ in range(max_iters):
                t_mid = 0.5*(t_lo+t_hi)
                gp_mid = gprime(t_mid)
                if abs(gp_mid) < tol or (t_hi-t_lo) <= tol*(1+t_mid):
                    t_star = t_mid; break
                if gp_mid < 0.0: t_lo = t_mid
                else: t_hi = t_mid
            else:
                t_star = 0.5*(t_lo+t_hi)
        rho = np.sqrt(max(t_star, eps))
        alpha = rho*t_tr + (1/rho)*th_tr
        beta = rho*s_sk + (1/rho)*sh_sk
        theta = np.arctan2(beta, alpha)
        return rho, theta
    P = np.zeros((d, d))
    for j in range(d//2):
        cols = block_cols(j)
        rho, theta = solve_block(X[:,cols], X_prime[:,cols], Y[:,cols], Y_prime[:,cols])
        P[np.ix_(cols, cols)] = rho * rot(theta)
    return P


def optimize_alignment(A_init, X, X_prime, Y, Y_prime, max_iter=5000):
    objective_values = []
    grad_norms = []
    condition_nums = []

    def obj_fn(flat_A):
        A = flat_A.reshape(A_init.shape)
        obj, grad_val = compute_value_and_grad_jax(jnp.array(A), jnp.array(X), jnp.array(X_prime), jnp.array(Y), jnp.array(Y_prime))
        return float(obj), np.array(grad_val).flatten()

    def callback(flat_A):
        A = flat_A.reshape(A_init.shape)
        obj, grad_val = compute_value_and_grad_jax(jnp.array(A), jnp.array(X), jnp.array(X_prime), jnp.array(Y), jnp.array(Y_prime))
        grad_norm = jnp.linalg.norm(grad_val, 'fro')
        cond = jnp.linalg.cond(jnp.array(A))
        objective_values.append(float(obj))
        grad_norms.append(float(grad_norm))
        condition_nums.append(float(cond))

    res = minimize(obj_fn, A_init.flatten(), jac=True, method='L-BFGS-B', options={'maxiter': max_iter}, callback=callback)
    A_opt = res.x.reshape(A_init.shape)
    return A_opt, objective_values, grad_norms, condition_nums


def extract_attention_params(attn):
    c_attn_kernel = np.array(attn['c_attn']['kernel'])
    c_attn_bias = np.array(attn['c_attn']['bias'])
    c_proj_kernel = np.array(attn['c_proj']['kernel'])
    c_proj_bias = np.array(attn['c_proj']['bias'])
    query, key, value = np.split(c_attn_kernel, 3, axis=0)
    query_bias, key_bias, value_bias = np.split(c_attn_bias, 3, axis=0)
    return query, key, value, query_bias, key_bias, value_bias, c_proj_kernel, c_proj_bias

def extract_ffn_params(ffn):
    W1 = ffn['c_fc']['kernel']
    b1 = ffn['c_fc']['bias']
    W2 = ffn['c_proj']['kernel']
    b2 = ffn['c_proj']['bias']
    return W1, b1, W2, b2
def reshape_attention_weights(query, key, value, query_bias, key_bias, value_bias, out_kernel, num_heads):
    D = query.shape[1]
    D_k = D_v = D // num_heads
    def stack_per_head(tensor, axis=0):
        return np.stack([
            tensor[i * D_k:(i + 1) * D_k, :].T if axis == 0 else tensor[:, i * D_k:(i + 1) * D_k].T
            for i in range(num_heads)
        ])
    def stack_bias_per_head(bias):
        return np.stack([bias[i * D_k:(i + 1) * D_k].T for i in range(num_heads)])
    W_Q = stack_per_head(query)
    W_K = stack_per_head(key)
    W_V = stack_per_head(value)
    W_O = stack_per_head(out_kernel, axis=1)
    b_Q = stack_bias_per_head(query_bias)
    b_K = stack_bias_per_head(key_bias)
    b_V = stack_bias_per_head(value_bias)
    return W_Q, b_Q, W_K, b_K, W_V, b_V, W_O
def compute_extended_weights(W, b):
    return np.vstack([W, b.reshape(1, -1)])
def compute_cost_matrix(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                        W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                        h, alpha=0.5, dist = 'L2'):
    C = np.zeros((h, h))
    for i in range(h):
        tilde_W_Q_a_i = np.vstack([W_Q_a[i], b_Q_a[i].reshape(1, -1)])
        tilde_W_K_a_i = np.vstack([W_K_a[i], b_K_a[i].reshape(1, -1)])
        tilde_W_V_a_i = np.vstack([W_V_a[i], b_V_a[i].reshape(1, -1)])
        QKT_a_i = tilde_W_Q_a_i @ tilde_W_K_a_i.T
        VO_a_i = tilde_W_V_a_i @ W_O_a[i]
        centered_QKT_a_i = QKT_a_i - np.mean(QKT_a_i, axis=1, keepdims=True)

        for j in range(h):
            tilde_W_Q_b_j = np.vstack([W_Q_b[j], b_Q_b[j].reshape(1, -1)])
            tilde_W_K_b_j = np.vstack([W_K_b[j], b_K_b[j].reshape(1, -1)])
            tilde_W_V_b_j = np.vstack([W_V_b[j], b_V_b[j].reshape(1, -1)])
            QKT_b_j = tilde_W_Q_b_j @ tilde_W_K_b_j.T
            VO_b_j = tilde_W_V_b_j @ W_O_b[j]
            centered_QKT_b_j = QKT_b_j - np.mean(QKT_b_j, axis=1, keepdims=True)
            # print(dist)
            if dist == "L2":
                cost_qk = l2_dist(centered_QKT_a_i, centered_QKT_b_j)
                cost_vo = l2_dist(VO_a_i, VO_b_j)
            elif dist == "L1":
                cost_qk = l1_dist(centered_QKT_a_i, centered_QKT_b_j)
                cost_vo = l1_dist(VO_a_i, VO_b_j)
            elif dist == "cosine":
                cost_qk = cosine_dist(centered_QKT_a_i, centered_QKT_b_j)
                cost_vo = cosine_dist(VO_a_i, VO_b_j)
            elif dist == "corr":
                cost_qk = corr_dist(centered_QKT_a_i, centered_QKT_b_j)
                cost_vo = corr_dist(VO_a_i, VO_b_j)
            elif dist == "spectral":
                cost_qk = spectral_dist(centered_QKT_a_i, centered_QKT_b_j)
                cost_vo = spectral_dist(VO_a_i, VO_b_j)
            else:
                raise ValueError(f"Unknown dist: {dist}")
            C[i, j] = cost_qk + cost_vo
    return C
def additive_align_single_head(W_Q_a_i, b_Q_a_i, W_K_a_i, b_K_a_i, W_V_a_i, b_V_a_i, W_O_a_i,
                      W_Q_b_i, b_Q_b_i, W_K_b_i, b_K_b_i, W_V_b_i, b_V_b_i, W_O_b_i, optimize):
    tilde_W_Q_a_i = compute_extended_weights(W_Q_a_i, b_Q_a_i)
    tilde_W_K_a_i = compute_extended_weights(W_K_a_i, b_K_a_i)
    tilde_W_V_a_i = compute_extended_weights(W_V_a_i, b_V_a_i)
    Y_O_a_i = W_O_a_i.T
    tilde_W_Q_b_i = compute_extended_weights(W_Q_b_i, b_Q_b_i)
    tilde_W_K_b_i = compute_extended_weights(W_K_b_i, b_K_b_i)
    tilde_W_V_b_i = compute_extended_weights(W_V_b_i, b_V_b_i)
    Y_O_b_i = W_O_b_i.T
    A_init = solve_orthogonal(tilde_W_Q_a_i, tilde_W_Q_b_i, tilde_W_K_a_i, tilde_W_K_b_i)
    B_init = solve_orthogonal(Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i)
    if optimize:
        A, objective_values_A, grad_norms_A, condition_nums_A = optimize_alignment(
            A_init, tilde_W_Q_a_i, tilde_W_Q_b_i, tilde_W_K_a_i, tilde_W_K_b_i
        )
        B, objective_values_B, grad_norms_B, condition_nums_B = optimize_alignment(
            B_init, Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i
        )
    else:
        A = A_init
        B = B_init
    A_inv = np.linalg.inv(A)
    B_inv = np.linalg.inv(B)
    W_Q_aligned = W_Q_b_i @ A.T
    b_Q_aligned = b_Q_b_i @ A.T
    W_K_aligned = W_K_b_i @ A_inv
    b_K_aligned = b_K_b_i @ A_inv
    W_V_aligned = W_V_b_i @ B_inv
    b_V_aligned = b_V_b_i @ B_inv
    W_O_aligned = B @ W_O_b_i
    aligned_params = {
        'query': {'kernel': W_Q_aligned, 'bias': b_Q_aligned},
        'key': {'kernel': W_K_aligned, 'bias': b_K_aligned},
        'value': {'kernel': W_V_aligned, 'bias': b_V_aligned},
        'out': {'kernel': W_O_aligned}
    }
    if optimize:
        return {
            'aligned_params': aligned_params,
            'metrics_A': {
                'objective_values': objective_values_A,
                'grad_norms': grad_norms_A,
                'condition_nums': condition_nums_A
            },
            'metrics_B': {
                'objective_values': objective_values_B,
                'grad_norms': grad_norms_B,
                'condition_nums': condition_nums_B
            }
        }
    return {'aligned_params': aligned_params}


    #############RoPE#################
def get_rope_matrix(seq_len, d_head):
    """Generates RoPE rotation matrices R[m] of shape (seq_len, d_head/2, 2, 2)."""
    assert d_head % 2 == 0, "d_head must be even"
    # inv_freq: (d_head/2,)
    inv_freq = 1.0 / (10000 ** (jnp.arange(0, d_head, 2) / d_head))
    t = jnp.arange(seq_len)                 # (seq_len,)
    freqs = jnp.einsum('i,j->ij', t, inv_freq)   # (seq_len, d_head/2)

    cos_freqs = jnp.cos(freqs)   # (seq_len, h)
    sin_freqs = jnp.sin(freqs)   # (seq_len, h)

    # Build rotation matrices per position and subspace: shape (seq_len, h, 2, 2)
    # Each 2x2 is [[cos, -sin], [sin, cos]]
    R = jnp.stack(
        [
            jnp.stack([cos_freqs, -sin_freqs], axis=-1),  # (seq_len, h, 2) -> first row entries
            jnp.stack([sin_freqs,  cos_freqs], axis=-1),  # (seq_len, h, 2) -> second row entries
        ],
        axis=-2
    )  # After this stack: shape (seq_len, h, 2, 2)
    return R

@jit
def apply_rope(x, R):
    """Applies RoPE to x of shape (B, L, D_k) using R shape (L, D_k/2, 2, 2)."""
    B, L, Dk = x.shape
    assert Dk % 2 == 0
    x_pairs = x.reshape((B, L, Dk//2, 2))   # (B, L, h, 2)
    # R must be (L, h, 2, 2)
    # einsum: 'b l h c, l h c r -> b l h r' -> back to (B, L, h, 2)
    x_rotated = jnp.einsum('blhc,lhcr->blhr', x_pairs, R)
    return x_rotated.reshape((B, L, Dk))
def multiplytive_align_single_head(W_Q_a_i, b_Q_a_i, W_K_a_i, b_K_a_i, W_V_a_i, b_V_a_i, W_O_a_i,
                      W_Q_b_i, b_Q_b_i, W_K_b_i, b_K_b_i, W_V_b_i, b_V_b_i, W_O_b_i, optimize):
    U = solve_rope_qk_alignment(W_Q_a_i, b_Q_a_i, W_K_a_i, b_K_a_i,W_Q_b_i, b_Q_b_i, W_K_b_i, b_K_b_i)
    U_inv = np.linalg.inv(U)
    W_Q_aligned = W_Q_b_i @ U.T
    b_Q_aligned = b_Q_b_i @ U.T
    W_K_aligned = W_K_b_i @ U_inv
    b_K_aligned = b_K_b_i @ U_inv
    tilde_W_V_a_i = compute_extended_weights(W_V_a_i, b_V_a_i)
    Y_O_a_i = W_O_a_i.T
    tilde_W_V_b_i = compute_extended_weights(W_V_b_i, b_V_b_i)
    Y_O_b_i = W_O_b_i.T
    B_init = solve_orthogonal(Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i)
    B, _, _, _ = optimize_alignment(B_init, Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i)
    B_inv = np.linalg.inv(B)
    W_V_aligned = W_V_b_i @ B_inv
    b_V_aligned = b_V_b_i @ B_inv
    W_O_aligned = B @ W_O_b_i
    aligned_params = {
        'query': {'kernel': W_Q_aligned, 'bias': b_Q_aligned},'key': {'kernel': W_K_aligned, 'bias': b_K_aligned},
        'value': {'kernel': W_V_aligned, 'bias': b_V_aligned},'out': {'kernel': W_O_aligned}
    }
    return {'aligned_params': aligned_params}
def merge_aligned_params(aligned_params, h, D, out_bias_b):
    query_kernel = np.stack([aligned_params[f'head_{i}']['query']['kernel'] for i in range(h)], axis=1)
    query_bias = np.stack([aligned_params[f'head_{i}']['query']['bias'] for i in range(h)], axis=0)
    key_kernel = np.stack([aligned_params[f'head_{i}']['key']['kernel'] for i in range(h)], axis=1)
    key_bias = np.stack([aligned_params[f'head_{i}']['key']['bias'] for i in range(h)], axis=0)
    value_kernel = np.stack([aligned_params[f'head_{i}']['value']['kernel'] for i in range(h)], axis=1)
    value_bias = np.stack([aligned_params[f'head_{i}']['value']['bias'] for i in range(h)], axis=0)
    out_kernel = np.stack([aligned_params[f'head_{i}']['out']['kernel'] for i in range(h)], axis=0)

    query_kernel = query_kernel.transpose(1, 2, 0).reshape(-1, D)
    key_kernel = key_kernel.transpose(1, 2, 0).reshape(-1, D)
    value_kernel = value_kernel.transpose(1, 2, 0).reshape(-1, D)
    out_kernel = out_kernel.transpose(2, 0, 1).reshape(D, -1)

    query_bias = query_bias.reshape(-1)
    key_bias = key_bias.reshape(-1)
    value_bias = value_bias.reshape(-1)

    return {
        'c_attn': {
            'kernel': jnp.array(np.concatenate([query_kernel, key_kernel, value_kernel], axis=0)),
            'bias': jnp.array(np.concatenate([query_bias, key_bias, value_bias], axis=0)),
        },
        'c_proj': {'kernel': jnp.array(out_kernel),'bias': jnp.array(out_bias_b),}
    }
def align_attention_params(rng, params_a, params_b, layer_idx, config, permute_heads=True, optimize=False, dist = 'L2', alpha=0.5):
    num_heads = config.lmc_config.n_head
    attn_a = params_a['transformer']['h'][str(layer_idx)]['attn']
    attn_b = params_b['transformer']['h'][str(layer_idx)]['attn']
    query_a, key_a, value_a, query_bias_a, key_bias_a, value_bias_a, out_a, out_bias_a = extract_attention_params(attn_a)
    query_b, key_b, value_b, query_bias_b, key_bias_b, value_bias_b, out_b, out_bias_b = extract_attention_params(attn_b)
    W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a = reshape_attention_weights(query_a, key_a, value_a, query_bias_a, key_bias_a, value_bias_a, out_a, num_heads)
    W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b = reshape_attention_weights(query_b, key_b, value_b, query_bias_b, key_bias_b, value_bias_b, out_b, num_heads)
    if permute_heads:
        C = compute_cost_matrix(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                                W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b, num_heads, alpha, dist)
        row_ind, col_ind = linear_sum_assignment(C)
        print("Best Permutation Heads:", col_ind)
        W_Q_b = [W_Q_b[j] for j in col_ind]
        b_Q_b = [b_Q_b[j] for j in col_ind]
        W_K_b = [W_K_b[j] for j in col_ind]
        b_K_b = [b_K_b[j] for j in col_ind]
        W_V_b = [W_V_b[j] for j in col_ind]
        b_V_b = [b_V_b[j] for j in col_ind]
        W_O_b = [W_O_b[j] for j in col_ind]
    if optimize:
        metrics_A_all = {key: [] for key in ['objective_values', 'grad_norms', 'condition_nums']}
        metrics_B_all = {key: [] for key in ['objective_values', 'grad_norms', 'condition_nums']}
    aligned_params, return_dict = {}, {}
    if(config.position_embeddings in ["learnable","sinusoidal"]):
        for i in range(num_heads):
            result = additive_align_single_head(
                W_Q_a[i], b_Q_a[i], W_K_a[i], b_K_a[i], W_V_a[i], b_V_a[i], W_O_a[i],
                W_Q_b[i], b_Q_b[i], W_K_b[i], b_K_b[i], W_V_b[i], b_V_b[i], W_O_b[i], optimize
            )
            aligned_params[f'head_{i}'] = result['aligned_params']
            if optimize:
                for key in metrics_A_all:
                    metrics_A_all[key].append(result['metrics_A'][key])
                    metrics_B_all[key].append(result['metrics_B'][key])
        return_dict['aligned_params'] = merge_aligned_params(aligned_params, num_heads, query_a.shape[1], out_bias_b)
        if optimize:
            return_dict['metrics_A_all'] = metrics_A_all
            return_dict['metrics_B_all'] = metrics_B_all
    elif(config.position_embeddings in ["rope"]): 
        for i in range(num_heads):
            result = multiplytive_align_single_head(
                W_Q_a[i], b_Q_a[i], W_K_a[i], b_K_a[i], W_V_a[i], b_V_a[i], W_O_a[i],
                W_Q_b[i], b_Q_b[i], W_K_b[i], b_K_b[i], W_V_b[i], b_V_b[i], W_O_b[i], optimize
            )
            aligned_params[f'head_{i}'] = result['aligned_params']
        return_dict['aligned_params'] = merge_aligned_params(aligned_params, num_heads, query_a.shape[1], out_bias_b)
    return return_dict


def  align_ffn_params(rng, params_a, params_b, layer_idx, config):
    # Extract parameters and paths for both models
    ffn_a = params_a['transformer']['h'][str(layer_idx)]['moe']['shared_experts']
    ffn_b = params_b['transformer']['h'][str(layer_idx)]['moe']['shared_experts']
    W1_a, b1_a, W2_a, b1_b = extract_ffn_params(ffn_a)
    W1_b, b1_b, W2_b, b2_b = extract_ffn_params(ffn_b)
    # Convert to NumPy for computation
    W1_a, b1_a, W2_a = np.array(W1_a.T), np.array(b1_a), np.array(W2_a.T)
    W1_b, b1_b, W2_b = np.array(W1_b.T), np.array(b1_b), np.array(W2_b.T)
    D_hidden = W1_a.shape[1]
    C = np.zeros((D_hidden, D_hidden), dtype=np.float32)
    for i in range(D_hidden):
        # Incoming weights and bias for neuron i of model A
        in_a = np.concatenate([W1_a[:, i], [b1_a[i]]])
        # Outgoing weights for neuron i of model A
        out_a = W2_a[i, :]
        for j in range(D_hidden):
            in_b = np.concatenate([W1_b[:, j], [b1_b[j]]])
            out_b = W2_b[j, :]
            # Cost is the sum of squared Euclidean distances
            cost = np.linalg.norm(in_a - in_b)**2 + np.linalg.norm(out_a - out_b)**2
            C[i, j] = cost
    # Solve LAP. `col_ind` gives the permutation for model B's neurons.
    row_ind, col_ind = linear_sum_assignment(C)
    # Permute the weights of model B according to the solution
    W1_aligned = W1_b[:, col_ind].T
    b1_aligned = b1_b[col_ind]
    W2_aligned = W2_b[col_ind, :].T
    return {
        'aligned_params': {'c_fc': {'kernel':W1_aligned,'bias':b1_aligned}, 'c_proj': {'kernel':W2_aligned,'bias':b2_b}}
    }

def weight_matching(rng, params_a, params_b, config, args):
    params_dict = {}
    configurations = [
        ("permu_head_init_ortho_opt", 'ortho', True, True),
    ]
    for name, init_method, permute_heads, optimize in configurations:
        aligned_params = copy.deepcopy(params_b)
        if optimize:
            layer_to_metrics_A = {}
            layer_to_metrics_B = {}
        for layer_idx in config.lmc_layer_indices:
            attention_result = align_attention_params(
                rng, params_a, aligned_params, layer_idx, config, permute_heads=permute_heads, optimize=optimize, dist = args.dist
            )
            aligned_params['transformer']['h'][str(layer_idx)]['attn'] = attention_result['aligned_params']
            if hasattr(config, "finetune_mlp") and config.finetune_mlp:
                ffn_result = align_ffn_params(rng, params_a, aligned_params, layer_idx, config)
                aligned_params['transformer']['h'][str(layer_idx)]['moe']['shared_experts'] = ffn_result['aligned_params']
        total_sum = tree_util.tree_reduce(lambda acc, x: acc + jnp.sum(x), aligned_params, initializer=0)
        print(f"{name}: {total_sum}, sanity check")
        params_dict[name] = aligned_params
    return params_dict

