# based on https://github.com/samuela/git-re-basin

from huggingface_hub import hf_hub_download, notebook_login
import numpy as np

from jax import numpy as jnp
from jax import random
from jax.random import split as rngmix
from scipy.optimize import linear_sum_assignment
from collections import defaultdict
from typing import NamedTuple
import pickle
from argparse import ArgumentParser


class PermutationSpec(NamedTuple):
    perm_to_axes: dict
    axes_to_perm: dict


def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
    perm_to_axes = defaultdict(list)
    for wk, axis_perms in axes_to_perm.items():
        for axis, perm in enumerate(axis_perms):
            if perm is not None:
                perm_to_axes[perm].append((wk, axis))
    return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)


def mlp_permutation_spec(num_hidden_layers: int) -> PermutationSpec:
    """We assume that one permutation cannot appear in two axes of the same weight array."""
    assert num_hidden_layers >= 1
    return permutation_spec_from_axes_to_perm({
        "Dense_0/kernel": (None, "P_0"),
        **{f"Dense_{i}/kernel": (f"P_{i - 1}", f"P_{i}")
           for i in range(1, num_hidden_layers)},
        **{f"Dense_{i}/bias": (f"P_{i}",)
           for i in range(num_hidden_layers)},
        f"Dense_{num_hidden_layers}/kernel": (f"P_{num_hidden_layers - 1}", None),
        f"Dense_{num_hidden_layers}/bias": (None,),
    })


def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
    """Get parameter `k` from `params`, with the permutations applied."""
    w = params[k]
    for axis, p in enumerate(ps.axes_to_perm[k]):
        # Skip the axis we're trying to permute.
        if axis == except_axis:
            continue

        # None indicates that there is no permutation relevant to that axis.
        if p is not None:
            w = jnp.take(w, perm[p], axis=axis)

    return w


def weight_matching(rng,
                    ps: PermutationSpec,
                    params_a,
                    params_b,
                    max_iter=100,
                    init_perm=None,
                    silent=False):
    """Find a permutation of `params_b` to make them match `params_a`."""
    perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}

    perm = {p: jnp.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
    perm_names = list(perm.keys())

    for iteration in range(max_iter):
        progress = False
        mix_rng, rng = random.split(rng, 2)
        for p_ix in random.permutation(mix_rng, len(perm_names)):
            p = perm_names[p_ix]
            n = perm_sizes[p]
            A = jnp.zeros((n, n))
            print('constructing task...')
            for wk, axis in ps.perm_to_axes[p]:
                w_a = params_a[wk]
                w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
                w_a = jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
                w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
                A += w_a @ w_b.T
            print(A.shape)
            print('solving task...')
            ri, ci = linear_sum_assignment(np.asarray(A), maximize=True)
            assert (ri == jnp.arange(len(ri))).all()
            print('solved')
            oldL = jnp.vdot(A, jnp.eye(n)[perm[p]])
            newL = jnp.vdot(A, jnp.eye(n)[ci, :])
            if not silent: print(f"{iteration}/{p}: {newL - oldL}")
            progress = progress or newL > oldL + 1e-12

            perm[p] = jnp.array(ci)

        if not progress:
            break

    return perm


layer_to_filename_base = {
    0: "layer_0/width_16k/average_l0_46",
    1: "layer_1/width_16k/average_l0_40",
    2: "layer_2/width_16k/average_l0_141",
    3: "layer_3/width_16k/average_l0_59",
    4: "layer_4/width_16k/average_l0_125",
    5: "layer_5/width_16k/average_l0_68",
    6: "layer_6/width_16k/average_l0_70",
    7: "layer_7/width_16k/average_l0_69",
    8: "layer_8/width_16k/average_l0_71",
    9: "layer_9/width_16k/average_l0_73",
    10: "layer_10/width_16k/average_l0_77",
    11: "layer_11/width_16k/average_l0_80",
    12: "layer_12/width_16k/average_l0_82",
    13: "layer_13/width_16k/average_l0_83",
    14: "layer_14/width_16k/average_l0_83",
    15: "layer_15/width_16k/average_l0_78",
    16: "layer_16/width_16k/average_l0_78",
    17: "layer_17/width_16k/average_l0_77",
    18: "layer_18/width_16k/average_l0_74",
    19: "layer_19/width_16k/average_l0_73",
    20: "layer_20/width_16k/average_l0_71",
    21: "layer_21/width_16k/average_l0_70",
    22: "layer_22/width_16k/average_l0_72",
    23: "layer_23/width_16k/average_l0_75",
    24: "layer_24/width_16k/average_l0_73",
    25: "layer_25/width_16k/average_l0_55"
}
pairs = {
    k: (layer_to_filename_base[k], layer_to_filename_base[k + 1]) for k in layer_to_filename_base.keys() if k != 25
}

sae_to_jax = {
    'W_enc': 'Dense_0/kernel',
    'W_dec': 'Dense_1/kernel',
    'b_dec': 'Dense_1/bias',
    'b_enc': 'Dense_0/bias',
}


def fold_thresholds(params):
    params['W_enc'] /= params['threshold']
    params['b_enc'] /= params['threshold']
    params['W_dec'] *= params['threshold'].reshape(-1, 1)
    params['threshold'] /= params['threshold']
    return params


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--pair",
        default=0,
        type=int,
        help="Layer pair idx, i.e. 19 is for finding permutation between 19th and 20th layers."
    )
    parser.add_argument("--output-path", default="./permutations")
    args = parser.parse_args()
    print(f'permutations for {pairs[args.pair]}')
    k = args.pair
    v1, v2 = pairs[k]
    from_layer = v1.split("/")[0].split("_")[1]
    to_layer = v2.split("/")[0].split("_")[1]

    path_to_params_prev = hf_hub_download(
        repo_id="google/gemma-scope-2b-pt-res",
        filename=f"{v1}/params.npz",
        force_download=False,
    )

    path_to_params_current = hf_hub_download(
        repo_id="google/gemma-scope-2b-pt-res",
        filename=f"{v2}/params.npz",
        force_download=False,
    )

    params_prev = np.load(path_to_params_prev)
    jax_params_prev = {k: jnp.asarray(v) for k, v in params_prev.items()}
    jax_params_prev = fold_thresholds(jax_params_prev)

    params_current = np.load(path_to_params_current)
    jax_params_current = {k: jnp.asarray(v) for k, v in params_current.items()}
    jax_params_current = fold_thresholds(jax_params_current)

    jax_to_sae = {v: k for k, v in sae_to_jax.items()}

    jax_params_sae_prev = {sae_to_jax[k]: v for k, v in jax_params_prev.items() if k != "threshold"}
    jax_params_sae_current = {sae_to_jax[k]: v for k, v in jax_params_current.items() if k != "threshold"}
    spec = mlp_permutation_spec(1)
    rng = random.PRNGKey(123)
    print(f"Doing {from_layer} ---> {to_layer}")
    result = weight_matching(rng, spec, jax_params_sae_prev, jax_params_sae_current)

    with open(f'{args.output_path}/permutations_{from_layer}_{to_layer}.pkl', 'wb') as outp:
        pickle.dump(result, outp)
    print(f"Done {from_layer} ---> {to_layer}")
    print(f"vanilla \n\n\n")

    for k in jax_params_prev.keys():
        if k == 'b_dec':
            continue
        mse = ((jax_params_prev[k] - jax_params_current[k]) ** 2).mean()
        print(f'MSE for {k}: {mse:.4f}')

    print(f"\n\n\nmatch \n\n\n")

    for k in jax_params_prev.keys():
        if k == 'b_dec':
            continue
        try:
            mse = ((jax_params_prev[k] - jax_params_current[k][result['P_0']]) ** 2).mean()
        except Exception as e:
            mse = ((jax_params_prev[k] - jax_params_current[k][:, result['P_0']]) ** 2).mean()
        print(f'MSE for {k}: {mse:.6f}')

