from collections import defaultdict
from typing import NamedTuple
from flax.core import freeze, unfreeze

import jax.numpy as jnp
from jax import random
from scipy.optimize import linear_sum_assignment

from .utils import rngmix

class PermutationSpec(NamedTuple):
  perm_to_axes: dict
  axes_to_perm: dict

def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
  perm_to_axes = defaultdict(list)
  for wk, axis_perms in axes_to_perm.items():
    for axis, perm in enumerate(axis_perms):
      if perm is not None:
        perm_to_axes[perm].append((wk, axis))
  return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)

def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
  """Get parameter `k` from `params`, with the permutations applied."""
  w = params[k]
  if k not in ps.axes_to_perm.keys():
    return w
  for axis, p in enumerate(ps.axes_to_perm[k]):
    # Skip the axis we're trying to permute.
    if axis == except_axis:
      continue

    # None indicates that there is no permutation relevant to that axis.
    if p is not None:
      w = jnp.take(w, perm[p], axis=axis)
  return w

def apply_permutation(ps: PermutationSpec, perm, params):
  """Apply a `perm` to `params`."""
  return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}

def weight_matching(rng,
                    ps: PermutationSpec,
                    params_a,
                    params_b,
                    max_iter=150,
                    patience=10,
                    init_perm=None,
                    silent=True):
    """Find a permutation of `params_b` to make them match `params_a` with early stopping."""
    perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}

    # Initialize permutation
    perm = {p: jnp.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
    perm_names = list(perm.keys())

    # Helper function to compute global alignment score
    def compute_score(perm):
        score = 0.0
        for wk in params_a.keys():
            w_a = params_a[wk]
            w_b = get_permuted_param(ps, perm, wk, params_b)
            score += jnp.vdot(w_a, w_b)
        return score

    # Initialize best permutation and score
    best_perm = perm.copy()
    best_score = compute_score(perm)
    no_improve_iter = 0

    # Main iteration loop
    for iteration in range(max_iter):
        # Update permutations
        for p_ix in random.permutation(rngmix(rng, iteration), len(perm_names)):
            p = perm_names[p_ix]
            n = perm_sizes[p]
            A = jnp.zeros((n, n))
            for wk, axis in ps.perm_to_axes[p]:
                w_a = params_a[wk]
                w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
                w_a = jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
                w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
                A += w_a @ w_b.T

            ri, ci = linear_sum_assignment(A, maximize=True)
            assert (ri == jnp.arange(len(ri))).all()

            oldL = jnp.vdot(A, jnp.eye(n)[perm[p]])
            newL = jnp.vdot(A, jnp.eye(n)[ci, :])
            if not silent and abs(newL - oldL) > 1e-3:
                print(f"{iteration}/{p}: {newL - oldL}")

            perm[p] = jnp.array(ci)

        # Compute current global score and update best if improved
        current_score = compute_score(perm)
        if current_score > best_score + 1e-12:
            best_score = current_score
            best_perm = perm.copy()
            no_improve_iter = 0
        else:
            no_improve_iter += 1
            if no_improve_iter >= patience:
                break

    return best_perm

def test_weight_matching():
  """If we just have a single hidden layer then it should converge after just one step."""
  ps = mlp_permutation_spec(num_hidden_layers=1)
  rng = random.PRNGKey(123)
  num_hidden = 10
  shapes = {
      "Dense_0/kernel": (2, num_hidden),
      "Dense_0/bias": (num_hidden, ),
      "Dense_1/kernel": (num_hidden, 3),
      "Dense_1/bias": (3, )
  }
  params_a = {k: random.normal(rngmix(rng, f"a-{k}"), shape) for k, shape in shapes.items()}
  params_b = {k: random.normal(rngmix(rng, f"b-{k}"), shape) for k, shape in shapes.items()}
  perm = weight_matching(rng, ps, params_a, params_b)
  print(perm)

def vit_permutation_spec_moe(num_experts: int = 2, shared = False) -> PermutationSpec:
    """
    Define the permutation specification for a ViT model where the last MLP layer is replaced with an MoE block.

    Args:
        num_layers (int): Number of transformer encoder layers in the ViT.
        num_experts (int): Number of experts in the MoE block (default=2 as per mnist_vit_finetune_moe.py).

    Returns:
        PermutationSpec: Specifies how to permute the model's parameters.
    """
    axes_to_perm = {}
    # Add MoE experts' dense layers dynamically based on num_experts
    for e in range(num_experts):
        # First dense layer of expert e
        axes_to_perm[
            f"MoETransformerEncoderLayer_0/MoEBlock_0/Dense_{2 * e + 1}/kernel"
        ] = (None, f"P_hidden_expert_{e}")
        axes_to_perm[
            f"MoETransformerEncoderLayer_0/MoEBlock_0/Dense_{2 * e + 1}/bias"
        ] = (f"P_hidden_expert_{e}",)
        # Second dense layer of expert e
        axes_to_perm[
            f"MoETransformerEncoderLayer_0/MoEBlock_0/Dense_{2 * e + 2}/kernel"
        ] = (f"P_hidden_expert_{e}", None)
        axes_to_perm[
            f"MoETransformerEncoderLayer_0/MoEBlock_0/Dense_{2 * e + 2}/bias"
        ] = (None,)
    
    # shared expert
    if shared:
        axes_to_perm[
            f"MoETransformerEncoderLayer_0/MoEBlock_0/shared_expert/Dense_0/kernel"
        ] = (None, f"P_hidden_expert_shared")
        axes_to_perm[
            f"MoETransformerEncoderLayer_0/MoEBlock_0/shared_expert/Dense_0/bias"
        ] = (f"P_hidden_expert_shared",)
        axes_to_perm[
            f"MoETransformerEncoderLayer_0/MoEBlock_0/shared_expert/Dense_1/kernel"
        ] = (f"P_hidden_expert_shared", None)
        axes_to_perm[
            f"MoETransformerEncoderLayer_0/MoEBlock_0/shared_expert/Dense_1/bias"
        ] = (None,)

    return permutation_spec_from_axes_to_perm(axes_to_perm)

def permute_moe_block(params, pi, num_layers: int, num_experts: int):
    params = unfreeze(params)
    moe_layer_key = f'MoETransformerEncoderLayer_0'
    moe_block_key = 'MoEBlock_0'
    # Permute gating parameters
    gating_kernel = params[moe_layer_key][moe_block_key]['Dense_0']['kernel']
    gating_bias = params[moe_layer_key][moe_block_key]['Dense_0']['bias']
    params[moe_layer_key][moe_block_key]['Dense_0']['kernel'] = gating_kernel[:, pi]
    params[moe_layer_key][moe_block_key]['Dense_0']['bias'] = gating_bias[pi]
    # Collect original expert parameters
    original_experts = {}
    for e in range(num_experts):
        original_experts[e] = {
            'first': params[moe_layer_key][moe_block_key][f'Dense_{2*e + 1}'],
            'second': params[moe_layer_key][moe_block_key][f'Dense_{2*e + 2}']
        }
    # Assign permuted expert parameters
    for e in range(num_experts):
        original_e = pi[e]
        params[moe_layer_key][moe_block_key][f'Dense_{2*e + 1}'] = original_experts[original_e]['first']
        params[moe_layer_key][moe_block_key][f'Dense_{2*e + 2}'] = original_experts[original_e]['second']
    return freeze(params)



if __name__ == "__main__":
    test_weight_matching()
