from collections import defaultdict
from typing import NamedTuple
from flax.core import freeze, unfreeze
import jax.numpy as jnp
from jax import random, tree_util, jit, grad, value_and_grad
from scipy.optimize import linear_sum_assignment, minimize
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import copy
import jax
import jax.lax as lax
import jax.nn as nn
import jax

'''num_heads = 4
    #print all layers
    layer_paths = []
    def collect_layer_paths(path, value):
        # Convert path to a readable string by joining path keys
        path_str = '/'.join([str(p.key) for p in path])
        shape = value.shape
        layer_paths.append((path_str, shape))

    jax.tree_util.tree_map_with_path(collect_layer_paths, pretrained_params)
    print("Layers of the model:")
    for path, shape in layer_paths:  
        print(f"  {path} {shape}")'''
'''example output
    Layers of the model:
    Conv_0/bias (32,)
    Conv_0/kernel (4, 4, 3, 32)
    Dense_0/bias (10,)
    Dense_0/kernel (32, 10)
    TransformerEncoderLayer_0/Dense_0/bias (128,)
    TransformerEncoderLayer_0/Dense_0/kernel (32, 128)
    TransformerEncoderLayer_0/Dense_1/bias (32,)
    TransformerEncoderLayer_0/Dense_1/kernel (128, 32)
    TransformerEncoderLayer_0/LayerNorm_0/bias (32,)
    TransformerEncoderLayer_0/LayerNorm_0/scale (32,)
    TransformerEncoderLayer_0/LayerNorm_1/bias (32,)
    TransformerEncoderLayer_0/LayerNorm_1/scale (32,)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/key/bias (4, 8)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/key/kernel (32, 4, 8) #Key projections for 4 attention heads, each with a dimension of 8 (4 heads x 8 = 32, matching the model's hidden size).
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/out/bias (32,)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/out/kernel (4, 8, 32)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/query/bias (4, 8)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/query/kernel (32, 4, 8)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/value/bias (4, 8)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/value/kernel (32, 4, 8)
    TransformerEncoderLayer_1/Dense_0/bias (128,)
    TransformerEncoderLayer_1/Dense_0/kernel (32, 128)
    TransformerEncoderLayer_1/Dense_1/bias (32,)
    TransformerEncoderLayer_1/Dense_1/kernel (128, 32)
    TransformerEncoderLayer_1/LayerNorm_0/bias (32,)
    TransformerEncoderLayer_1/LayerNorm_0/scale (32,)
    TransformerEncoderLayer_1/LayerNorm_1/bias (32,)
    TransformerEncoderLayer_1/LayerNorm_1/scale (32,)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/key/bias (4, 8)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/key/kernel (32, 4, 8)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/out/bias (32,)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/out/kernel (4, 8, 32)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/query/bias (4, 8)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/query/kernel (32, 4, 8)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/value/bias (4, 8)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/value/kernel (32, 4, 8)
    TransformerEncoderLayer_2/Dense_0/bias (128,)
    TransformerEncoderLayer_2/Dense_0/kernel (32, 128)
    TransformerEncoderLayer_2/Dense_1/bias (32,)
    TransformerEncoderLayer_2/Dense_1/kernel (128, 32)
    TransformerEncoderLayer_2/LayerNorm_0/bias (32,)
    TransformerEncoderLayer_2/LayerNorm_0/scale (32,)
    TransformerEncoderLayer_2/LayerNorm_1/bias (32,)
    TransformerEncoderLayer_2/LayerNorm_1/scale (32,)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/key/bias (4, 8)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/key/kernel (32, 4, 8)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/out/bias (32,)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/out/kernel (4, 8, 32)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/query/bias (4, 8)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/query/kernel (32, 4, 8)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/value/bias (4, 8)
    TransformerEncoderLayer_2/MultiHeadDotProductAttention_0/value/kernel (32, 4, 8)
    cls_token (1, 1, 32)
    pos_embedding (1, 65, 32)
'''

def to_numpy(x):
    if isinstance(x, jnp.ndarray):
        return np.array(x)
    return np.array(x)

@jit
def compute_objective_jax(A, X, X_prime, Y, Y_prime, cond_threshold=1e6):
    cond = jnp.linalg.cond(A)
    def safe_obj():
        A_inv = jnp.linalg.inv(A)
        term1 = X - X_prime @ A.T
        term2 = Y - Y_prime @ A_inv
        return jnp.sum(term1**2) + jnp.sum(term2**2)
    return lax.cond(cond > cond_threshold, lambda: jnp.inf, safe_obj)

compute_value_and_grad_jax = jit(value_and_grad(compute_objective_jax))

def solve_orthogonal(X, X_prime, Y, Y_prime):
    B = X.T @ X_prime + Y.T @ Y_prime
    U, _, Vt = np.linalg.svd(B)
    return U @ Vt 

def optimize_alignment(A_init, X, X_prime, Y, Y_prime, max_iter=5000):
    objective_values = []
    grad_norms = []
    condition_nums = []

    def obj_fn(flat_A):
        A = flat_A.reshape(A_init.shape)
        obj, grad_val = compute_value_and_grad_jax(jnp.array(A), jnp.array(X), jnp.array(X_prime), jnp.array(Y), jnp.array(Y_prime))
        return float(obj), np.array(grad_val).flatten()

    def callback(flat_A):
        A = flat_A.reshape(A_init.shape)
        obj, grad_val = compute_value_and_grad_jax(jnp.array(A), jnp.array(X), jnp.array(X_prime), jnp.array(Y), jnp.array(Y_prime))
        grad_norm = jnp.linalg.norm(grad_val, 'fro')
        cond = jnp.linalg.cond(jnp.array(A))
        objective_values.append(float(obj))
        grad_norms.append(float(grad_norm))
        condition_nums.append(float(cond))

    res = minimize(obj_fn, A_init.flatten(), jac=True, method='L-BFGS-B', options={'maxiter': max_iter}, callback=callback)
    A_opt = res.x.reshape(A_init.shape)
    return A_opt, objective_values, grad_norms, condition_nums

def get_nested_item(d, keys):
    """Accesses a nested dictionary item using a tuple of keys."""
    for key in keys:
        d = d[key]
    return d

# Helper functions for parameter extraction and reshaping
def extract_attention_params(params, layer_idx):
    """
    Extracts MHA parameters from different, known model structures in a compatible way.

    This function detects the model type and constructs the correct path to the
    attention parameters for a given layer index.

    Args:
        params: The Flax parameter tree.
        layer_idx: The integer index of the transformer layer.

    Returns:
        A tuple containing:
        - A flat tuple of the MHA tensors: (key, key_bias, query, query_bias, value, value_bias, out, out_bias).
        - A tuple representing the nested path to the MHA block, for use in updates.
    """
    # Detect vit-jax style: params['Transformer']['encoderblock_...']
    if 'Transformer' in params and f'encoderblock_{layer_idx}' in params['Transformer']:
        mha_path = ('Transformer', f'encoderblock_{layer_idx}', 'MultiHeadDotProductAttention_0')
    # Detect cifar_vit style: params['TransformerEncoderLayer_...']
    elif f'TransformerEncoderLayer_{layer_idx}' in params:
        mha_path = (f'TransformerEncoderLayer_{layer_idx}', 'MultiHeadDotProductAttention_0')
    else:
        raise KeyError(f"Could not find a known path for attention layer {layer_idx} in the provided params.")

    attention_block = get_nested_item(params, mha_path)

    key_k, key_b = attention_block['key']['kernel'], attention_block['key']['bias']
    query_k, query_b = attention_block['query']['kernel'], attention_block['query']['bias']
    value_k, value_b = attention_block['value']['kernel'], attention_block['value']['bias']
    out_k, out_b = attention_block['out']['kernel'], attention_block['out']['bias']
    
    return (key_k, key_b, query_k, query_b, value_k, value_b, out_k, out_b), mha_path

def reshape_to_per_head(params, num_heads):
    """
    Reshapes batched attention parameters into a list of per-head parameters.

    This function assumes a specific shape convention for the input weight and
    bias tensors, which is common in Flax/Linen implementations.

    Args:
        params (dict): A dictionary containing the attention parameters.
            Expected keys and tensor shapes are:
            - 'query': Weight tensor of shape (D, num_heads, d_k)
            - 'query_bias': Bias tensor of shape (num_heads, d_k)
            - 'key': Weight tensor of shape (D, num_heads, d_k)
            - 'key_bias': Bias tensor of shape (num_heads, d_k)
            - 'value': Weight tensor of shape (D, num_heads, d_v)
            - 'value_bias': Bias tensor of shape (num_heads, d_v)
            - 'out': Weight tensor of shape (num_heads, d_v, D)
        num_heads (int): The number of attention heads.

    Returns:
        A tuple containing lists of per-head parameters:
        (W_Q, b_Q, W_K, b_K, W_V, b_V, W_O)
    """

    query_kernel = params['query']
    assert query_kernel.ndim == 3, f"Expected query weights to be 3D, but got shape {query_kernel.shape}"
    assert query_kernel.shape[1] == num_heads, (
        f"The second dimension of the query weight tensor should be num_heads ({num_heads}), "
        f"but got shape {query_kernel.shape}. Please verify your model's parameter shape convention."
    )

    W_Q = [params['query'][:, i, :] for i in range(num_heads)]
    b_Q = [params['query_bias'][i, :] for i in range(num_heads)]
    W_K = [params['key'][:, i, :] for i in range(num_heads)]
    b_K = [params['key_bias'][i, :] for i in range(num_heads)]
    W_V = [params['value'][:, i, :] for i in range(num_heads)]
    b_V = [params['value_bias'][i, :] for i in range(num_heads)]
    W_O = [params['out'][i, :, :] for i in range(num_heads)]
    
    return W_Q, b_Q, W_K, b_K, W_V, b_V, W_O

def compute_extended_weights(W, b):
    return jnp.vstack([jnp.array(W), jnp.array(b).reshape(1, -1)])

# Helper function to plot multiple curves
def plot_multiple_curves(data_list, title, xlabel, ylabel, labels, save_path):
    plt.figure()
    for data, label in zip(data_list, labels):
        label = f"{label} ({data[-1]:.4f})"
        plt.plot(data, label=label)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()

# Stage 1 Function: Find Heads Permutation (Data-Dependent)
# Version 1: Using post-softmax probabilities
def compute_cost_matrix_postsoftmax(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                               W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                               num_heads, activations_a, activations_b, alpha=0.5, epsilon=1e-8):
    """
    Computes the cost matrix using post-softmax probabilities and model-specific activations.
    """
    B_a, L_a, D_a = activations_a.shape
    B_b, L_b, D_b = activations_b.shape
    assert B_a == B_b and L_a == L_b and D_a == D_b, "Activations for both models must have the same shape."

    # Augment activations for each model separately
    X_tilde_a = jnp.concatenate([activations_a, jnp.ones((B_a, L_a, 1))], axis=-1)
    X_tilde_b = jnp.concatenate([activations_b, jnp.ones((B_b, L_b, 1))], axis=-1)

    d_head = W_Q_a[0].shape[1]
    sqrt_d = jnp.sqrt(float(d_head))
    C = np.zeros((num_heads, num_heads))

    # Pre-compute flattened outputs for model A
    P_flat_a, V_flat_a = [], []
    for i in range(num_heads):
        tilde_W_Q_a_i = compute_extended_weights(W_Q_a[i], b_Q_a[i])
        tilde_W_K_a_i = compute_extended_weights(W_K_a[i], b_K_a[i])
        tilde_W_V_a_i = compute_extended_weights(W_V_a[i], b_V_a[i])
        Q_a_i = X_tilde_a @ tilde_W_Q_a_i
        K_a_i = X_tilde_a @ tilde_W_K_a_i
        S_a_i = jnp.einsum('bld,bmd->blm', Q_a_i, K_a_i) / sqrt_d
        P_a_i = nn.softmax(S_a_i, axis=-1)
        P_flat_a.append(P_a_i.flatten())
        V_a_i = (X_tilde_a @ tilde_W_V_a_i) @ W_O_a[i]
        V_flat_a.append(V_a_i.flatten())

    # Pre-compute flattened outputs for model B
    P_flat_b, V_flat_b = [], []
    for j in range(num_heads):
        tilde_W_Q_b_j = compute_extended_weights(W_Q_b[j], b_Q_b[j])
        tilde_W_K_b_j = compute_extended_weights(W_K_b[j], b_K_b[j])
        tilde_W_V_b_j = compute_extended_weights(W_V_b[j], b_V_b[j])
        Q_b_j = X_tilde_b @ tilde_W_Q_b_j
        K_b_j = X_tilde_b @ tilde_W_K_b_j
        S_b_j = jnp.einsum('bld,bmd->blm', Q_b_j, K_b_j) / sqrt_d
        P_b_j = nn.softmax(S_b_j, axis=-1)
        P_flat_b.append(P_b_j.flatten())
        V_b_j = (X_tilde_b @ tilde_W_V_b_j) @ W_O_b[j]
        V_flat_b.append(V_b_j.flatten())

    # Compute cost matrix from pre-computed values
    for i in range(num_heads):
        for j in range(num_heads):
            # Cosine similarity for P (post-softmax probabilities)
            dot_P = jnp.dot(P_flat_a[i], P_flat_b[j])
            norm_P_a = jnp.linalg.norm(P_flat_a[i])
            norm_P_b = jnp.linalg.norm(P_flat_b[j])
            cost_P = 1.0 - (dot_P / (norm_P_a * norm_P_b + epsilon))

            # Cosine similarity for V (value-projections)
            dot_V = jnp.dot(V_flat_a[i], V_flat_b[j])
            norm_V_a = jnp.linalg.norm(V_flat_a[i])
            norm_V_b = jnp.linalg.norm(V_flat_b[j])
            cost_V = 1.0 - (dot_V / (norm_V_a * norm_V_b + epsilon))
            
            C[i, j] = alpha * cost_P + (1 - alpha) * cost_V

    return C

# Version 2: Using pre-softmax scores
def compute_cost_matrix_presoftmax(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                               W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                               num_heads, activations_a, activations_b, alpha=0.5, epsilon=1e-8):
    """
    Computes the cost matrix for attention head permutation using model-specific activations.
    """
    B_a, L_a, D_a = activations_a.shape
    B_b, L_b, D_b = activations_b.shape
    assert B_a == B_b and L_a == L_b and D_a == D_b, "Activations for both models must have the same shape."
    
    # Augment activations for model A
    ones_col_a = jnp.ones((B_a, L_a, 1))
    X_tilde_a = jnp.concatenate([activations_a, ones_col_a], axis=-1)

    # Augment activations for model B
    ones_col_b = jnp.ones((B_b, L_b, 1))
    X_tilde_b = jnp.concatenate([activations_b, ones_col_b], axis=-1)

    d_head = W_Q_a[0].shape[1]
    sqrt_d = jnp.sqrt(float(d_head))
    C = np.zeros((num_heads, num_heads))

    # Pre-compute all head outputs for model A
    S_bar_flat_a, V_flat_a = [], []
    for i in range(num_heads):
        tilde_W_Q_a_i = compute_extended_weights(W_Q_a[i], b_Q_a[i])
        tilde_W_K_a_i = compute_extended_weights(W_K_a[i], b_K_a[i])
        tilde_W_V_a_i = compute_extended_weights(W_V_a[i], b_V_a[i])
        Q_a_i = X_tilde_a @ tilde_W_Q_a_i
        K_a_i = X_tilde_a @ tilde_W_K_a_i
        S_a_i = jnp.einsum('bld,bmd->blm', Q_a_i, K_a_i) / sqrt_d
        S_bar_a_i = S_a_i - jnp.mean(S_a_i, axis=2, keepdims=True)
        S_bar_flat_a.append(S_bar_a_i.flatten())
        V_tilde_a_i = X_tilde_a @ tilde_W_V_a_i
        V_a_i = V_tilde_a_i @ W_O_a[i]
        V_flat_a.append(V_a_i.flatten())

    # Pre-compute all head outputs for model B
    S_bar_flat_b, V_flat_b = [], []
    for j in range(num_heads):
        tilde_W_Q_b_j = compute_extended_weights(W_Q_b[j], b_Q_b[j])
        tilde_W_K_b_j = compute_extended_weights(W_K_b[j], b_K_b[j])
        tilde_W_V_b_j = compute_extended_weights(W_V_b[j], b_V_b[j])
        Q_b_j = X_tilde_b @ tilde_W_Q_b_j
        K_b_j = X_tilde_b @ tilde_W_K_b_j
        S_b_j = jnp.einsum('bld,bmd->blm', Q_b_j, K_b_j) / sqrt_d
        S_bar_b_j = S_b_j - jnp.mean(S_b_j, axis=2, keepdims=True)
        S_bar_flat_b.append(S_bar_b_j.flatten())
        V_tilde_b_j = X_tilde_b @ tilde_W_V_b_j
        V_b_j = V_tilde_b_j @ W_O_b[j]
        V_flat_b.append(V_b_j.flatten())

    # Compute cost matrix from pre-computed values
    for i in range(num_heads):
        for j in range(num_heads):
            # Cosine similarity for S
            dot_S = jnp.dot(S_bar_flat_a[i], S_bar_flat_b[j])
            norm_S_a = jnp.linalg.norm(S_bar_flat_a[i])
            norm_S_b = jnp.linalg.norm(S_bar_flat_b[j])
            cos_sim_S = dot_S / (norm_S_a * norm_S_b + epsilon)
            cost_S = 1.0 - cos_sim_S
            # Cosine similarity for V
            dot_V = jnp.dot(V_flat_a[i], V_flat_b[j])
            norm_V_a = jnp.linalg.norm(V_flat_a[i])
            norm_V_b = jnp.linalg.norm(V_flat_b[j])
            cos_sim_V = dot_V / (norm_V_a * norm_V_b + epsilon)
            cost_V = 1.0 - cos_sim_V
            C[i, j] = (alpha * cost_S + (1 - alpha) * cost_V)
    return C

def compute_cost_matrix_data_independent(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                        W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                        num_heads, alpha=0.5):
    C = np.zeros((num_heads, num_heads))
    for i in range(num_heads):
        tilde_W_Q_a_i = compute_extended_weights(W_Q_a[i], b_Q_a[i])
        tilde_W_K_a_i = compute_extended_weights(W_K_a[i], b_K_a[i])
        tilde_W_V_a_i = compute_extended_weights(W_V_a[i], b_V_a[i])
        QKT_a_i = tilde_W_Q_a_i @ tilde_W_K_a_i.T
        VO_a_i = tilde_W_V_a_i @ W_O_a[i]
        centered_QKT_a_i = QKT_a_i - np.mean(QKT_a_i, axis=1, keepdims=True)
        for j in range(num_heads):
            tilde_W_Q_b_j = compute_extended_weights(W_Q_b[j], b_Q_b[j])
            tilde_W_K_b_j = compute_extended_weights(W_K_b[j], b_K_b[j])
            tilde_W_V_b_j = compute_extended_weights(W_V_b[j], b_V_b[j])
            QKT_b_j = tilde_W_Q_b_j @ tilde_W_K_b_j.T
            VO_b_j = tilde_W_V_b_j @ W_O_b[j]
            centered_QKT_b_j = QKT_b_j - np.mean(QKT_b_j, axis=1, keepdims=True)
            cost = alpha * np.sum((centered_QKT_a_i - centered_QKT_b_j)**2) + (1 - alpha) * np.sum((VO_a_i - VO_b_j)**2)
            C[i, j] = cost
    return C

def find_heads_permutation(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                           W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                           num_heads, activations_a, activations_b, alpha, data_independent):
    if data_independent:
        C = compute_cost_matrix_data_independent(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                                                     W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                                                     num_heads, alpha=alpha)
    else:
        C = compute_cost_matrix_presoftmax(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                            W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                            num_heads, activations_a, activations_b, alpha)    
    row_ind, col_ind = linear_sum_assignment(to_numpy(C))
    print({int(i): int(j) for i, j in zip(row_ind, col_ind)})
    return row_ind, col_ind

# Stage 2 Function: Align Single Head (with optional optimization)
def align_single_head(W_Q_a_i, b_Q_a_i, W_K_a_i, b_K_a_i, W_V_a_i, b_V_a_i, W_O_a_i,
                      W_Q_b_i, b_Q_b_i, W_K_b_i, b_K_b_i, W_V_b_i, b_V_b_i, W_O_b_i,
                      init_method, optimize):
    tilde_W_Q_a_i = compute_extended_weights(W_Q_a_i, b_Q_a_i)
    tilde_W_K_a_i = compute_extended_weights(W_K_a_i, b_K_a_i)
    tilde_W_V_a_i = compute_extended_weights(W_V_a_i, b_V_a_i)
    Y_O_a_i = W_O_a_i.T
    tilde_W_Q_b_i = compute_extended_weights(W_Q_b_i, b_Q_b_i)
    tilde_W_K_b_i = compute_extended_weights(W_K_b_i, b_K_b_i)
    tilde_W_V_b_i = compute_extended_weights(W_V_b_i, b_V_b_i)
    Y_O_b_i = W_O_b_i.T
    
    if init_method == 'ortho':
        A_init = solve_orthogonal(tilde_W_Q_a_i, tilde_W_Q_b_i, tilde_W_K_a_i, tilde_W_K_b_i)
        B_init = solve_orthogonal(Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i)
    elif init_method == 'random':
        while True:
            A_init = np.random.normal(loc=1, scale=1, size=(tilde_W_Q_a_i.shape[1], tilde_W_Q_a_i.shape[1]))
            if np.linalg.det(A_init) != 0:
                break
        while True:
            B_init = np.random.normal(loc=1, scale=1, size=(tilde_W_V_a_i.shape[1], tilde_W_V_a_i.shape[1]))
            if np.linalg.det(B_init) != 0:
                break
    elif init_method == 'identity':
        A_init = np.eye(tilde_W_Q_a_i.shape[1])
        B_init = np.eye(tilde_W_V_a_i.shape[1])
    else:
        raise ValueError("Invalid initialization method")
    
    if optimize:
        A, objective_values_A, grad_norms_A, condition_nums_A = optimize_alignment(
            A_init, tilde_W_Q_a_i, tilde_W_Q_b_i, tilde_W_K_a_i, tilde_W_K_b_i
        )
        B, objective_values_B, grad_norms_B, condition_nums_B = optimize_alignment(
            B_init, Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i
        )
    else:
        A = A_init
        B = B_init
    
    A_inv = np.linalg.inv(A)
    B_inv = np.linalg.inv(B)
    W_Q_aligned = W_Q_b_i @ A.T
    b_Q_aligned = b_Q_b_i @ A.T
    W_K_aligned = W_K_b_i @ A_inv
    b_K_aligned = b_K_b_i @ A_inv
    W_V_aligned = W_V_b_i @ B_inv
    b_V_aligned = b_V_b_i @ B_inv
    W_O_aligned = B @ W_O_b_i
    
    aligned_params = {
        'query': {'kernel': W_Q_aligned, 'bias': b_Q_aligned},
        'key': {'kernel': W_K_aligned, 'bias': b_K_aligned},
        'value': {'kernel': W_V_aligned, 'bias': b_V_aligned},
        'out': {'kernel': W_O_aligned}
    }
    
    if optimize:
        return {
            'aligned_params': aligned_params,
            'metrics_A': {
                'objective_values': objective_values_A,
                'grad_norms': grad_norms_A,
                'condition_nums': condition_nums_A
            },
            'metrics_B': {
                'objective_values': objective_values_B,
                'grad_norms': grad_norms_B,
                'condition_nums': condition_nums_B
            }
        }
    return {'aligned_params': aligned_params}

def align_attention_params_main(rng, params_a, params_b, layer_idx, num_heads, 
                                activations_for_layer_a, activations_for_layer_b, plot_path=None, init_method='ortho', permute_heads=True, optimize=True, method_name="", alpha=0.5, data_independent=False):

    params_a_extracted, _ = extract_attention_params(params_a, layer_idx)
    params_b_extracted, mha_path_b = extract_attention_params(params_b, layer_idx)
    
    params_a_np = {k: np.array(v) for k, v in zip(['key', 'key_bias', 'query', 'query_bias', 
                                                   'value', 'value_bias', 'out', 'out_bias'], params_a_extracted)}
    params_b_np = {k: np.array(v) for k, v in zip(['key', 'key_bias', 'query', 'query_bias', 
                                                   'value', 'value_bias', 'out', 'out_bias'], params_b_extracted)}
    
    W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a = reshape_to_per_head(params_a_np, num_heads)
    W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b = reshape_to_per_head(params_b_np, num_heads)
    
    if permute_heads:
        row_ind, col_ind = find_heads_permutation(
            W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
            W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
            num_heads, activations_for_layer_a, activations_for_layer_b, alpha, data_independent
        )
    
        W_Q_b = [W_Q_b[j] for j in col_ind]
        b_Q_b = [b_Q_b[j] for j in col_ind]
        W_K_b = [W_K_b[j] for j in col_ind]
        b_K_b = [b_K_b[j] for j in col_ind]
        W_V_b = [W_V_b[j] for j in col_ind]
        b_V_b = [b_V_b[j] for j in col_ind]
        W_O_b = [W_O_b[j] for j in col_ind]

    if optimize:
        metrics_A_all = {key: [] for key in ['objective_values', 'grad_norms', 'condition_nums']}
        metrics_B_all = {key: [] for key in ['objective_values', 'grad_norms', 'condition_nums']}
    
    aligned_params = {}
    for i in range(num_heads):
        result = align_single_head(
            W_Q_a[i], b_Q_a[i], W_K_a[i], b_K_a[i], W_V_a[i], b_V_a[i], W_O_a[i],
            W_Q_b[i], b_Q_b[i], W_K_b[i], b_K_b[i], W_V_b[i], b_V_b[i], W_O_b[i],
            init_method, optimize
        )
        aligned_params[f'head_{i}'] = result['aligned_params']
        if optimize:
            for key in metrics_A_all:
                metrics_A_all[key].append(result['metrics_A'][key])
                metrics_B_all[key].append(result['metrics_B'][key])
    
    query_kernel = np.stack([aligned_params[f'head_{i}']['query']['kernel'] for i in range(num_heads)], axis=1)
    query_bias = np.stack([aligned_params[f'head_{i}']['query']['bias'] for i in range(num_heads)], axis=0)
    key_kernel = np.stack([aligned_params[f'head_{i}']['key']['kernel'] for i in range(num_heads)], axis=1)
    key_bias = np.stack([aligned_params[f'head_{i}']['key']['bias'] for i in range(num_heads)], axis=0)
    value_kernel = np.stack([aligned_params[f'head_{i}']['value']['kernel'] for i in range(num_heads)], axis=1)
    value_bias = np.stack([aligned_params[f'head_{i}']['value']['bias'] for i in range(num_heads)], axis=0)
    out_kernel = np.stack([aligned_params[f'head_{i}']['out']['kernel'] for i in range(num_heads)], axis=0)

    return_dict = {
        'aligned_params': {
            'query': {'kernel': jnp.array(query_kernel), 'bias': jnp.array(query_bias)},
            'key': {'kernel': jnp.array(key_kernel), 'bias': jnp.array(key_bias)},
            'value': {'kernel': jnp.array(value_kernel), 'bias': jnp.array(value_bias)},
            'out': {'kernel': jnp.array(out_kernel), 'bias': params_b_np['out_bias']}
        },
        'mha_path': mha_path_b # Return the path for updating model_b
    }
    if optimize:
        return_dict['metrics_A_all'] = metrics_A_all
        return_dict['metrics_B_all'] = metrics_B_all
    return return_dict

def matching_attn(rng, params_a, params_b, activations_a, activations_b, finetune_layer_which, num_heads, plot_path):
    params_dict = {}
    configurations = [
        ("data_indep_permu_head_init_ortho_no_opt", True, 'ortho', True, False),
        #("data_dep_permu_head_init_ortho_no_opt", False, 'ortho', True, False),
        #("data_dep_permu_head_init_ortho_opt", 'ortho', True, True),
    ]
    
    for name, data_independent, init_method, permute_heads, optimize in configurations:
        aligned_params = copy.deepcopy(params_b)
        if optimize:
            layer_to_metrics_A = {}
            layer_to_metrics_B = {}
        for layer_idx in finetune_layer_which:
            activations_for_layer_a = activations_a[layer_idx]
            activations_for_layer_b = activations_b[layer_idx]

            result = align_attention_params_main(
                rng, params_a, aligned_params, layer_idx, num_heads, activations_for_layer_a, activations_for_layer_b, plot_path=None,
                init_method=init_method, permute_heads=permute_heads, optimize=optimize, method_name=name, data_independent=data_independent
            )

            # Update the aligned_params tree using the path returned by the alignment function.
            unfrozen_params = unfreeze(aligned_params)
            
            # Navigate to the parent dictionary of the MHA block
            temp_dict = unfrozen_params
            for key in result['mha_path'][:-1]:
                temp_dict = temp_dict[key]
            
            # Update the MHA block with the aligned parameters
            temp_dict[result['mha_path'][-1]] = result['aligned_params']
            
            aligned_params = freeze(unfrozen_params)

            if optimize:
                layer_to_metrics_A[layer_idx] = result['metrics_A_all']
                layer_to_metrics_B[layer_idx] = result['metrics_B_all']
        
        total_sum = tree_util.tree_reduce(lambda acc, x: acc + jnp.sum(x), aligned_params, initializer=0)
        print(f"{name}: {total_sum}, sanity check")
        params_dict[name] = aligned_params


        if optimize and plot_path:
            os.makedirs(plot_path, exist_ok=True)
            num_layers = len(finetune_layer_which)
            layers = finetune_layer_which
            metric_keys = ['objective_values', 'grad_norms', 'condition_nums']
            for metric_key in metric_keys:
                fig, axs = plt.subplots(num_layers + 1, 2, figsize=(20, 5 * (num_layers + 1)), sharex='col')
                for col in range(2):
                    if col == 0:
                        metrics_per_layer = layer_to_metrics_A
                        alignment_type = "Query/Key Alignment"
                    else:
                        metrics_per_layer = layer_to_metrics_B
                        alignment_type = "Value/Out Alignment"
                    
                    # Plot per-layer subplots
                    for row in range(num_layers):
                        layer = layers[row]
                        data_list = metrics_per_layer[layer][metric_key]
                        labels = [f"Head {i}" for i in range(num_heads)]
                        ax = axs[row, col]
                        for data, label in zip(data_list, labels):
                            if data:
                                final_val = data[-1]
                                ax.plot(data, label=f"{label} ({final_val:.4f})")
                        ax.set_title(f"Layer {layer}: {metric_key.replace('_', ' ').capitalize()} - {alignment_type}")
                        ax.set_xlabel('Iteration')
                        ax.set_ylabel(metric_key.replace('_', ' ').capitalize())
                        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
                    
                    # Bottom row: mean across heads for all layers
                    ax = axs[num_layers, col]
                    data_list = []
                    labels = []
                    for layer in layers:
                        head_data = metrics_per_layer[layer][metric_key]
                        if head_data:
                            max_len = max(len(d) for d in head_data if d)
                            padded = []
                            for d in head_data:
                                if d:
                                    if len(d) < max_len:
                                        last = d[-1]
                                        padded.append(d + [last] * (max_len - len(d)))
                                    else:
                                        padded.append(d)
                            if padded:
                                mean_data = np.mean(padded, axis=0).tolist()
                                data_list.append(mean_data)
                                final_mean = mean_data[-1]
                                labels.append(f"Layer {layer} ({final_mean:.4f})")
                    for data, label in zip(data_list, labels):
                        ax.plot(data, label=label)
                    ax.set_title(f"All Layers Mean: {metric_key.replace('_', ' ').capitalize()} - {alignment_type}")
                    ax.set_xlabel('Iteration')
                    ax.set_ylabel(metric_key.replace('_', ' ').capitalize())
                    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
                
                plt.tight_layout()
                save_path = os.path.join(plot_path, f"{name}_{metric_key}.png")
                plt.savefig(save_path, bbox_inches='tight')
                plt.close()

    return params_dict

#############RoPE#################
def get_rope_matrix(seq_len, d_head):
    """Generates RoPE rotation matrices R[m] of shape (seq_len, d_head/2, 2, 2)."""
    assert d_head % 2 == 0, "d_head must be even"
    # inv_freq: (d_head/2,)
    inv_freq = 1.0 / (10000 ** (jnp.arange(0, d_head, 2) / d_head))
    t = jnp.arange(seq_len)                 # (seq_len,)
    freqs = jnp.einsum('i,j->ij', t, inv_freq)   # (seq_len, d_head/2)

    cos_freqs = jnp.cos(freqs)   # (seq_len, h)
    sin_freqs = jnp.sin(freqs)   # (seq_len, h)

    # Build rotation matrices per position and subspace: shape (seq_len, h, 2, 2)
    # Each 2x2 is [[cos, -sin], [sin, cos]]
    R = jnp.stack(
        [
            jnp.stack([cos_freqs, -sin_freqs], axis=-1),  # (seq_len, h, 2) -> first row entries
            jnp.stack([sin_freqs,  cos_freqs], axis=-1),  # (seq_len, h, 2) -> second row entries
        ],
        axis=-2
    )  # After this stack: shape (seq_len, h, 2, 2)
    return R

@jax.jit
def apply_rope(x, R):
    """Applies RoPE to x of shape (B, L, D_k) using R shape (L, D_k/2, 2, 2)."""
    B, L, Dk = x.shape
    assert Dk % 2 == 0
    x_pairs = x.reshape((B, L, Dk//2, 2))   # (B, L, h, 2)
    # R must be (L, h, 2, 2)
    # einsum: 'b l h c, l h c r -> b l h r' -> back to (B, L, h, 2)
    x_rotated = jnp.einsum('blhc,lhcr->blhr', x_pairs, R)
    return x_rotated.reshape((B, L, Dk))

def compute_cost_matrix_presoftmax_rope(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                                        W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                                        num_heads, activations_a, activations_b, alpha=0.5, epsilon=1e-8):
    B_a, L_a, D_a = activations_a.shape
    B_b, L_b, D_b = activations_b.shape
    assert B_a == B_b and L_a == L_b and D_a == D_b
    L, d_head = L_a, W_Q_a[0].shape[1]
    
    rope_matrices = get_rope_matrix(L, d_head)
    ones_col_a, ones_col_b = jnp.ones((B_a, L_a, 1)), jnp.ones((B_b, L_b, 1))
    X_tilde_a = jnp.concatenate([activations_a, ones_col_a], axis=-1)
    X_tilde_b = jnp.concatenate([activations_b, ones_col_b], axis=-1)
    sqrt_d = jnp.sqrt(float(d_head))
    C = np.zeros((num_heads, num_heads))

    # Pre-compute for model A
    S_bar_flat_a, V_flat_a = [], []
    for i in range(num_heads):
        tilde_W_Q_a_i, tilde_W_K_a_i, tilde_W_V_a_i = compute_extended_weights(W_Q_a[i], b_Q_a[i]), compute_extended_weights(W_K_a[i], b_K_a[i]), compute_extended_weights(W_V_a[i], b_V_a[i])
        Q_rope_a_i, K_rope_a_i = apply_rope(X_tilde_a @ tilde_W_Q_a_i, rope_matrices), apply_rope(X_tilde_a @ tilde_W_K_a_i, rope_matrices)
        S_a_i = jnp.einsum('bld,bmd->blm', Q_rope_a_i, K_rope_a_i) / sqrt_d
        S_bar_flat_a.append((S_a_i - jnp.mean(S_a_i, axis=2, keepdims=True)).flatten())
        V_flat_a.append((X_tilde_a @ tilde_W_V_a_i @ W_O_a[i]).flatten())

    # Pre-compute for model B
    S_bar_flat_b, V_flat_b = [], []
    for j in range(num_heads):
        tilde_W_Q_b_j, tilde_W_K_b_j, tilde_W_V_b_j = compute_extended_weights(W_Q_b[j], b_Q_b[j]), compute_extended_weights(W_K_b[j], b_K_b[j]), compute_extended_weights(W_V_b[j], b_V_b[j])
        Q_rope_b_j, K_rope_b_j = apply_rope(X_tilde_b @ tilde_W_Q_b_j, rope_matrices), apply_rope(X_tilde_b @ tilde_W_K_b_j, rope_matrices)
        S_b_j = jnp.einsum('bld,bmd->blm', Q_rope_b_j, K_rope_b_j) / sqrt_d
        S_bar_flat_b.append((S_b_j - jnp.mean(S_b_j, axis=2, keepdims=True)).flatten())
        V_flat_b.append((X_tilde_b @ tilde_W_V_b_j @ W_O_b[j]).flatten())
        
    for i in range(num_heads):
        for j in range(num_heads):
            dot_S = jnp.dot(S_bar_flat_a[i], S_bar_flat_b[j])
            norm_S_a, norm_S_b = jnp.linalg.norm(S_bar_flat_a[i]), jnp.linalg.norm(S_bar_flat_b[j])
            cost_S = 1.0 - (dot_S / (norm_S_a * norm_S_b + epsilon))
            dot_V = jnp.dot(V_flat_a[i], V_flat_b[j])
            norm_V_a, norm_V_b = jnp.linalg.norm(V_flat_a[i]), jnp.linalg.norm(V_flat_b[j])
            cost_V = 1.0 - (dot_V / (norm_V_a * norm_V_b + epsilon))
            C[i, j] = (alpha * cost_S + (1 - alpha) * cost_V)
    return C

def find_heads_permutation_rope(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                                W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                                num_heads, activations_a, activations_b, alpha, data_independent):
    """Wrapper to find permutation using the RoPE cost matrix with model-specific activations."""
    if data_independent:
        C = compute_cost_matrix_data_independent(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                                                     W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                                                     num_heads, alpha=alpha)
    else:
        C = compute_cost_matrix_presoftmax_rope(
            W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
            W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
            num_heads, activations_a, activations_b, alpha
        )    
    row_ind, col_ind = linear_sum_assignment(to_numpy(C))
    print("RoPE Head Permutation:", {int(i): int(j) for i, j in zip(row_ind, col_ind)})
    return row_ind, col_ind

from scipy.optimize import minimize_scalar
from scipy.linalg import block_diag
from math import sqrt, cos, sin, atan2
import numpy as np

def solve_rope_qk_alignment(W_Q_a_i, b_Q_a_i, W_K_a_i, b_K_a_i,
                            W_Q_b_i, b_Q_b_i, W_K_b_i, b_K_b_i):
    """
    Solves for the G_RoPE alignment matrix U for a single head's QK weights.
    """
    tilde_W_Q_a = compute_extended_weights(W_Q_a_i, b_Q_a_i)
    tilde_W_K_a = compute_extended_weights(W_K_a_i, b_K_a_i)
    tilde_W_Q_b = compute_extended_weights(W_Q_b_i, b_Q_b_i)
    tilde_W_K_b = compute_extended_weights(W_K_b_i, b_K_b_i)
    
    D_k = tilde_W_Q_a.shape[1]
    assert D_k % 2 == 0, "Head dimension must be even for RoPE."

    U_blocks = []
    J = jnp.array([[0, -1], [1, 0]])

    for j in range(D_k // 2):
        # 1. Slice submatrices for the j-th 2D subspace
        sl = slice(2 * j, 2 * j + 2)
        Q_a_j, Q_b_j = tilde_W_Q_a[:, sl], tilde_W_Q_b[:, sl]
        K_a_j, K_b_j = tilde_W_K_a[:, sl], tilde_W_K_b[:, sl]

        # 2. Precompute constants
        N_Q = jnp.sum(Q_b_j**2)
        N_K = jnp.sum(K_b_j**2)
        C_Q = Q_a_j.T @ Q_b_j
        C_K = K_a_j.T @ K_b_j
        
        c_q = 0.5 * (jnp.trace(C_Q) + 1j * jnp.trace(C_Q @ J))
        c_k = 0.5 * (jnp.trace(C_K) + 1j * jnp.trace(C_K @ J))

        A = jnp.abs(c_q)**2
        B = jnp.abs(c_k)**2
        C = 2 * jnp.real(c_q * jnp.conj(c_k))

        # Convert all JAX/Device arrays to native Python/NumPy types for SciPy
        N_Q_f, N_K_f = float(N_Q), float(N_K)
        A_f, B_f, C_f = float(A), float(B), float(C)
        c_q_f = complex(c_q)
        c_k_f = complex(c_k)

        # Define the 1D scalar objective function using native floats
        def g_objective(x):
            x = float(x)
            # Protect the sqrt argument from tiny negative values due to roundoff
            inner_term = A_f * x + (B_f / x) + C_f
            safe_inner = max(inner_term, 1e-20)
            return x * N_Q_f + N_K_f / x - 4.0 * sqrt(safe_inner)

        # 4. Find the minimizer x* using robust bounds
        res = minimize_scalar(g_objective, bounds=(1e-8, 1e8), method='bounded')
        x_star = res.x

        # 5. Reconstruct the optimal 2x2 alignment matrix U_j using NumPy/math
        r_star = sqrt(x_star)
        combined_c = r_star * c_q_f + (1 / r_star) * c_k_f
        if abs(combined_c) < 1e-30:
            theta_star = 0.0
        else:
            theta_star = -atan2(combined_c.imag, combined_c.real)
        a = r_star * cos(theta_star)
        b = r_star * sin(theta_star)
        
        U_j = np.array([[a, -b], [b, a]])
        U_blocks.append(U_j)
        
    # 6. Assemble the full block-diagonal matrix U
    U_opt = block_diag(*U_blocks)
    condU = np.linalg.cond(U_opt)
    if condU > 1e12:
        # fallback: scale blocks to have minimum magnitude, or add small diag:
        eps = 1e-6
        U_opt = U_opt + eps * np.eye(U_opt.shape[0])
    return U_opt

def align_attention_params_main_rope(params_a, params_b, layer_idx, num_heads, 
                                     activations_for_layer_a, activations_for_layer_b, 
                                     init_method_vo='ortho', permu_heads=True, optimize_vo=True, alpha=0.5, data_independent=False):
    """
    Aligns a single MHA layer with RoPE using model-specific activations.
    COMPATIBLE with multiple model structures.
    """
    # Use the compatible extractor to get params and the update path ---
    params_a_extracted, _ = extract_attention_params(params_a, layer_idx)
    params_b_extracted, mha_path_b = extract_attention_params(params_b, layer_idx)
    params_a_np = {k: np.array(v) for k, v in zip(['key', 'key_bias', 'query', 'query_bias', 'value', 'value_bias', 'out', 'out_bias'], params_a_extracted)}
    params_b_np = {k: np.array(v) for k, v in zip(['key', 'key_bias', 'query', 'query_bias', 'value', 'value_bias', 'out', 'out_bias'], params_b_extracted)}
    W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a = reshape_to_per_head(params_a_np, num_heads)
    W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b = reshape_to_per_head(params_b_np, num_heads)

    if permu_heads:
        # --- Stage 1: Head Permutation (RoPE version) ---
        row_ind, col_ind = find_heads_permutation_rope(
            W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
            W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
            num_heads, activations_for_layer_a, activations_for_layer_b, alpha, data_independent
        )
        # --- Reorder heads of model B --- (No changes here)
        W_Q_b = [W_Q_b[j] for j in col_ind]
        b_Q_b = [b_Q_b[j] for j in col_ind]
        W_K_b = [W_K_b[j] for j in col_ind]
        b_K_b = [b_K_b[j] for j in col_ind]
        W_V_b = [W_V_b[j] for j in col_ind]
        b_V_b = [b_V_b[j] for j in col_ind]
        W_O_b = [W_O_b[j] for j in col_ind]

    # --- Stage 2: Per-Head Parameter Alignment --- (No changes here)
    aligned_params_list = []
    for i in range(num_heads):
        # ... (rest of the function is identical)
        # QK Alignment (RoPE specific)
        U = solve_rope_qk_alignment(
            W_Q_a[i], b_Q_a[i], W_K_a[i], b_K_a[i],
            W_Q_b[i], b_Q_b[i], W_K_b[i], b_K_b[i]
        )
        U_inv = np.linalg.inv(U)
        
        W_Q_aligned = W_Q_b[i] @ U.T
        b_Q_aligned = b_Q_b[i] @ U.T
        W_K_aligned = W_K_b[i] @ U_inv
        b_K_aligned = b_K_b[i] @ U_inv
        
        # VO Alignment (Standard MHA logic, as it's unaffected by RoPE)
        tilde_W_V_a_i = compute_extended_weights(W_V_a[i], b_V_a[i])
        Y_O_a_i = W_O_a[i].T
        tilde_W_V_b_i = compute_extended_weights(W_V_b[i], b_V_b[i])
        Y_O_b_i = W_O_b[i].T

        if init_method_vo == 'ortho':
            B_init = solve_orthogonal(Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i)
        else:
             B_init = np.identity(W_O_b[i].shape[0])

        if optimize_vo:
            B, _, _, _ = optimize_alignment(B_init, Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i)
        else:
            B = B_init

        B_inv = np.linalg.inv(B)
        W_V_aligned = W_V_b[i] @ B_inv
        b_V_aligned = b_V_b[i] @ B_inv
        W_O_aligned = B @ W_O_b[i]
        
        aligned_params_list.append({
            'query': {'kernel': W_Q_aligned, 'bias': b_Q_aligned},
            'key': {'kernel': W_K_aligned, 'bias': b_K_aligned},
            'value': {'kernel': W_V_aligned, 'bias': b_V_aligned},
            'out': {'kernel': W_O_aligned}
        })

    # --- Reassemble Parameters --- (No changes here)
    query_kernel = np.stack([p['query']['kernel'] for p in aligned_params_list], axis=1)
    query_bias = np.stack([p['query']['bias'] for p in aligned_params_list], axis=0)
    key_kernel = np.stack([p['key']['kernel'] for p in aligned_params_list], axis=1)
    key_bias = np.stack([p['key']['bias'] for p in aligned_params_list], axis=0)
    value_kernel = np.stack([p['value']['kernel'] for p in aligned_params_list], axis=1)
    value_bias = np.stack([p['value']['bias'] for p in aligned_params_list], axis=0)
    out_kernel = np.stack([p['out']['kernel'] for p in aligned_params_list], axis=0)

    return {
        'aligned_params': {
            'query': {'kernel': jnp.array(query_kernel), 'bias': jnp.array(query_bias)},
            'key': {'kernel': jnp.array(key_kernel), 'bias': jnp.array(key_bias)},
            'value': {'kernel': jnp.array(value_kernel), 'bias': jnp.array(value_bias)},
            'out': {'kernel': jnp.array(out_kernel), 'bias': jnp.array(params_b_np['out_bias'])}
        },
        'mha_path': mha_path_b
    }

def matching_attn_rope(params_a, params_b, activations_a, activations_b, finetune_layer_which, num_heads, alpha=0.5):
    """
    Main function to align RoPE-based MHA layers using separate activations for each model.
    COMPATIBLE with multiple model structures.
    
    Args:
        params_a (dict): Parameters of the reference model A.
        params_b (dict): Parameters of the model B to be aligned.
        activations (dict): Dictionary mapping layer_idx to input activations.
        finetune_layer_which (list): List of layer indices to align.
        num_heads (int): Number of attention heads.
        alpha (float): Weighting factor for permutation cost matrix calculation.

    Returns:
        dict: A dictionary where keys are method names and values are the
              aligned parameters for model B.
    """
    params_dict = {}
    
    # Define a single configuration for RoPE alignment.
    # The VO part can optionally be optimized after orthogonal initialization.
    configurations = [
        ("data_indep_permu_head_init_ortho_no_opt", True, 'ortho', True, False),
        #("data_dep_permu_head_init_ortho_no_opt", False, 'ortho', True, False),
        #("data_dep_permu_head_init_ortho_opt", 'ortho', True, True),
    ]
    
    for name, data_independent, init_method_vo, permu_heads, optimize_vo in configurations:
        print(f"--- Running RoPE Alignment Configuration: {name} ---")
        
        # Start with a deepcopy of params_b to modify iteratively ---
        aligned_params = copy.deepcopy(params_b)

        for layer_idx in finetune_layer_which:
            print(f"Aligning Layer {layer_idx}...")
            
            activations_for_layer_a = activations_a[layer_idx]
            activations_for_layer_b = activations_b[layer_idx]

            # Call the main alignment function for a single RoPE MHA layer
            result = align_attention_params_main_rope(
                params_a=params_a,
                params_b=aligned_params,
                layer_idx=layer_idx,
                num_heads=num_heads,
                activations_for_layer_a=activations_for_layer_a,
                activations_for_layer_b=activations_for_layer_b,
                init_method_vo=init_method_vo,
                permu_heads=permu_heads,
                optimize_vo=optimize_vo,
                alpha=alpha,
                data_independent=data_independent
            )
            
            # Update the aligned_params tree using the path ---
            unfrozen_params = unfreeze(aligned_params)
            
            # Navigate to the parent dictionary of the MHA block
            temp_dict = unfrozen_params
            for key in result['mha_path'][:-1]:
                temp_dict = temp_dict[key]
            
            # Update the MHA block with the aligned parameters
            temp_dict[result['mha_path'][-1]] = result['aligned_params']
            
            aligned_params = freeze(unfrozen_params)


        # Sanity check
        total_sum = tree_util.tree_reduce(lambda acc, x: acc + jnp.sum(x), aligned_params, initializer=0)
        print(f"Finished configuration '{name}'. Total parameter sum: {total_sum:.4f}\n")
        
        params_dict[name] = aligned_params

    return params_dict


# this function performs matching on all possible heads permutations
# thus only support len(finetune_layer_which)==1 i.e. at 1 layer only
import itertools

layer_key_prefix = 'TransformerEncoderLayer'
attention_key = 'MultiHeadDotProductAttention_0'

def matching_attn_all_heads_permu(rng, params_a, params_b, finetune_layer_which, num_heads, plot_path, rope_use=False, activations_a=None, activations_b=None):
    """
    Performs matching for all possible head permutations for a single layer and records the
    optimal permutation for several data-dependent and independent methods.

    Args:
        rng: JAX random key.
        params_a: Parameters of the first model.
        params_b: Parameters of the second model.
        finetune_layer_which: A list containing the index of the layer to finetune (must have length 1).
        num_heads: The number of attention heads.
        plot_path: Path for saving plots (not used in this version but kept for consistency).
        rope_use: Boolean indicating if RoPE is used in the model.
        activations_a: A dictionary of activations from model A, keyed by layer index.
        activations_b: A dictionary of activations from model B, keyed by layer index.

    Returns:
        A tuple containing:
        - params_dict: Dictionary of aligned parameters for different settings and permutations.
        - heads_objective_values: Dictionary of objective values for each permutation.
        - heads_permutation_sol: Dictionary storing the optimal permutation for each calculation method.
    """
    assert len(finetune_layer_which) == 1, "This function only supports one layer at a time."

    params_dict = defaultdict(lambda: defaultdict(dict))
    heads_objective_values = defaultdict(lambda: defaultdict(dict))
    heads_permutation_sol = {}
    C_dict = {}

    layer_idx = finetune_layer_which[0]
    layer_key = f'{layer_key_prefix}_{layer_idx}'

    # --- Extract and reshape weights ---
    params_a_extracted = extract_attention_params(params_a, layer_key, attention_key)
    params_b_extracted = extract_attention_params(params_b, layer_key, attention_key)
    params_a_np = {k: np.array(v) for k, v in zip(['key', 'key_bias', 'query', 'query_bias', 'value', 'value_bias', 'out', 'out_bias'], params_a_extracted)}
    params_b_np = {k: np.array(v) for k, v in zip(['key', 'key_bias', 'query', 'query_bias', 'value', 'value_bias', 'out', 'out_bias'], params_b_extracted)}
    W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a = reshape_to_per_head(params_a_np, num_heads)
    W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b = reshape_to_per_head(params_b_np, num_heads)

    # --- Define permutation settings ---
    head_permu_settings = [
        ("data-independent", True, None),
        #("data-dependent_acti-b-use", False, True),
        #("data-dependent_acti-b-notuse", False, False),
    ]

    alpha = 0.5  # Using a fixed alpha as in the original script

    # --- Calculate cost matrices and find optimal permutations for each method ---
    print("Calculating cost matrices and optimal permutations for each method...")
    for name, data_independent, activation_b_use in head_permu_settings:
        if data_independent:
            C = compute_cost_matrix_data_independent(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                                                     W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                                                     num_heads, alpha=alpha)
        else:
            # For data-dependent methods
            assert activations_a is not None, "Activations for model A must be provided for data-dependent methods."
            # Use model B's activations if specified, otherwise use model A's activations for both
            current_activations_b = activations_b[layer_idx] if activation_b_use and activations_b else activations_a[layer_idx]

            if rope_use:
                C = compute_cost_matrix_presoftmax_rope(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                                                        W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                                                        num_heads, activations_a[layer_idx], current_activations_b, alpha=alpha)
            else:
                C = compute_cost_matrix_presoftmax(W_Q_a, b_Q_a, W_K_a, b_K_a, W_V_a, b_V_a, W_O_a,
                                                   W_Q_b, b_Q_b, W_K_b, b_K_b, W_V_b, b_V_b, W_O_b,
                                                   num_heads, activations_a[layer_idx], current_activations_b, alpha=alpha)

        C_dict[name] = C
        row_ind, col_ind = linear_sum_assignment(C)
        # Store the permutation mapping row_ind (model A) to col_ind (model B)
        # We sort by row_ind to ensure the permutation is always in the order [perm_for_head_0, perm_for_head_1, ...]
        permutation_solution = [int(col_ind[i]) for i in np.argsort(row_ind)]
        heads_permutation_sol[name] = ([int(_) for _ in row_ind], [int(_) for _ in  col_ind])
        print(f"  - Method '{name}': Optimal permutation is {permutation_solution}")


    # --- Iterate through all possible permutations for alignment and interpolation ---
    permutations_to_evaluate = list(itertools.permutations(range(num_heads)))

    # Collect unique permutations from heads_permutation_sol
    S = set()
    for name, _, _ in head_permu_settings:
        _, col_ind = heads_permutation_sol[name]
        col_ind = [int(_) for _ in col_ind] 
        S.add(tuple(col_ind))  # Convert to tuple for set compatibility

    # Initialize permutations to use
    permutations_to_use = set(S)
    max_permutations = 24

    # Sample additional permutations up to max_permutations
    while len(permutations_to_use) < max_permutations:
        perm = tuple(np.random.permutation(num_heads))
        permutations_to_use.add(perm)  # Set ensures no duplicates

    permutations_to_evaluate = permutations_to_use
    print(f"\nEvaluating all {len(permutations_to_evaluate)} possible head permutations...")

    # Define alignment settings
    alignment_settings = [
        ("init_ortho_no_opt", 'ortho', False),
        #("init_ortho_opt", 'ortho', True),
    ]

    for setting_name, init_method, optimize in alignment_settings:
        for perm_tuple in permutations_to_evaluate:
            perm = list(perm_tuple)
            # Create a fresh copy of model B parameters for this specific permutation
            params_b_permuted = copy.deepcopy(unfreeze(params_b))
            
            # Manually permute the attention heads of model B for the specified layer
            # This is done by re-indexing the weight tensors according to `perm`
            params_b_layer = params_b_permuted[layer_key][attention_key]
            perm_array = np.array(perm)
            
            # Permute dimensions related to heads
            params_b_layer['query']['kernel'] = params_b_layer['query']['kernel'][:, perm_array, :]
            params_b_layer['query']['bias'] = params_b_layer['query']['bias'][perm_array, :]
            params_b_layer['key']['kernel'] = params_b_layer['key']['kernel'][:, perm_array, :]
            params_b_layer['key']['bias'] = params_b_layer['key']['bias'][perm_array, :]
            params_b_layer['value']['kernel'] = params_b_layer['value']['kernel'][:, perm_array, :]
            params_b_layer['value']['bias'] = params_b_layer['value']['bias'][perm_array, :]
            params_b_layer['out']['kernel'] = params_b_layer['out']['kernel'][perm_array, :, :]

            # Align the permuted model B to model A
            if rope_use:
                aligned_result = align_attention_params_main_rope(
                    params_a, params_b_permuted, layer_idx, num_heads,
                    activations_a[layer_idx],
                    activations_b[layer_idx],
                    init_method_vo=init_method, permu_heads=False, optimize_vo=optimize, alpha=alpha
                )
                # The rope main function returns the aligned MHA block directly
                result_to_store = aligned_result
            else:
                # Standard alignment. `permute_heads` is False because we do it manually above.
                result_dict = align_attention_params_main(
                    rng, params_a, params_b_permuted, layer_idx, num_heads,
                    activations_a[layer_idx],
                    activations_b[layer_idx],
                    plot_path=None, init_method=init_method,
                    permute_heads=False, optimize=optimize, alpha=alpha
                )
                result_to_store = result_dict

            perm_str = str([int(_) for _ in perm])
            params_dict[setting_name][perm_str] = result_to_store

            # --- Calculate objective cost for this permutation against each method's cost matrix ---
            for method_name, C_matrix in C_dict.items():
                # The cost is the sum of C[i, j] for the mapping i -> perm[i]
                total_cost = sum(C_matrix[i, perm[i]] for i in range(num_heads))
                heads_objective_values[method_name][setting_name][perm_str] = float(total_cost)

    print("Finished evaluating all permutations.")
    return params_dict, heads_objective_values, heads_permutation_sol

##########Transformers matching
import copy
import jax.numpy as jnp
import numpy as np
from flax.core import unfreeze, freeze
from scipy.optimize import linear_sum_assignment

# --- Helper Functions (from your draft and my previous code) ---

def get_nested_item(d, keys):
    """Accesses a nested dictionary item using a tuple of keys."""
    for key in keys:
        d = d[key]
    return d

def set_nested_item(d, keys, value):
    """Sets a value in a nested dictionary using a tuple of keys."""
    current = d
    for key in keys[:-1]:
        current = current[key]
    current[keys[-1]] = value

def extract_ffn_params(params, layer_idx):
    """
    Extracts FFN parameters from different, known model structures in a compatible way.
    This function detects the model type and constructs the correct path to the
    FFN parameters for a given layer index.
    Args:
        params: The Flax parameter tree.
        layer_idx: The integer index of the transformer layer.
    Returns:
        A tuple containing:
        - W1, b1, W2 (FFN weights and biases for Dense_0 and Dense_1).
        - dense0_path, dense1_path (tuples representing the nested paths to Dense_0 and Dense_1).
    """
    # Detect cifar_vit style: params['TransformerEncoderLayer_...']
    if f'TransformerEncoderLayer_{layer_idx}' in params:
        base_path_tuple = (f'TransformerEncoderLayer_{layer_idx}',)
        dense0_path = base_path_tuple + ('Dense_0',)
        dense1_path = base_path_tuple + ('Dense_1',)
    # Detect vit-jax style: params['Transformer']['encoderblock_...']
    elif 'Transformer' in params and f'encoderblock_{layer_idx}' in params['Transformer']:
        # Note: vit-jax often nests the MLP in its own block, e.g., 'MlpBlock_0'
        # We check for its existence for robustness.
        encoder_block = get_nested_item(params, ('Transformer', f'encoderblock_{layer_idx}'))
        mlp_key = next((k for k in encoder_block if 'MlpBlock' in k), None)
        if mlp_key:
             base_path_tuple = ('Transformer', f'encoderblock_{layer_idx}', mlp_key)
        else: # Fallback if no explicit MlpBlock
             base_path_tuple = ('Transformer', f'encoderblock_{layer_idx}')
        dense0_path = base_path_tuple + ('Dense_0',)
        dense1_path = base_path_tuple + ('Dense_1',)
    else:
        raise KeyError(f"Could not find a known path for FFN layer {layer_idx} in the provided params.")
    
    dense0_block = get_nested_item(params, dense0_path)
    dense1_block = get_nested_item(params, dense1_path)
    
    W1 = dense0_block['kernel']
    b1 = dense0_block['bias']
    W2 = dense1_block['kernel']
    
    return W1, b1, W2, dense0_path, dense1_path

# --- Core Functions ---

def matching_transformer_ffn(params_a, params_b, finetune_layer_which):
    """
    Aligns the FFN components of two models for the specified layers.

    This function computes the optimal permutation of hidden neurons in the FFN
    of model B to match model A. The permutation is found by solving a Linear
    Assignment Problem (LAP) where the cost is the sum of squared L2 distances
    between the incoming (weights + bias) and outgoing weights of each neuron pair.
    
    Args:
        params_a (dict): Parameters of the reference model A.
        params_b (dict): Parameters of the model B to be aligned.
        finetune_layer_which (list): List of layer indices to align.
    
    Returns:
        dict: The aligned parameters for model B.
    """
    aligned_params_b = copy.deepcopy(params_b)

    for layer_idx in finetune_layer_which:
        # Extract parameters and paths for both models
        W1_a, b1_a, W2_a, _, _ = extract_ffn_params(params_a, layer_idx)
        W1_b, b1_b, W2_b, dense0_path_b, dense1_path_b = extract_ffn_params(aligned_params_b, layer_idx)
        
        # Convert to NumPy for computation
        W1_a, b1_a, W2_a = np.array(W1_a), np.array(b1_a), np.array(W2_a)
        W1_b, b1_b, W2_b = np.array(W1_b), np.array(b1_b), np.array(W2_b)

        D_hidden = W1_a.shape[1]
        assert W1_b.shape[1] == D_hidden, f"FFN hidden dimensions for layer {layer_idx} must match."
        
        # Compute cost matrix C
        C = np.zeros((D_hidden, D_hidden), dtype=np.float32)
        for i in range(D_hidden):
            # Incoming weights and bias for neuron i of model A
            in_a = np.concatenate([W1_a[:, i], [b1_a[i]]])
            # Outgoing weights for neuron i of model A
            out_a = W2_a[i, :]
            for j in range(D_hidden):
                in_b = np.concatenate([W1_b[:, j], [b1_b[j]]])
                out_b = W2_b[j, :]
                
                # Cost is the sum of squared Euclidean distances
                cost = np.linalg.norm(in_a - in_b)**2 + np.linalg.norm(out_a - out_b)**2
                C[i, j] = cost
        
        # Solve LAP. `col_ind` gives the permutation for model B's neurons.
        row_ind, col_ind = linear_sum_assignment(C)
        
        # Permute the weights of model B according to the solution
        W1_aligned = W1_b[:, col_ind]
        b1_aligned = b1_b[col_ind]
        W2_aligned = W2_b[col_ind, :]
        
        # Update the parameter dictionary for model B
        unfrozen_params = unfreeze(aligned_params_b)
        set_nested_item(unfrozen_params, dense0_path_b + ('kernel',), jnp.array(W1_aligned))
        set_nested_item(unfrozen_params, dense0_path_b + ('bias',), jnp.array(b1_aligned))
        set_nested_item(unfrozen_params, dense1_path_b + ('kernel',), jnp.array(W2_aligned))
        aligned_params_b = freeze(unfrozen_params)
    
    return aligned_params_b

def matching_transformer_block(rng, params_a, params_b, activations_a, activations_b, finetune_layer_which, num_heads, rope_use=False, plot_path=None):
    """
    Aligns entire Transformer blocks by sequentially aligning their MHA and FFN components.

    This function first calls the appropriate MHA alignment function (`matching_attn` or
    `matching_attn_rope`), which returns a dictionary of aligned parameters for various
    configurations. It then iterates through this dictionary, applying the FFN
    alignment to each MHA-aligned model.
    
    Args:
        rng: JAX random key.
        params_a (dict): Parameters of the reference model A.
        params_b (dict): Parameters of the model B to be aligned.
        activations_a (dict): Dictionary of activations from model A, keyed by layer index.
        activations_b (dict): Dictionary of activations from model B, keyed by layer index.
        finetune_layer_which (list): List of layer indices to align.
        num_heads (int): Number of attention heads.
        rope_use (bool): If True, use RoPE-specific MHA alignment.
        plot_path (str, optional): Path for saving diagnostic plots.
    
    Returns:
        dict: A dictionary where keys are configuration names (e.g., 'data_indep_...`) 
              and values are the fully aligned (MHA + FFN) parameter dictionaries.
    """
    print("--- Starting Transformer Block Alignment ---")
    
    # Step 1: Align the MHA component for all specified layers and configurations.
    print("\nStep 1: Aligning Multi-Head Attention components...")
    if rope_use:
        mha_aligned_params_dict = matching_attn_rope(params_a, params_b, activations_a, activations_b, finetune_layer_which, num_heads)
    else:
        mha_aligned_params_dict = matching_attn(rng, params_a, params_b, activations_a, activations_b, finetune_layer_which, num_heads, plot_path)
    print("MHA alignment complete.")

    # Step 2: For each MHA-aligned model, align its FFN component.
    print("\nStep 2: Aligning Feed-Forward Network components for each configuration...")
    fully_aligned_params_dict = {}
    for config_name, mha_aligned_params in mha_aligned_params_dict.items():
        print(f"  - Aligning FFN for configuration: '{config_name}'")
        fully_aligned_params = matching_transformer_ffn(params_a, mha_aligned_params, finetune_layer_which)
        fully_aligned_params_dict[config_name] = fully_aligned_params
    
    print("FFN alignment complete.")
    print("\n--- Transformer Block Alignment Finished ---")
    
    return fully_aligned_params_dict