import os
import copy
import time
import random
import itertools
import numpy as np
import jax.lax as lax
import jax.numpy as jnp
from utils import rngmix
from typing import NamedTuple
from collections import defaultdict
from flax.core import freeze, unfreeze
from scipy.optimize import linear_sum_assignment, minimize, minimize_scalar
from jax import random, tree_util, jit, grad, value_and_grad
from scipy.linalg import block_diag
from math import sqrt, cos, sin, atan2
import numpy as np
import matplotlib.pyplot as plt
def l2_dist(A, B):
    return np.sum((A - B) ** 2)
def l1_dist(A, B):
    return np.sum(np.abs(A - B))
def cosine_dist(A, B, eps=1e-8):
    a = A.reshape(-1)
    b = B.reshape(-1)
    return 1.0 - np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + eps)
def corr_dist(A, B, eps=1e-8):
    a = (A - A.mean()).reshape(-1)
    b = (B - B.mean()).reshape(-1)
    return 1.0 - np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + eps)
def spectral_dist(A, B):
    return abs(np.linalg.norm(A, 2) - np.linalg.norm(B, 2))
def compute_objective(A, X, X_prime, Y, Y_prime):
    A_inv = np.linalg.inv(A)
    term1 = X - X_prime @ A.T
    term2 = Y - Y_prime @ A_inv
    return np.sum(term1**2) + np.sum(term2**2)

def compute_gradient(A, X, X_prime, Y, Y_prime):
    A_inv = np.linalg.inv(A)
    term1 = -2 * X.T @ X_prime + 2 * A @ X_prime.T @ X_prime
    term2 = 2 * A_inv.T @ Y_prime.T @ (Y - Y_prime @ A_inv) @ A_inv.T
    return term1 + term2

def line_search(A, grad, X, X_prime, Y, Y_prime, max_step=1, tau=0.5, c1=1e-4):
    eta = max_step
    f_current = compute_objective(A, X, X_prime, Y, Y_prime)
    grad_norm2 = np.sum(grad**2)
    n = A.shape[0]
    while eta > 1e-10:
        A_new = A - eta * grad
        if np.linalg.matrix_rank(A_new) < n:
            eta *= tau
            continue
        f_new = compute_objective(A_new, X, X_prime, Y, Y_prime)
        if f_new <= f_current - c1 * eta * grad_norm2:
            return eta
        eta *= tau
    return 0
@jit
def compute_objective_jax(A, X, X_prime, Y, Y_prime, cond_threshold=1e6):
    cond = jnp.linalg.cond(A)
    def safe_obj():
        A_inv = jnp.linalg.inv(A)
        term1 = X - X_prime @ A.T
        term2 = Y - Y_prime @ A_inv
        return jnp.sum(term1**2) + jnp.sum(term2**2)
    return lax.cond(cond > cond_threshold, lambda: jnp.inf, safe_obj)
compute_value_and_grad_jax = jit(value_and_grad(compute_objective_jax))
def solve_orthogonal(X, X_prime, Y, Y_prime):
    B = X.T @ X_prime + Y.T @ Y_prime
    U, _, Vt = np.linalg.svd(B)
    return U @ Vt

def solve_rope_qk_alignment(W_Q_a_i, W_K_a_i, W_Q_b_i, W_K_b_i):
    """
    Solve U for one head, but in the SAME RoPE pairing as rotate_half():
      pairs are (j, j + D_k/2), not (2j, 2j+1).
    """
    tilde_W_Q_a = np.array(W_Q_a_i, copy=True)
    tilde_W_K_a = np.array(W_K_a_i, copy=True)
    tilde_W_Q_b = np.array(W_Q_b_i, copy=True)
    tilde_W_K_b = np.array(W_K_b_i, copy=True)

    D_k = tilde_W_Q_a.shape[1]
    assert D_k % 2 == 0, "Head dimension must be even for RoPE."
    half = D_k // 2

    # J must be NumPy here (we are using SciPy + Python math downstream)
    J = np.array([[0.0, -1.0], [1.0, 0.0]])

    # We will place 2x2 blocks into the correct indices
    U_opt = np.eye(D_k, dtype=np.float64)

    for j in range(half):
        idx = np.array([j, j + half])  # <-- RoPE plane!

        Q_a_j = tilde_W_Q_a[:, idx]  # (d_model, 2)
        Q_b_j = tilde_W_Q_b[:, idx]
        K_a_j = tilde_W_K_a[:, idx]
        K_b_j = tilde_W_K_b[:, idx]

        # constants
        N_Q = float(np.sum(Q_b_j**2))
        N_K = float(np.sum(K_b_j**2))
        C_Q = Q_a_j.T @ Q_b_j  # (2,2)
        C_K = K_a_j.T @ K_b_j

        c_q = 0.5 * (np.trace(C_Q) + 1j * np.trace(C_Q @ J))
        c_k = 0.5 * (np.trace(C_K) + 1j * np.trace(C_K @ J))

        A = float(np.abs(c_q)**2)
        B = float(np.abs(c_k)**2)
        C = float(2 * np.real(c_q * np.conj(c_k)))

        def g_objective(x):
            x = float(x)
            inner_term = A * x + (B / x) + C
            safe_inner = max(inner_term, 1e-20)
            return x * N_Q + N_K / x - 4.0 * sqrt(safe_inner)

        res = minimize_scalar(g_objective, bounds=(1e-8, 1e8), method="bounded")
        x_star = float(res.x)

        r_star = sqrt(x_star)
        combined_c = r_star * c_q + (1.0 / r_star) * c_k

        if abs(combined_c) < 1e-30:
            theta_star = 0.0
        else:
            theta_star = -atan2(combined_c.imag, combined_c.real)

        a = r_star * cos(theta_star)
        b = r_star * sin(theta_star)

        U_j = np.array([[a, -b], [b, a]], dtype=np.float64)

        # Place block into (j, j+half) coords
        U_opt[np.ix_(idx, idx)] = U_j

    # optional conditioning safeguard
    condU = np.linalg.cond(U_opt)
    if condU > 1e12:
        U_opt = U_opt + 1e-6 * np.eye(D_k)

    return U_opt

def optimize_alignment(A_init, X, X_prime, Y, Y_prime, max_iter=5000):
    objective_values = []
    grad_norms = []
    condition_nums = []

    def obj_fn(flat_A):
        A = flat_A.reshape(A_init.shape)
        obj, grad_val = compute_value_and_grad_jax(jnp.array(A), jnp.array(X), jnp.array(X_prime), jnp.array(Y), jnp.array(Y_prime))
        return float(obj), np.array(grad_val).flatten()

    def callback(flat_A):
        A = flat_A.reshape(A_init.shape)
        obj, grad_val = compute_value_and_grad_jax(jnp.array(A), jnp.array(X), jnp.array(X_prime), jnp.array(Y), jnp.array(Y_prime))
        grad_norm = jnp.linalg.norm(grad_val, 'fro')
        cond = jnp.linalg.cond(jnp.array(A))
        objective_values.append(float(obj))
        grad_norms.append(float(grad_norm))
        condition_nums.append(float(cond))

    res = minimize(obj_fn, A_init.flatten(), jac=True, method='L-BFGS-B', options={'maxiter': max_iter}, callback=callback)
    A_opt = res.x.reshape(A_init.shape)
    return A_opt, objective_values, grad_norms, condition_nums


def extract_attention_params(attn):
    query = np.array(attn['q_proj']['kernel'])
    key   = np.array(attn['k_proj']['kernel'])
    value = np.array(attn['v_proj']['kernel'])
    out   = np.array(attn['o_proj']['kernel'])
    return query, key, value, out

def extract_ffn_params(ffn):
    W1 = ffn['up_proj']['kernel']
    W2 = ffn['down_proj']['kernel']
    return W1, W2
def reshape_attention_weights(query, key, value, out, num_heads):
    D = query.shape[0]
    D_k = D_v = D // num_heads
    def stack_per_head(tensor, axis=0):
        return np.stack([
            tensor[:,i * D_k:(i + 1) * D_k] if axis == 0 else tensor[i * D_k:(i + 1) * D_k,:]
            for i in range(num_heads)
        ])
    def stack_bias_per_head(bias):
        return np.stack([bias[i * D_k:(i + 1) * D_k] for i in range(num_heads)])
    W_Q = stack_per_head(query)
    W_K = stack_per_head(key)
    W_V = stack_per_head(value)
    W_O = stack_per_head(out, axis=1)
    return W_Q, W_K, W_V, W_O
def compute_extended_weights(W, b):
    return np.vstack([W, b.reshape(1, -1)])
def compute_cost_matrix(W_Q_a, W_K_a, W_V_a, W_O_a,
                        W_Q_b, W_K_b, W_V_b, W_O_b,
                        h, alpha=0.5,dist="L2"):
    C = np.zeros((h, h))
    for i in range(h):
        tilde_W_Q_a_i = copy.deepcopy(W_Q_a[i])
        tilde_W_K_a_i = copy.deepcopy(W_K_a[i])
        tilde_W_V_a_i = copy.deepcopy(W_V_a[i])
        QKT_a_i = tilde_W_Q_a_i @ tilde_W_K_a_i.T
        VO_a_i = tilde_W_V_a_i @ W_O_a[i]
        centered_QKT_a_i = QKT_a_i - np.mean(QKT_a_i, axis=1, keepdims=True)
        for j in range(h):
            tilde_W_Q_b_j = copy.deepcopy(W_Q_b[j])
            tilde_W_K_b_j = copy.deepcopy(W_K_b[j])
            tilde_W_V_b_j = copy.deepcopy(W_V_b[j])
            QKT_b_j = tilde_W_Q_b_j @ tilde_W_K_b_j.T
            VO_b_j = tilde_W_V_b_j @ W_O_b[j]
            centered_QKT_b_j = QKT_b_j - np.mean(QKT_b_j, axis=1, keepdims=True)
            if dist == "L2":
                cost_qk = l2_dist(centered_QKT_a_i, centered_QKT_b_j)
                cost_vo = l2_dist(VO_a_i, VO_b_j)
            elif dist == "L1":
                cost_qk = l1_dist(centered_QKT_a_i, centered_QKT_b_j)
                cost_vo = l1_dist(VO_a_i, VO_b_j)
            elif dist == "cosine":
                cost_qk = cosine_dist(centered_QKT_a_i, centered_QKT_b_j)
                cost_vo = cosine_dist(VO_a_i, VO_b_j)
            elif dist == "corr":
                cost_qk = corr_dist(centered_QKT_a_i, centered_QKT_b_j)
                cost_vo = corr_dist(VO_a_i, VO_b_j)
            elif dist == "spectral":
                cost_qk = spectral_dist(centered_QKT_a_i, centered_QKT_b_j)
                cost_vo = spectral_dist(VO_a_i, VO_b_j)
            else:
                raise ValueError(f"Unknown dist: {dist}")
            C[i, j] = cost_qk + cost_vo
    return C

def multiplytive_align_single_head(W_Q_a_i, W_K_a_i, W_V_a_i, W_O_a_i,
                                   W_Q_b_i, W_K_b_i, W_V_b_i, W_O_b_i, optimize):
    U = solve_rope_qk_alignment(W_Q_a_i, W_K_a_i, W_Q_b_i, W_K_b_i)
    U_inv = np.linalg.inv(U)
    W_Q_aligned = W_Q_b_i @ U.T
    W_K_aligned = W_K_b_i @ U_inv
    tilde_W_V_a_i = copy.deepcopy(W_V_a_i)
    Y_O_a_i = W_O_a_i.T
    tilde_W_V_b_i = copy.deepcopy(W_V_b_i)
    Y_O_b_i = W_O_b_i.T
    B_init = solve_orthogonal(Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i)
    B, _, _, _ = optimize_alignment(B_init, Y_O_a_i, Y_O_b_i, tilde_W_V_a_i, tilde_W_V_b_i)
    B_inv = np.linalg.inv(B)
    W_V_aligned = W_V_b_i @ B_inv
    W_O_aligned = B @ W_O_b_i
    aligned_params = {
        'query': {'kernel': W_Q_aligned},'key': {'kernel': W_K_aligned},
        'value': {'kernel': W_V_aligned},'out': {'kernel': W_O_aligned}
    }
    return {'aligned_params': aligned_params}

def merge_aligned_params(aligned_params, h, D):
    query_kernel = np.stack([aligned_params[f'head_{i}']['query']['kernel'] for i in range(h)], axis=1)
    key_kernel = np.stack([aligned_params[f'head_{i}']['key']['kernel'] for i in range(h)], axis=1)
    value_kernel = np.stack([aligned_params[f'head_{i}']['value']['kernel'] for i in range(h)], axis=1)
    out_kernel = np.stack([aligned_params[f'head_{i}']['out']['kernel'] for i in range(h)], axis=0)
    query_kernel = query_kernel.reshape(-1, D)
    key_kernel = key_kernel.reshape(-1, D)
    value_kernel = value_kernel.reshape(-1, D)
    out_kernel = out_kernel.reshape(D, -1)
    return {
        'q_proj': {'kernel': jnp.array(query_kernel)},
        'k_proj': {'kernel': jnp.array(key_kernel)},
        'v_proj': {'kernel': jnp.array(value_kernel)},
        'o_proj': {'kernel': jnp.array(out_kernel)},
    }
def align_attention_params(rng, params_a, params_b, layer_idx, config, permute_heads=True, optimize=False, dist = 'L2', alpha=0.5):
    num_heads = config.lmc_config.num_attention_heads
    attn_a = params_a['model']['layers'][str(layer_idx)]['self_attn'] 
    attn_b = params_b['model']['layers'][str(layer_idx)]['self_attn'] 
    query_a, key_a, value_a, out_a = extract_attention_params(attn_a)
    query_b, key_b, value_b, out_b = extract_attention_params(attn_b)
    W_Q_a, W_K_a, W_V_a, W_O_a = reshape_attention_weights(query_a, key_a, value_a, out_a, num_heads)
    W_Q_b, W_K_b, W_V_b, W_O_b = reshape_attention_weights(query_b, key_b, value_b, out_b, num_heads)
    if permute_heads:
        C = compute_cost_matrix(W_Q_a, W_K_a, W_V_a, W_O_a,W_Q_b, W_K_b, W_V_b, W_O_b, num_heads, alpha, dist)
        row_ind, col_ind = linear_sum_assignment(C)
        print("Best Permutation Heads:", col_ind)
        W_Q_b = [W_Q_b[j] for j in col_ind]
        W_K_b = [W_K_b[j] for j in col_ind]
        W_V_b = [W_V_b[j] for j in col_ind]
        W_O_b = [W_O_b[j] for j in col_ind]
    if optimize:
        metrics_A_all = {key: [] for key in ['objective_values', 'grad_norms', 'condition_nums']}
        metrics_B_all = {key: [] for key in ['objective_values', 'grad_norms', 'condition_nums']}
    aligned_params, return_dict = {}, {}
    for i in range(num_heads):
        result = multiplytive_align_single_head(
            W_Q_a[i], W_K_a[i], W_V_a[i], W_O_a[i],
            W_Q_b[i], W_K_b[i], W_V_b[i], W_O_b[i], optimize
        )
        aligned_params[f'head_{i}'] = result['aligned_params']
    return_dict['aligned_params'] = merge_aligned_params(aligned_params, num_heads, query_a.shape[1])
    return return_dict
def  align_ffn_params(rng, params_a, params_b, layer_idx, config):
    # Extract parameters and paths for both models
    ffn_a = params_a['model']['layers'][str(layer_idx)]['mlp']
    ffn_b = params_b['model']['layers'][str(layer_idx)]['mlp']
    W1_a, W2_a = extract_ffn_params(ffn_a)
    W1_b, W2_b = extract_ffn_params(ffn_b)
    # Convert to NumPy for computation
    W1_a, W2_a = np.array(W1_a), np.array(W2_a)
    W1_b, W2_b = np.array(W1_b), np.array(W2_b)
    D_hidden = W1_a.shape[1]
    C = np.zeros((D_hidden, D_hidden), dtype=np.float32)
    for i in range(D_hidden):
        # Incoming weights and bias for neuron i of model A
        in_a = W1_a[:, i]
        # Outgoing weights for neuron i of model A
        out_a = W2_a[i, :]
        for j in range(D_hidden):
            in_b = W1_b[:, j]
            out_b = W2_b[j, :]
            # Cost is the sum of squared Euclidean distances
            cost = np.linalg.norm(in_a - in_b)**2 + np.linalg.norm(out_a - out_b)**2
            C[i, j] = cost
    # Solve LAP. `col_ind` gives the permutation for model B's neurons.
    row_ind, col_ind = linear_sum_assignment(C)
    # Permute the weights of model B according to the solution
    W1_aligned = W1_b[:, col_ind]
    W2_aligned = W2_b[col_ind, :]
    return {
        'aligned_params': {'up_proj': {'kernel':W1_aligned}, 'down_proj': {'kernel':W2_aligned}}
    }

def weight_matching(rng, params_a, params_b, config, args):
    params_dict = {}
    configurations = [
        ("permu_head_init_ortho_opt", 'ortho', True, True),
    ]
    for name, init_method, permute_heads, optimize in configurations:
        aligned_params = copy.deepcopy(params_b)
        for layer_idx in config.lmc_layer_indices:
            attention_result = align_attention_params(
                rng, params_a, aligned_params, layer_idx, config, permute_heads=permute_heads, optimize=optimize, dist = args.dist
            )
            aligned_params['model']['layers'][str(layer_idx)]['self_attn'] = attention_result['aligned_params']
            if hasattr(config, "finetune_mlp") and config.finetune_mlp and config.mlp_type == "Normal":
                ffn_result = align_ffn_params(rng, params_a, aligned_params, layer_idx, config)
                aligned_params['model']['layers'][str(layer_idx)]['mlp'] = ffn_result['aligned_params']
        total_sum = tree_util.tree_reduce(lambda acc, x: acc + jnp.sum(x), aligned_params, initializer=0)
        print(f"{name}: {total_sum}, sanity check")
        params_dict[name] = aligned_params
    return params_dict

