from collections import defaultdict
from typing import NamedTuple
from flax.core import freeze, unfreeze

import jax.numpy as jnp
from jax import random
from scipy.optimize import linear_sum_assignment

rngmix = lambda rng, x: random.fold_in(rng, hash(x) % 2**16)

class PermutationSpec(NamedTuple):
  perm_to_axes: dict
  axes_to_perm: dict

def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
  perm_to_axes = defaultdict(list)
  for wk, axis_perms in axes_to_perm.items():
    for axis, perm in enumerate(axis_perms):
      if perm is not None:
        perm_to_axes[perm].append((wk, axis))
  return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)

def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
  """Get parameter `k` from `params`, with the permutations applied."""
  w = params[k]
  if k not in ps.axes_to_perm.keys():
    return w
  for axis, p in enumerate(ps.axes_to_perm[k]):
    # Skip the axis we're trying to permute.
    if axis == except_axis:
      continue

    # None indicates that there is no permutation relevant to that axis.
    if p is not None:
      w = jnp.take(w, perm[p], axis=axis)
  return w

def apply_permutation(ps: PermutationSpec, perm, params):
  """Apply a `perm` to `params`."""
  return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}

def weight_matching(rng,
                    ps: PermutationSpec,
                    params_a,
                    params_b,
                    max_iter=150,
                    patience=10,
                    init_perm=None,
                    silent=True):
    """Find a permutation of `params_b` to make them match `params_a` with early stopping."""
    perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}

    # Initialize permutation
    perm = {p: jnp.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
    perm_names = list(perm.keys())

    # Helper function to compute global alignment score
    def compute_score(perm):
        score = 0.0
        for wk in params_a.keys():
            w_a = params_a[wk]
            w_b = get_permuted_param(ps, perm, wk, params_b)
            score += jnp.vdot(w_a, w_b)
        return score

    # Initialize best permutation and score
    best_perm = perm.copy()
    best_score = compute_score(perm)
    no_improve_iter = 0

    # Main iteration loop
    for iteration in range(max_iter):
        # Update permutations
        for p_ix in random.permutation(rngmix(rng, iteration), len(perm_names)):
            p = perm_names[p_ix]
            n = perm_sizes[p]
            A = jnp.zeros((n, n))
            for wk, axis in ps.perm_to_axes[p]:
                w_a = params_a[wk]
                w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
                w_a = jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
                w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
                A += w_a @ w_b.T

            ri, ci = linear_sum_assignment(A, maximize=True)
            assert (ri == jnp.arange(len(ri))).all()

            oldL = jnp.vdot(A, jnp.eye(n)[perm[p]])
            newL = jnp.vdot(A, jnp.eye(n)[ci, :])
            if not silent and abs(newL - oldL) > 1e-3:
                print(f"{iteration}/{p}: {newL - oldL}")

            perm[p] = jnp.array(ci)

        # Compute current global score and update best if improved
        current_score = compute_score(perm)
        if current_score > best_score + 1e-12:
            best_score = current_score
            best_perm = perm.copy()
            no_improve_iter = 0
        else:
            no_improve_iter += 1
            if no_improve_iter >= patience:
                break

    return best_perm

def vit_permutation_spec_moe(moe_layer_which: int, num_experts: int, shared = False) -> PermutationSpec:
    """
    Define the permutation specification for a ViT model where the last MLP layer is replaced with an MoE block.

    Args:
        num_layers (int): Number of transformer encoder layers in the ViT.
        num_experts (int): Number of experts in the MoE block (default=2 as per mnist_vit_finetune_moe.py).

    Returns:
        PermutationSpec: Specifies how to permute the model's parameters.
    """
    axes_to_perm = {}
    # Add MoE experts' dense layers dynamically based on num_experts
    for e in range(num_experts):
        # First dense layer of expert e
        axes_to_perm[
            f"Transformer/encoderblock_{moe_layer_which}/MoEBlock/expert_{e}_dense1/kernel"
        ] = (None, f"P_hidden_expert_{e}")
        axes_to_perm[
            f"Transformer/encoderblock_{moe_layer_which}/MoEBlock/expert_{e}_dense1/bias"
        ] = (f"P_hidden_expert_{e}",)
        # Second dense layer of expert e
        axes_to_perm[
            f"Transformer/encoderblock_{moe_layer_which}/MoEBlock/expert_{e}_dense2/kernel"
        ] = (f"P_hidden_expert_{e}", None)
        axes_to_perm[
            f"Transformer/encoderblock_{moe_layer_which}/MoEBlock/expert_{e}_dense2/bias"
        ] = (None,)

    # shared expert
    if shared:
        axes_to_perm[
            f"Transformer/encoderblock_{moe_layer_which}/MoEBlock/shared_expert/dense1/kernel"
        ] = (None, f"P_hidden_expert_shared")
        axes_to_perm[
            f"Transformer/encoderblock_{moe_layer_which}/MoEBlock/shared_expert/dense1/bias"
        ] = (f"P_hidden_expert_shared",)
        axes_to_perm[
            f"Transformer/encoderblock_{moe_layer_which}/MoEBlock/shared_expert/dense2/kernel"
        ] = (f"P_hidden_expert_shared", None)
        axes_to_perm[
            f"Transformer/encoderblock_{moe_layer_which}/MoEBlock/shared_expert/dense2/bias"
        ] = (None,)

    return permutation_spec_from_axes_to_perm(axes_to_perm)


def permute_moe_block(params, pi, moe_layer_which: int, num_experts: int):
    params = unfreeze(params)
    transformer_layer_key = "Transformer"
    moe_layer_key = f'encoderblock_{moe_layer_which}'
    moe_block_key = 'MoEBlock'
    # Permute gating parameters
    gating_kernel = params[transformer_layer_key][moe_layer_key][moe_block_key]['gate']['kernel']
    gating_bias = params[transformer_layer_key][moe_layer_key][moe_block_key]['gate']['bias']
    params[transformer_layer_key][moe_layer_key][moe_block_key]['gate']['kernel'] = gating_kernel[:, pi]
    params[transformer_layer_key][moe_layer_key][moe_block_key]['gate']['bias'] = gating_bias[pi]
    # Collect original expert parameters
    original_experts = {}
    for e in range(num_experts):
        original_experts[e] = {
            'first': params[transformer_layer_key][moe_layer_key][moe_block_key][f'expert_{e}_dense1'],
            'second': params[transformer_layer_key][moe_layer_key][moe_block_key][f'expert_{e}_dense2']
        }
    # Assign permuted expert parameters
    for e in range(num_experts):
        original_e = pi[e]
        params[transformer_layer_key][moe_layer_key][moe_block_key][f'expert_{e}_dense1'] = original_experts[original_e]['first']
        params[transformer_layer_key][moe_layer_key][moe_block_key][f'expert_{e}_dense2'] = original_experts[original_e]['second']
    return freeze(params)

if __name__ == "__main__":
    test_weight_matching()
    #verify_permutation_spec_moe()
