from copy import deepcopy
from typing import Optional, List

import torch
from scipy.optimize import linear_sum_assignment
from torch.utils.data import DataLoader
from tqdm import tqdm

from transformers import OlmoeForCausalLM
from transformers.models.olmoe.modeling_olmoe import OlmoeMLP

from mergemoe.utils.constants import FP32_EPS

__all__ = [
    "align_OLMoE_permutation_for_all_experts_by_weight_matching",
    "align_OLMoE_permutation_for_all_experts_by_activation_matching",
    "merge_olmoe_mlp_by_activation_matching_within_and_across_models"
]


def permute_OLMoE_mlp_dense_expert_(
        dense_mlp: OlmoeMLP,
        perm: torch.Tensor,
) -> OlmoeMLP:
    """
    Permute the weights of MLP according to the given permutation, the behavior of the MLP is still the same as before.

    The original weights of the MLP are Wi and Wo, the permutation matrix is P, the input is x, then the output of the
    MLP is Wo @ act(Wi @ x^T), where act is the activation function. After the permutation, the output of the MLP is
    (Wo @ P) @ act((P @ Wi) @ x^T), which is the same as the original output.

    So the permutation of the MLP is (Wi, Wo) -> (P @ Wi, Wo @ P).

    Parameters
    ----------
    dense_mlp: OlmoeMLP
        The MLP to be permuted.
    perm: torch.Tensor, of shape (d_ff, )
        The permutation vector (not permutation matrix) to be applied to the MLP.

    Returns
    -------
    mlp: OlmoeMLP
        The permuted MLP, from in-place operation.

    Examples
    --------
    >>> from transformers import SwitchTransformersConfig
    >>> mlp = SwitchTransformersDenseActDense(SwitchTransformersConfig(d_model=16, d_ff=32, dropout_rate=0))
    >>> dummy_input = torch.randn(4, 16)
    >>> permuted_mlp = permute_switch_mlp_dense_expert_(deepcopy(mlp), torch.randperm(32))
    >>> torch.allclose(mlp(dummy_input), permuted_mlp(dummy_input))
    True
    """
    d_ff = dense_mlp.up_proj.out_features

    # Check the permutation vector
    if perm.shape != (d_ff,):
        raise ValueError(f"The shape of the permutation vector should be (d_ff, ), but got {perm.shape}.")
    if not torch.allclose(perm.sort()[0], torch.arange(d_ff, device=perm.device)):
        raise ValueError("The permutation vector should be a permutation.")

    # Permute the weights of the MLP
    with torch.no_grad():
        dense_mlp.up_proj.weight.data = dense_mlp.up_proj.weight.data[perm, :]
        dense_mlp.gate_proj.weight.data = dense_mlp.gate_proj.weight.data[perm, :]
        dense_mlp.down_proj.weight.data = dense_mlp.down_proj.weight.data[:, perm]

    return dense_mlp


def compute_OLMoE_permutation_by_weight_matching(
        reference_mlp: OlmoeMLP,
        target_mlp: OlmoeMLP,
        include_wo: bool,
) -> torch.Tensor:
    """
    Compute the permutation vector by weight match that can permute the weights of the target MLP
     to match the weights of the reference MLP.

    Specifically, find a permutation of the target MLP such that the summation of L2 error between each row of
     reference MLP and target MLP is minimal.

    This can be formulated as linear sum assignment problem:
        argmin{||vec(Wr) - vec(Wt @ P)||^2} = argmax{vec(Wr) * vec(Wt @ P)}

    Parameters
    ----------
    reference_mlp: OlmoeMLP
        The reference MLP.
    target_mlp: OlmoeMLP
        The target MLP.
    include_wo: bool, default False
        Whether to include the (each column) weights of the second layer of MLP in the weight matching.

    Returns
    -------
    perm: torch.Tensor, of shape (d_ff, )
        The permutation vector that can permute the weights of the target MLP to match the weights of the reference MLP.

    # Examples
    # --------
    # >>> from transformers import SwitchTransformersConfig
    # >>> mlp = SwitchTransformersDenseActDense(SwitchTransformersConfig(d_model=16, d_ff=32, dropout_rate=0))
    # >>> perm = compute_switch_permutation_by_weight_matching(mlp, mlp)
    # >>> torch.allclose(perm, torch.arange(32))
    # True
    """
    with torch.no_grad():
        lsa_cost_matrix = torch.mm(
            reference_mlp.up_proj.weight.data, target_mlp.up_proj.weight.data.t()
        )
        if include_wo:
            lsa_cost_matrix += torch.mm(
                reference_mlp.down_proj.weight.data.t(), target_mlp.down_proj.weight.data
            )
    _, perm = linear_sum_assignment(lsa_cost_matrix.cpu().to(torch.float).numpy(), maximize=True)
    return torch.from_numpy(perm).to(lsa_cost_matrix.device)




def align_OLMoE_permutation_for_all_experts_by_weight_matching(
        model: OlmoeForCausalLM,
        include_wo: bool,
) -> OlmoeForCausalLM:
    """
    Align the permutation of all experts in the OLMoE model by weight matching.

    Parameters
    ----------
    model: OlmoeForCausalLM
        The OLMoE model to be aligned.
    include_wo: bool, default False
        Whether to include the (each column) weights of the second layer of MLP in the weight matching.

    Returns
    -------
    model: OlmoeForCausalLM
        The aligned OLMoE model, from in-place operation.
    """
    config = model.config
    sparse_layer_indices = list(range(1, config.num_layers))
    num_experts = config.num_experts
    for layer_idx in tqdm(sparse_layer_indices,
                          desc=f"[Permutation]Aligning permutation {'with' if include_wo else 'without'} Wo"):
        mlp = model.model.layers[layer_idx].mlp
        for expert_idx in range(1, num_experts):
            # Permute decoder
            perm = compute_OLMoE_permutation_by_weight_matching(
                mlp.experts[0],
                mlp.experts[expert_idx],
                include_wo=include_wo,
            )
            mlp.experts[expert_idx] = permute_OLMoE_mlp_dense_expert_(
                mlp.experts[expert_idx], perm
            )

    return model


def compute_OLMoE_permutation_by_activation_matching(
        reference_mlp: OlmoeMLP,
        target_mlp: OlmoeMLP,
        forwarded_hidden_states: torch.Tensor,
        mini_batch_size: Optional[int] = None,
) -> torch.Tensor:
    """
    In particular, the forwarded_hidden_states is supposed to be gathered from tokens routed to both
    reference_mlp and target_mlp, and the permutation is computed by matching the activation of the
    forwarded_hidden_states.

    Parameters
    ----------
    reference_mlp: OlmoeMLP
        The reference MLP.
    target_mlp: OlmoeMLP
        The target MLP.
    forwarded_hidden_states: torch.Tensor, of shape (batch_size * seq_len, d_ff)
        The hidden states that are forwarded to both reference_mlp and target_mlp.
    mini_batch_size: int, default None
        The mini batch size for the activation matching. If None, the mini batch size is set to be the batch size.

    Returns
    -------
    perm: torch.Tensor, of shape (d_ff, )
        The permutation vector that can permute the weights of the target MLP to match the weights of the reference MLP.

    """
    if len(forwarded_hidden_states) == 0 or len(forwarded_hidden_states) == 1:
        return torch.arange(reference_mlp.up_proj.out_features, device=forwarded_hidden_states.device)

    if forwarded_hidden_states.shape[-1] != reference_mlp.up_proj.in_features:
        raise ValueError(
            f"The last dimension of forwarded_hidden_states should be {reference_mlp.up_proj.in_features}, "
            f"but got {forwarded_hidden_states.shape[-1]}."
        )
    if mini_batch_size is None:
        mini_batch_size = forwarded_hidden_states.shape[0]
    reference_activations = []
    target_activations = []

    def _ref_activation_hook(module, input, output):
        reference_activations.append(input[0].detach().reshape(-1, input[0].shape[-1]))

    def _target_activation_hook(module, input, output):
        target_activations.append(input[0].detach().reshape(-1, input[0].shape[-1]))

    reference_handle = reference_mlp.down_proj.register_forward_hook(_ref_activation_hook)
    target_handle = target_mlp.down_proj.register_forward_hook(_target_activation_hook)
    with torch.no_grad():
        for i in range(0, forwarded_hidden_states.shape[0], mini_batch_size):
            reference_mlp(forwarded_hidden_states[i:i + mini_batch_size])
            target_mlp(forwarded_hidden_states[i:i + mini_batch_size])

    reference_activations = torch.cat(reference_activations, dim=0)  # (batch_size * seq_len, d_ff)
    target_activations = torch.cat(target_activations, dim=0)  # (batch_size * seq_len, d_ff)

    # Compute the correlation matrix as the cost matrix
    mean_ref = reference_activations.mean(dim=0, keepdim=True)  # (1, d_ff)
    mean_target = target_activations.mean(dim=0, keepdim=True)  # (1, d_ff)
    std_ref = reference_activations.std(dim=0, keepdim=True)  # (1, d_ff)
    std_target = target_activations.std(dim=0, keepdim=True)  # (1, d_ff)
    covar = torch.mm(
        (reference_activations - mean_ref).t(),
        (target_activations - mean_target)
    ) / (reference_activations.shape[0] - 1)  # (d_ff, d_ff)
    cost_matrix = covar / (std_ref.t() * std_target + FP32_EPS)  # (d_ff, d_ff)

    _, perm = linear_sum_assignment(cost_matrix.cpu().to(torch.float).numpy(), maximize=True)
    reference_handle.remove()
    target_handle.remove()
    return torch.from_numpy(perm).to(cost_matrix.device)

def align_OLMoE_permutation_for_all_experts_by_activation_matching(
        model: OlmoeForCausalLM,
        dataloader: DataLoader,
) -> OlmoeForCausalLM:
    """
    Align the permutation of all experts in the switch model by activation matching.

    In practice, the hidden states of the tokens routed to both reference_mlp and target_mlp are gathered
        from the dataloader, and the permutation is computed by matching the activation of the hidden states.

    Parameters
    ----------
    model: OlmoeForCausalLM
        The OLMoE model to be aligned.
    dataloader: DataLoader
        The dataloader to be used to gather the hidden states for activation matching.

    Returns
    -------
    model: OlmoeForCausalLM
        The aligned OLMoE model, from in-place operation.

    """
    # {name: values}, values  will be of shape (len(dataloader), batch_size * seq_len, d_ff)
    forwarded_hidden_states = dict()

    model.eval()
    config = model.config
    sparse_layer_indices = list(range(1, config.num_layers))

    # Register the activation hook for all experts
    handles = []

    def _get_activation_hook(name):
        def hook(module, input, output):
            forwarded_hidden_states[name].append(input[0].detach().reshape(-1, input[0].shape[-1]))

        return hook

    for layer_idx in tqdm(sparse_layer_indices, desc="[Permutation]Registering forward hook..."):
        mlp_name = f"model.layers.{layer_idx}.mlp"
        forwarded_hidden_states[mlp_name] = []
        handles.append(model.model.layers[layer_idx].mlp.register_forward_hook(
            _get_activation_hook(mlp_name))
        )

    # {name: values}, values will be of shape (len(dataloader), batch_size * seq_len)
    router_indices = {name: [] for name in forwarded_hidden_states.keys()}
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="[Permutation]Computing activations..."):
            batch = {k: v.cuda() for k, v in batch.items()}
            output = model(**batch)
            for layer_idx in sparse_layer_indices:
                router_indices[f"model.layers.{layer_idx}.mlp"].append(
                    output.decoder_router_logits[layer_idx][1].reshape(-1)
                )

    # Compute the permutation for all experts
    num_experts = config.num_experts
    progress_bar = tqdm(range(len(sparse_layer_indices) * (num_experts - 1)),
                        desc="[Permutation]Aligning permutation by activation matching")
    for layer_idx in sparse_layer_indices:
        mlp = model.layers[layer_idx].mlp
        mlp_name = f"model.layers.{layer_idx}.mlp"

        expert0_hidden_states = torch.cat([
            forwarded_hidden_states[mlp_name][i][router_indices[mlp_name][i] == 0] for i in
            range(len(forwarded_hidden_states[mlp_name]))
        ])
        for expert_idx in range(1, num_experts):
            expert_hidden_states = torch.cat([
                forwarded_hidden_states[mlp_name][i][router_indices[mlp_name][i] == expert_idx]
                for i in range(len(forwarded_hidden_states[mlp_name]))
            ])
            perm = compute_OLMoE_permutation_by_activation_matching(
                reference_mlp=mlp.experts[0],
                target_mlp=mlp.experts[expert_idx],
                forwarded_hidden_states=torch.cat([expert0_hidden_states, expert_hidden_states], dim=0)
            )
            mlp.experts[expert_idx] = permute_OLMoE_mlp_dense_expert_(
                mlp.experts[expert_idx], perm
            )

            progress_bar.update(1)

    for handle in handles:
        handle.remove()
    return model




def merge_olmoe_mlp_by_activation_matching_within_and_across_models(
        mlp_list,
        forwarded_hidden_states: torch.Tensor,
        mini_batch_size: Optional[int] = None,
        alpha_for_repeated_merging: Optional[float] = 0.1,
        average_coefs: Optional[List[float]] = None,
):
    """
    Merge the MLPs in the mlp_list by activation matching. Specifically, the MLPs in the mlp_list are merged
    by matching the activation of the forwarded_hidden_states, within and across models.


    Parameters
    ----------
    mlp_list: List[SwitchTransformersDenseActDense]
        The MLPs to be merged.
    forwarded_hidden_states: torch.Tensor, of shape (batch_size * seq_len, d_ff)
        The hidden states that are forwarded to all MLPs in the mlp_list.
    mini_batch_size: int, default None
        The mini batch size for the activation matching. If None, the mini batch size is set to be the batch size.
    alpha_for_repeated_merging: float, default 0.1
        The alpha for repeated merging, which is used to update the correlation matrix after each merging.
        In particular, the correlation between the merged feature and the other features is set to be the minimum
            of the old features' correlation weighted.
    average_coefs: Optional[List[float]], default None
        The average coefficients for the features in the MLPs. If None, the average coefficients are set to be 1.0.

    Returns
    -------
    merged_mlp: SwitchTransformersDenseActDense
        The merged MLP.

    Examples
    --------
    >>> from transformers import SwitchTransformersConfig
    >>> from copy import deepcopy
    >>> mlp = SwitchTransformersDenseActDense(SwitchTransformersConfig(d_model=16, d_ff=32, dropout_rate=0))
    >>> mlp_merged = merge_switch_mlp_by_activation_matching_within_and_across_models([mlp, deepcopy(mlp)], torch.randn(64, 16))
    >>> dummy_input = torch.randn(4, 16)
    >>> torch.allclose(mlp_merged(dummy_input), mlp(dummy_input))
    True
    """
    mlp_list = [mlp.eval() for mlp in mlp_list]
    concat_mlp = deepcopy(mlp_list[0])
    d_ff, d_model = concat_mlp.up_proj.out_features, concat_mlp.up_proj.in_features
    if average_coefs is None:
        average_coefs = [1.0] * len(mlp_list) * d_ff
    elif len(average_coefs) == len(mlp_list):
        average_coefs = [coef for coef in average_coefs for _ in range(d_ff)]
    elif len(average_coefs) != len(mlp_list) * d_ff:
        raise ValueError(
            f"The length of average_coefs should be either {len(mlp_list)} or {len(mlp_list) * d_ff}, "
            f"but got {len(average_coefs)}."
        )
    num_mlp = len(mlp_list)
    if len(forwarded_hidden_states) == 0 or len(forwarded_hidden_states) == 1:
        return concat_mlp
    if mini_batch_size is None:
        mini_batch_size = forwarded_hidden_states.shape[0]

    mlp_all_up_proj = torch.cat([
        mlp.up_proj.weight.data for mlp in mlp_list
    ], dim=0)
    mlp_all_down_proj = torch.cat([
        mlp.down_proj.weight.data for mlp in mlp_list
    ], dim=1)
    mlp_all_gate_proj = torch.cat([
        mlp.gate_proj.weight.data for mlp in mlp_list
    ], dim=0)
    concat_mlp.up_proj = torch.nn.Linear(d_model, d_ff * num_mlp, bias=False)
    concat_mlp.gate_proj = torch.nn.Linear(d_model, d_ff * num_mlp, bias=False)
    concat_mlp.down_proj = torch.nn.Linear(d_ff * num_mlp, d_model, bias=False)
    with torch.no_grad():
        concat_mlp.up_proj.weight.data = mlp_all_up_proj  # (d_ff * num_mlp, d_model)
        concat_mlp.gate_proj.weight.data = mlp_all_gate_proj  # (d_ff * num_mlp, d_model)
        concat_mlp.down_proj.weight.data = mlp_all_down_proj  # (d_model, d_ff * num_mlp)
    concat_mlp = concat_mlp.eval().to(forwarded_hidden_states.device)

    activations = []

    def _activation_hook(module, input, output):
        activations.append(input[0].detach().reshape(-1, input[0].shape[-1]))

    handle = concat_mlp.down_proj.register_forward_hook(_activation_hook)
    with torch.no_grad():
        for i in range(0, forwarded_hidden_states.shape[0], mini_batch_size):
            concat_mlp(forwarded_hidden_states[i:i + mini_batch_size])

    activations = torch.cat(activations, dim=0)  # (batch_size * seq_len, d_ff * num_mlp)

    # Initialize the correlation matrix
    mean = activations.mean(dim=0, keepdim=True)  # (1, d_ff * num_mlp)
    std = activations.std(dim=0, keepdim=True)  # (1, d_ff * num_mlp)
    covar = torch.mm(
        (activations - mean).t(),
        (activations - mean)
    ) / (activations.shape[0] - 1)  # (d_ff * num_mlp, d_ff * num_mlp)
    corr_matrix = covar / (std.t() * std + FP32_EPS)  # (d_ff * num_mlp, d_ff * num_mlp)

    del activations, covar, std, mean
    torch.cuda.empty_cache()

    corr_matrix[torch.arange(d_ff * num_mlp), torch.arange(d_ff * num_mlp)] = -1  # Remove self-correlation

    # Greedy Merging!
    while mlp_all_up_proj.shape[0] > d_ff:
        # Select the most correlated pair
        max_index = torch.argmax(corr_matrix)
        max_i, max_j = max_index // corr_matrix.shape[0], max_index % corr_matrix.shape[0]

        # Merge the most correlated pair, replace the first feature with the merged one
        i_coef, j_coef = average_coefs[max_i], average_coefs[max_j]
        mlp_all_up_proj[max_i] = (i_coef * mlp_all_up_proj[max_i] + j_coef * mlp_all_up_proj[max_j]) / (i_coef + j_coef + FP32_EPS)
        mlp_all_gate_proj[max_i] = (i_coef * mlp_all_gate_proj[max_i] + j_coef * mlp_all_gate_proj[max_j]) / (i_coef + j_coef + FP32_EPS)
        mlp_all_down_proj[:, max_i] = (i_coef * mlp_all_down_proj[:, max_i] + j_coef * mlp_all_down_proj[:, max_j]) / (
                i_coef + j_coef + FP32_EPS)

        # Remove the second feature
        mlp_all_up_proj = torch.cat([
            mlp_all_up_proj[:max_j],
            mlp_all_up_proj[max_j + 1:]
        ], dim=0)
        mlp_all_gate_proj = torch.cat([
            mlp_all_gate_proj[:max_j],
            mlp_all_gate_proj[max_j + 1:]
        ], dim=0)
        mlp_all_down_proj = torch.cat([
            mlp_all_down_proj[:, :max_j],
            mlp_all_down_proj[:, max_j + 1:]
        ], dim=1)

        # Update the correlation matrix
        updated_corr_vec = alpha_for_repeated_merging * torch.min(
            torch.stack([corr_matrix[max_i], corr_matrix[max_j]]), dim=0
        ).values
        corr_matrix[max_i] = updated_corr_vec
        corr_matrix[:, max_i] = updated_corr_vec
        corr_matrix[max_i, max_i] = -1  # Remove self-correlation

        # Remove the second feature from the correlation matrix
        corr_matrix = torch.cat([
            corr_matrix[:, :max_j],
            corr_matrix[:, max_j + 1:]
        ], dim=1)
        corr_matrix = torch.cat([
            corr_matrix[:max_j],
            corr_matrix[max_j + 1:]
        ], dim=0)

        # Update the average coefs
        average_coefs[max_i] += average_coefs[max_j]
        average_coefs = average_coefs[:max_j] + average_coefs[max_j + 1:]

    handle.remove()
    merged_mlp = deepcopy(mlp_list[0])
    with torch.no_grad():
        merged_mlp.up_proj.weight.data = mlp_all_up_proj
        merged_mlp.down_proj.weight.data = mlp_all_down_proj
        merged_mlp.gate_proj.weight.data = mlp_all_gate_proj

    return merged_mlp


