# def parse_sync_matrix(sync_matrix, n, combinations, order):
#     """
#     Parse the large synchronization block matrix into the individual permutation matrices.
#     """
#     assert sync_matrix.shape == [n, n]

#     sync_perm_matrices = {comb: None for comb in combinations}
#     for comb in combinations:
#         new_perm_matrix_comb = sync_matrix[
#             n * order[comb[0]] : n * (order[comb[0]] + 1), n * order[comb[1]] : n * (order[comb[1]] + 1)
#         ]
#         assert is_valid_permutation_matrix(new_perm_matrix_comb)
#         sync_perm_matrices[comb] = new_perm_matrix_comb
#     return sync_perm_matrices


# def parse_sync_matrix(sync_matrix, n, combinations, order):
#     """
#     Parse the large synchronization block matrix into the individual permutation matrices.
#     """
#     assert sync_matrix.shape == (3 * n, n)

#     from_universe = {"a": sync_matrix[:n], "b": sync_matrix[n : 2 * n], "c": sync_matrix[2 * n :]}
#     to_universe = {"a": sync_matrix[:n].T, "b": sync_matrix[n : 2 * n].T, "c": sync_matrix[2 * n :].T}

#     sync_perm_matrices = {comb: None for comb in combinations}
#     for comb in combinations:
#         new_perm_matrix_comb = to_universe[comb[1]] @ from_universe[comb[0]]
#         assert is_valid_permutation_matrix(new_perm_matrix_comb)
#         sync_perm_matrices[comb] = new_perm_matrix_comb
#     return sync_perm_matrices

# def construct_uber_matrix(perm_matrices, order, n):
#     uber_matrix = torch.zeros((n * 3, n * 3))

#     uber_matrix[:n, :n] = torch.eye(n)
#     uber_matrix[n : n * 2, n : n * 2] = torch.eye(n)
#     uber_matrix[n * 2 :, n * 2 :] = torch.eye(n)

#     for comb, perm_matrix in perm_matrices.items():
#         uber_matrix[
#             n * order[comb[0]] : n * (order[comb[0]] + 1), n * order[comb[1]] : n * (order[comb[1]] + 1)
#         ] = perm_matrix

#     return uber_matrix


# def weight_matching(ps: PermutationSpec, params_a, params_b, max_iter=100, init_perm=None):
#     """
#     Find a permutation of params_b to make them match params_a.
#     :param ps: PermutationSpec
#     :param params_a: dict of params
#     """

#     # For a MLP of 4 layers it would be something like {'P_0': 512, 'P_1': 512, 'P_2': 512, 'P_3': 256}. Input and output dim are never permuted.
#     perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}

#     # initialize with identity permutation if none given
#     perm_matrices = {p: torch.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
#     # e.g. P0, P1, ..
#     perm_names = list(perm_matrices.keys())

#     for iteration in tqdm(range(max_iter), desc="Weight matching"):
#         progress = False

#         # iterate over the permutation matrices in random order
#         for p_ix in torch.randperm(len(perm_names)):
#             p = perm_names[p_ix]
#             n = perm_sizes[p]

#             A = torch.zeros((n, n))

#             # all the params that are permuted by this permutation matrix, together with the axis on which it acts
#             # e.g. ('layer_0.weight', 0), ('layer_0.bias', 0), ('layer_1.weight', 0)..
#             params_and_axes = ps.perm_to_axes[p]

#             for params_name, axis in params_and_axes:
#                 w_a = params_a[params_name]
#                 w_b = params_b[params_name]
#                 # w_c = params_c[params_name]

#                 assert w_a.shape == w_b.shape

#                 perms_to_apply = ps.axes_to_perm[params_name]

#                 w_b_to_a = get_permuted_param(w_b, perms_to_apply, perm_matrices, except_axis=axis)

#                 w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1))
#                 w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1))

#                 A += w_a @ w_b_to_a.T

#             ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True)
#             # pi_b_to_a, pi_c_to_a, pi_c_to_b = alg(A, B, C)

#             assert (torch.tensor(ri) == torch.arange(len(ri))).all()

#             old_similarity = compute_weights_similarity(A, perm_matrices[p])

#             perm_matrices[p] = torch.Tensor(ci)

#             new_similarity = compute_weights_similarity(A, perm_matrices[p])

#             pylogger.info(f"Iteration {iteration}, Permutation {p}: {new_similarity - old_similarity}")

#             progress = progress or new_similarity > old_similarity + 1e-12

#         if not progress:
#             break

#     return perm_matrices
