from collections import defaultdict
from typing import NamedTuple
from flax.core import freeze, unfreeze
import copy
import jax.numpy as jnp
from jax import random
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt
from utils import rngmix
import numpy as np

class PermutationSpec(NamedTuple):
  perm_to_axes: dict
  axes_to_perm: dict
def get_nested(params, dotted_key):
    keys = dotted_key.split('.')
    for k in keys:
        params = params[k]
    return params

def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
    perm_to_axes = defaultdict(list)
    for param_name, axis_list in axes_to_perm.items():
        for axis, perm_name in enumerate(axis_list):
            if perm_name is not None:
                perm_to_axes[perm_name].append((param_name, axis))
    return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)
def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
    """Get parameter `k` from `params`, with the permutations applied."""
    w = params[k]
    if k not in ps.axes_to_perm.keys():
        return w
    for axis, p in enumerate(ps.axes_to_perm[k]):
        # Skip the axis we're trying to permute.
        if axis == except_axis:
            continue
        # None indicates that there is no permutation relevant to that axis.
        if p is not None:
            w = jnp.take(w, perm[p], axis=axis)
    return w

def apply_permutation(ps: PermutationSpec, perm, params):
    """Apply a `perm` to `params`."""
    return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}


def weight_matching(rng,
                    ps: PermutationSpec,
                    params_a,
                    params_b,
                    max_iter=150,
                    patience=10,
                    init_perm=None,
                    silent=True):
    """Find a permutation of `params_b` to make them match `params_a` with early stopping."""
    perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}
    # Initialize permutation
    perm = {p: jnp.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
    perm_names = list(perm.keys())
    # Helper function to compute global alignment score
    def compute_score(perm):
        score = 0.0
        for wk in params_a.keys():
            w_a = params_a[wk]
            w_b = get_permuted_param(ps, perm, wk, params_b)
            # ------------------------------------------------
            score += jnp.vdot(w_a, w_b)
        return score
    # Initialize best permutation and score
    best_perm = perm.copy()
    best_score = compute_score(perm)
    no_improve_iter = 0
    # Main iteration loop
    for iteration in range(max_iter):
        # Update permutations
        for p_ix in random.permutation(rngmix(rng, iteration), len(perm_names)):
            p = perm_names[p_ix]
            n = perm_sizes[p]
            A = jnp.zeros((n, n))
            for wk, axis in ps.perm_to_axes[p]:
                w_a = params_a[wk]
                w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
                w_a = jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
                w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
                A += w_a @ w_b.T
            ri, ci = linear_sum_assignment(A, maximize=True)
            assert (ri == jnp.arange(len(ri))).all()
            oldL = jnp.vdot(A, jnp.eye(n)[perm[p]])
            newL = jnp.vdot(A, jnp.eye(n)[ci, :])
            if not silent and abs(newL - oldL) > 1e-3:
                print(f"{iteration}/{p}: {newL - oldL}")
            perm[p] = jnp.array(ci)
        # Compute current global score and update best if improved
        current_score = compute_score(perm)
        if current_score > best_score + 1e-12:
            best_score = current_score
            best_perm = perm.copy()
            no_improve_iter = 0
        else:
            no_improve_iter += 1
            if no_improve_iter >= patience:
                break
    return best_perm
def flax_gpt2_permutation_spec_moe(config) -> PermutationSpec:
    """
    Creates a PermutationSpec for ViT-MoE where the MoE block is at `vit/.encoder/layer/{num_layers - 1}/moe_block`.
    """
    axes_to_perm = {}
    # gate_prefix = f"transformer/h/{config.moe_layer_indices}/mlp/gate"
    # axes_to_perm[f"{gate_prefix}/kernel"] = (None, "P_expert")   # (hidden,  num_routed_experts)
    # axes_to_perm[f"{gate_prefix}/bias"]   = ("P_expert",)        # (num_routed_experts,)
    if(config.num_shared_experts > 0):
        fc_prefix = f"transformer/h/{config.moe_layer_indices}/mlp/shared_experts/c_fc"
        proj_prefix = f"transformer/h/{config.moe_layer_indices}/mlp/shared_experts/c_proj"
        axes_to_perm[f"{fc_prefix}/kernel"] = ("P_shared",None)
        axes_to_perm[f"{fc_prefix}/bias"]   = ("P_shared",)
        # c_proj: (in=hidden, out=ff_dim) – permute the *input* axis
        axes_to_perm[f"{proj_prefix}/kernel"] = (None, "P_shared")
        axes_to_perm[f"{proj_prefix}/bias"]   = (None,)     
    for expert in range(config.num_routed_experts):
        fc_prefix   = f"transformer/h/{config.moe_layer_indices}/mlp/routed_experts_{expert}/c_fc"
        proj_prefix = f"transformer/h/{config.moe_layer_indices}/mlp/routed_experts_{expert}/c_proj"
        # c_fc: (out=ff_dim, in=hidden) – permute the *output* axis
        axes_to_perm[f"{fc_prefix}/kernel"] = (f"P_routed_{expert}",None)
        axes_to_perm[f"{fc_prefix}/bias"]   = (f"P_routed_{expert}",)
        # c_proj: (in=hidden, out=ff_dim) – permute the *input* axis
        axes_to_perm[f"{proj_prefix}/kernel"] = (None,f"P_routed_{expert}")
        axes_to_perm[f"{proj_prefix}/bias"]   = (None,)
    return permutation_spec_from_axes_to_perm(axes_to_perm)

def permute_moe_block(params, pi, config):
    """
    Permute the router and expert submodules of the final MoE layer in a ViT-MoE model.

    Args:
        params: FrozenDict of model parameters.
        pi: jnp.ndarray of shape (num_routed_experts,) - new permutation indices.
        num_layers: Total number of transformer layers.
        num_routed_experts: Number of experts in the MoE block.

    Returns:
        FrozenDict with permuted MoE parameters.
    """
    params = unfreeze(params)
    moe_layer_indices = str(config.moe_layer_indices)
    block = params["transformer"]["h"][moe_layer_indices]["mlp"]    # → gate / routed_experts_*
    # --- Permute gate ---
    block["gate"]["kernel"] = block["gate"]["kernel"][:, tuple(pi)]
    # bias   shape (E,)        – permute entries
    block["gate"]["bias"]   = block["gate"]["bias"][jnp.array(pi)]
    # --- Backup original experts ---
    originals = {
        e: {
            "c_fc":   copy.deepcopy(block[f"routed_experts_{e}"]["c_fc"]),
            "c_proj": copy.deepcopy(block[f"routed_experts_{e}"]["c_proj"]),
        }
        for e in range(config.num_routed_experts)
    }
    # --- Reassign experts with permuted source ---
    for e in range(config.num_routed_experts):
        src = int(pi[e])
        block[f"routed_experts_{e}"]["c_fc"]   = originals[src]["c_fc"]
        block[f"routed_experts_{e}"]["c_proj"] = originals[src]["c_proj"]
    return freeze(params)

plot_line_width=1
legend_size=5
colors = [
        '#F4A7A7', '#A7C6F4', '#C2F4A7', '#F4D7A7', '#D7A7F4',  # Reds, Blues, Greens, Oranges, Purples
        '#F4A7C6', '#A7F4D7', '#E6F4A7', '#A7BFF4', '#F4A7E6',  # Pinks, Cyans, Yellows, etc.
        '#FFB3BA', '#BAFFC9', '#BAE1FF', '#FFFFBA', '#FFDFBA',
        '#E6A7F4', '#A7F4E6', '#F4C2A7', '#A7E6F4', '#D7F4A7',
        '#F4A7B3', '#A7D7F4', '#C6A7F4', '#F4E6A7', '#A7F4C2'
]
def plot_interp_loss(lambdas, train_loss_interp_naive, test_loss_interp_naive,
                     train_loss_interp_clever_list, test_loss_interp_clever_list, perm_labels):
    """Plot loss landscape for naive and all permuted interpolations."""
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111)
    
    # Compute indices of lowest mean test losses
    test_means = [np.mean(test_loss) for test_loss in test_loss_interp_clever_list]
    i_test = np.argmin(test_means)
    
    # Naive interpolation
    ax.plot(lambdas, train_loss_interp_naive, linestyle="dashed", color=colors[0], 
            alpha=0.5, linewidth=2 * plot_line_width, label="Train, naïve interp.")
    ax.plot(lambdas, test_loss_interp_naive, linestyle="solid", color=colors[0], 
            alpha=0.5, linewidth=2 * plot_line_width, label="Test, naïve interp.")
    
    for i, (train_loss, test_loss, label) in enumerate(zip(train_loss_interp_clever_list, 
                                                           test_loss_interp_clever_list, 
                                                           perm_labels)):
        color = colors[i+1 % len(colors)]
        # Set line width: 2 * plot_line_width for lowest mean, plot_line_width otherwise
        linewidth = 2 * plot_line_width if i == i_test else plot_line_width
        ax.plot(lambdas, train_loss, linestyle="dashed", color=color, 
                linewidth=linewidth, label=f"Train, permuted interp. {label}")
        ax.plot(lambdas, test_loss, linestyle="solid", color=color, alpha=0.7,
                linewidth=linewidth, label=f"Test, permuted interp. {label}")
    
    # Configure axes and legend
    ax.set_xlabel("$\lambda$")
    ax.set_xticks([0, 1])
    ax.set_xticklabels(["Model $A$", "Model $B$"])
    ax.set_ylabel("Loss")
    ax.set_title(f"Loss landscape between the two models")
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), framealpha=0.5, fontsize=legend_size)
    plt.subplots_adjust(right=0.65)
    fig.tight_layout()
    return fig

def plot_interp_ppl(lambdas, train_ppl_interp_naive, test_ppl_interp_naive,
                    train_ppl_interp_clever_list, test_ppl_interp_clever_list, perm_labels):
    """Plot PPL landscape for naive and all permuted interpolations."""
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111)
    
    # Compute indices of highest mean train and test ppl
    test_means = [np.mean(test_ppl) for test_ppl in test_ppl_interp_clever_list]
    i_test = np.argmax(test_means)
    
    # Naive interpolation
    ax.plot(lambdas, train_ppl_interp_naive, linestyle="dashed", color=colors[0], 
            alpha=0.5, linewidth=2*plot_line_width, label="Train, naïve interp.")
    ax.plot(lambdas, test_ppl_interp_naive, linestyle="solid", color=colors[0], 
            alpha=0.5, linewidth=2*plot_line_width, label="Test, naïve interp.")

    for i, (train_ppl, test_ppl, label) in enumerate(zip(train_ppl_interp_clever_list, 
                                                         test_ppl_interp_clever_list, 
                                                         perm_labels)):
        color = colors[i+1 % len(colors)]
        linewidth = 2 * plot_line_width if i == i_test else plot_line_width
        ax.plot(lambdas, train_ppl, linestyle="dashed", color=color, 
                linewidth=linewidth, label=f"Train, permuted interp. {label}")
        ax.plot(lambdas, test_ppl, linestyle="solid", color=color, alpha=0.7,
                linewidth=linewidth, label=f"Test, permuted interp. {label}")
    ax.set_xlabel("$\lambda$")
    ax.set_xticks([0, 1])
    ax.set_xticklabels(["Model $A$", "Model $B$"])
    ax.set_ylabel("Perplexity")
    ax.set_title(f"PPL between the two models")
    #ax.legend(loc="lower right", framealpha=0.5)
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), framealpha=0.5, fontsize=legend_size)
    plt.subplots_adjust(right=0.65)
    fig.tight_layout()
    return fig