from collections import defaultdict
from typing import NamedTuple
from flax.core import freeze, unfreeze
import copy
import jax.numpy as jnp
from jax import random
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt
from utils import rngmix
import numpy as np

class PermutationSpec(NamedTuple):
  perm_to_axes: dict
  axes_to_perm: dict
def get_nested(params, dotted_key):
    keys = dotted_key.split('.')
    for k in keys:
        params = params[k]
    return params

def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
    perm_to_axes = defaultdict(list)
    for param_name, axis_list in axes_to_perm.items():
        for axis, perm_name in enumerate(axis_list):
            if perm_name is not None:
                perm_to_axes[perm_name].append((param_name, axis))
    return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)
def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
    """Get parameter `k` from `params`, with the permutations applied."""
    w = params[k]
    if k not in ps.axes_to_perm.keys():
        return w
    for axis, p in enumerate(ps.axes_to_perm[k]):
        # Skip the axis we're trying to permute.
        if axis == except_axis:
            continue
        # None indicates that there is no permutation relevant to that axis.
        if p is not None:
            w = jnp.take(w, perm[p], axis=axis)
    return w

def apply_permutation(ps: PermutationSpec, perm, params):
    """Apply a `perm` to `params`."""
    return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}


def weight_matching(rng,
                    ps: PermutationSpec,
                    params_a,
                    params_b,
                    max_iter=150,
                    patience=10,
                    init_perm=None,
                    silent=True):
    """Find a permutation of `params_b` to make them match `params_a` with early stopping."""
    perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}
    # Initialize permutation
    perm = {p: jnp.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
    perm_names = list(perm.keys())
    # Helper function to compute global alignment score
    def compute_score(perm):
        score = 0.0
        for wk in params_a.keys():
            w_a = params_a[wk]
            w_b = get_permuted_param(ps, perm, wk, params_b)
            score += jnp.vdot(w_a, w_b)
        return score
    # Initialize best permutation and score
    best_perm = perm.copy()
    best_score = compute_score(perm)
    no_improve_iter = 0
    # Main iteration loop
    for iteration in range(max_iter):
        # Update permutations
        for p_ix in random.permutation(rngmix(rng, iteration), len(perm_names)):
            p = perm_names[p_ix]
            n = perm_sizes[p]
            A = jnp.zeros((n, n))
            for wk, axis in ps.perm_to_axes[p]:
                w_a = params_a[wk]
                w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
                w_a = jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
                w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
                A += w_a @ w_b.T
            ri, ci = linear_sum_assignment(A, maximize=True)
            assert (ri == jnp.arange(len(ri))).all()
            oldL = jnp.vdot(A, jnp.eye(n)[perm[p]])
            newL = jnp.vdot(A, jnp.eye(n)[ci, :])
            if not silent and abs(newL - oldL) > 1e-3:
                print(f"{iteration}/{p}: {newL - oldL}")
            perm[p] = jnp.array(ci)
        # Compute current global score and update best if improved
        current_score = compute_score(perm)
        if current_score > best_score + 1e-12:
            best_score = current_score
            best_perm = perm.copy()
            no_improve_iter = 0
        else:
            no_improve_iter += 1
            if no_improve_iter >= patience:
                break
    return best_perm
def flax_vit_permutation_spec_moe(config) -> PermutationSpec:
    """
    Creates a PermutationSpec for ViT-MoE with both shared and routed experts.
    MoE block is at `vit/encoder/layer/{config.moe_idx}/moe_block`.
    Shared experts are under `shared_intermediates` and `shared_outputs`.
    Routed experts are under `routed_intermediates_{i}` and `routed_outputs_{i}`.
    The gate is not permuted.
    """
    axes_to_perm = {}
    moe_prefix = f"vit/encoder/layer/{config.moe_idx}/moe_block"
    # Shared experts
    if config.num_shared_experts > 0:
        inter_prefix = f"{moe_prefix}/shared_intermediate/dense"
        out_prefix = f"{moe_prefix}/shared_output/dense"
        # inter: (in=hidden, out=ff_dim) → permute output
        axes_to_perm[f"{inter_prefix}/kernel"] = (None, "P_shared")
        axes_to_perm[f"{inter_prefix}/bias"] = ("P_shared",)
        # out: (in=ff_dim, out=hidden) → permute input
        axes_to_perm[f"{out_prefix}/kernel"] = ("P_shared", None)
        axes_to_perm[f"{out_prefix}/bias"] = (None,)
    # Routed experts
    for i in range(config.num_routed_experts):
        inter_prefix = f"{moe_prefix}/routed_intermediates_{i}/dense"
        out_prefix = f"{moe_prefix}/routed_outputs_{i}/dense"

        # inter: (in=hidden, out=ff_dim) → permute output
        axes_to_perm[f"{inter_prefix}/kernel"] = (None, f"P_routed_{i}")
        axes_to_perm[f"{inter_prefix}/bias"] = (f"P_routed_{i}",)

        # out: (in=ff_dim, out=hidden) → permute input
        axes_to_perm[f"{out_prefix}/kernel"] = (f"P_routed_{i}", None)
        axes_to_perm[f"{out_prefix}/bias"] = (None,)

    return permutation_spec_from_axes_to_perm(axes_to_perm)
def permute_moe_block(params, pi, config):
    """
    Permute the router and expert submodules of the final MoE layer in a ViT-MoE model.

    Args:
        params: FrozenDict of model parameters.
        pi: jnp.ndarray of shape (num_experts,) - new permutation indices.
        num_layers: Total number of transformer layers.
        num_experts: Number of experts in the MoE block.

    Returns:
        FrozenDict with permuted MoE parameters.
    """
    params = unfreeze(params)
    moe_idx = str(config.moe_idx)
    moe_block = params['vit']['encoder']['layer'][moe_idx]['moe_block']

    # --- Permute gate ---
    moe_block['gate']['kernel'] = moe_block['gate']['kernel'][:, tuple(pi)]
    moe_block['gate']['bias'] = moe_block['gate']['bias'][jnp.array(pi)]
    # --- Backup original routed experts ---
    originals = {
        e: {
            "inter": copy.deepcopy(moe_block[f"routed_intermediates_{e}"]["dense"]),
            "out":   copy.deepcopy(moe_block[f"routed_outputs_{e}"]["dense"]),
        }
        for e in range(config.num_routed_experts)
    }
    # --- Reassign experts with permuted source ---
    for e in range(config.num_routed_experts):
        src = int(pi[e])
        moe_block[f"routed_intermediates_{e}"]["dense"] = originals[src]["inter"]
        moe_block[f"routed_outputs_{e}"]["dense"] = originals[src]["out"]

    return freeze(params)

plot_line_width=1
legend_size=5
colors = [
        '#F4A7A7', '#A7C6F4', '#C2F4A7', '#F4D7A7', '#D7A7F4',  # Reds, Blues, Greens, Oranges, Purples
        '#F4A7C6', '#A7F4D7', '#E6F4A7', '#A7BFF4', '#F4A7E6',  # Pinks, Cyans, Yellows, etc.
        '#FFB3BA', '#BAFFC9', '#BAE1FF', '#FFFFBA', '#FFDFBA',
        '#E6A7F4', '#A7F4E6', '#F4C2A7', '#A7E6F4', '#D7F4A7',
        '#F4A7B3', '#A7D7F4', '#C6A7F4', '#F4E6A7', '#A7F4C2'
]
def plot_interp_loss(lambdas, train_loss_interp_naive, test_loss_interp_naive,
                     train_loss_interp_clever_list, test_loss_interp_clever_list, perm_labels):
    """Plot loss landscape for naive and all permuted interpolations."""
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111)
    
    # Compute indices of lowest mean test losses
    test_means = [np.mean(test_loss) for test_loss in test_loss_interp_clever_list]
    i_test = np.argmin(test_means)
    
    # Naive interpolation
    ax.plot(lambdas, train_loss_interp_naive, linestyle="dashed", color=colors[0], 
            alpha=0.5, linewidth=2 * plot_line_width, label="Train, naïve interp.")
    ax.plot(lambdas, test_loss_interp_naive, linestyle="solid", color=colors[0], 
            alpha=0.5, linewidth=2 * plot_line_width, label="Test, naïve interp.")
    
    for i, (train_loss, test_loss, label) in enumerate(zip(train_loss_interp_clever_list, 
                                                           test_loss_interp_clever_list, 
                                                           perm_labels)):
        color = colors[i+1 % len(colors)]
        # Set line width: 2 * plot_line_width for lowest mean, plot_line_width otherwise
        linewidth = 2 * plot_line_width if i == i_test else plot_line_width
        ax.plot(lambdas, train_loss, linestyle="dashed", color=color, 
                linewidth=linewidth, label=f"Train, permuted interp. {label}")
        ax.plot(lambdas, test_loss, linestyle="solid", color=color, alpha=0.7,
                linewidth=linewidth, label=f"Test, permuted interp. {label}")
    
    # Configure axes and legend
    ax.set_xlabel("$\lambda$")
    ax.set_xticks([0, 1])
    ax.set_xticklabels(["Model $A$", "Model $B$"])
    ax.set_ylabel("Loss")
    ax.set_title(f"Loss landscape between the two models")
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), framealpha=0.5, fontsize=legend_size)
    plt.subplots_adjust(right=0.65)
    fig.tight_layout()
    return fig

def plot_interp_acc(lambdas, train_acc_interp_naive, test_acc_interp_naive,
                    train_acc_interp_clever_list, test_acc_interp_clever_list, perm_labels):
    """Plot accuracy landscape for naive and all permuted interpolations."""
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111)
    
    # Compute indices of highest mean train and test acc
    test_means = [np.mean(test_acc) for test_acc in test_acc_interp_clever_list]
    i_test = np.argmax(test_means)
    
    # Naive interpolation
    ax.plot(lambdas, train_acc_interp_naive, linestyle="dashed", color=colors[0], 
            alpha=0.5, linewidth=2*plot_line_width, label="Train, naïve interp.")
    ax.plot(lambdas, test_acc_interp_naive, linestyle="solid", color=colors[0], 
            alpha=0.5, linewidth=2*plot_line_width, label="Test, naïve interp.")

    for i, (train_acc, test_acc, label) in enumerate(zip(train_acc_interp_clever_list, 
                                                         test_acc_interp_clever_list, 
                                                         perm_labels)):
        color = colors[i+1 % len(colors)]
        linewidth = 2 * plot_line_width if i == i_test else plot_line_width
        ax.plot(lambdas, train_acc, linestyle="dashed", color=color, 
                linewidth=linewidth, label=f"Train, permuted interp. {label}")
        ax.plot(lambdas, test_acc, linestyle="solid", color=color, alpha=0.7,
                linewidth=linewidth, label=f"Test, permuted interp. {label}")
    ax.set_xlabel("$\lambda$")
    ax.set_xticks([0, 1])
    ax.set_xticklabels(["Model $A$", "Model $B$"])
    ax.set_ylabel("Accuracy")
    ax.set_title(f"Accuracy between the two models")
    #ax.legend(loc="lower right", framealpha=0.5)
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), framealpha=0.5, fontsize=legend_size)
    plt.subplots_adjust(right=0.65)
    fig.tight_layout()
    return fig