from copy import deepcopy
from typing import Optional, List

import torch
from scipy.optimize import linear_sum_assignment
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers.models.switch_transformers.modeling_switch_transformers import (
    SwitchTransformersDenseActDense,
    SwitchTransformersForConditionalGeneration
)

from model.fsgpt_moe import (
    FSGPTMoEDenseActDense,
    FSGPTMoEForCausalLM,
    FSGPTMoEConfig
)
from utils.constants import FP32_EPS

__all__ = [
    "permute_switch_mlp_dense_expert_",
    "compute_switch_permutation_by_weight_matching",
    "align_switch_permutation_for_all_experts_by_weight_matching",
    "align_switch_permutation_for_all_experts_by_activation_matching",
    "align_fsgpt_permutation_for_all_experts_by_weight_matching",
    "compute_switch_permutation_by_activation_matching",
    "merge_switch_mlp_by_activation_matching_within_and_across_models",
    "merge_switch_mlp_by_weight_matching_within_and_across_models",
    "permute_fsgpt_ffn_dense_expert_",
    "compute_fsgpt_permutation_by_weight_matching",
]


def permute_switch_mlp_dense_expert_(
        dense_mlp: SwitchTransformersDenseActDense,
        perm: torch.Tensor,
) -> SwitchTransformersDenseActDense:
    d_ff = dense_mlp.wi.out_features

    
    if perm.shape != (d_ff,):
        raise ValueError(f"The shape of the permutation vector should be (d_ff, ), but got {perm.shape}.")
    if not torch.allclose(perm.sort()[0], torch.arange(d_ff, device=perm.device)):
        raise ValueError("The permutation vector should be a permutation.")

    
    with torch.no_grad():
        dense_mlp.wi.weight.data = dense_mlp.wi.weight.data[perm, :]
        dense_mlp.wo.weight.data = dense_mlp.wo.weight.data[:, perm]

    return dense_mlp


def compute_switch_permutation_by_weight_matching(
        reference_mlp: SwitchTransformersDenseActDense,
        target_mlp: SwitchTransformersDenseActDense,
        include_wo: bool,
) -> torch.Tensor:
    with torch.no_grad():
        lsa_cost_matrix = torch.mm(
            reference_mlp.wi.weight.data, target_mlp.wi.weight.data.t()
        )
        if include_wo:
            lsa_cost_matrix += torch.mm(
                reference_mlp.wo.weight.data.t(), target_mlp.wo.weight.data
            )
    _, perm = linear_sum_assignment(lsa_cost_matrix.cpu().numpy(), maximize=True)
    return torch.from_numpy(perm).to(lsa_cost_matrix.device)


def align_switch_permutation_for_all_experts_by_weight_matching(
        switch_model: SwitchTransformersForConditionalGeneration,
        include_wo: bool,
) -> SwitchTransformersForConditionalGeneration:
    config = switch_model.config
    sparse_layer_indices = list(range(1, config.num_layers, config.encoder_sparse_step))
    num_experts = config.num_experts
    for layer_idx in tqdm(sparse_layer_indices,
                          desc=f"[Permutation]Aligning permutation {'with' if include_wo else 'without'} Wo"):
        encoder_mlp = switch_model.encoder.block[layer_idx].layer[-1].mlp
        decoder_mlp = switch_model.decoder.block[layer_idx].layer[-1].mlp
        for expert_idx in range(1, num_experts):
            
            perm = compute_switch_permutation_by_weight_matching(
                encoder_mlp.experts["expert_0"],
                encoder_mlp.experts[f"expert_{expert_idx}"],
                include_wo=include_wo,
            )
            encoder_mlp.experts[f"expert_{expert_idx}"] = permute_switch_mlp_dense_expert_(
                encoder_mlp.experts[f"expert_{expert_idx}"], perm
            )
            
            perm = compute_switch_permutation_by_weight_matching(
                decoder_mlp.experts["expert_0"],
                decoder_mlp.experts[f"expert_{expert_idx}"],
                include_wo=include_wo,
            )
            decoder_mlp.experts[f"expert_{expert_idx}"] = permute_switch_mlp_dense_expert_(
                decoder_mlp.experts[f"expert_{expert_idx}"], perm
            )

    return switch_model


def permute_fsgpt_ffn_dense_expert_(
        dense_ffn: FSGPTMoEDenseActDense,
        perm: torch.Tensor,
) -> FSGPTMoEDenseActDense:
    hidden_size = dense_ffn.fc1.out_features

    
    if perm.shape != (hidden_size,):
        raise ValueError(f"The shape of the permutation vector should be (d_ff, ), but got {perm.shape}.")
    if not torch.allclose(perm.sort()[0], torch.arange(hidden_size, device=perm.device)):
        raise ValueError("The permutation vector should be a permutation.")

    
    with torch.no_grad():
        dense_ffn.fc1.weight.data = dense_ffn.fc1.weight.data[perm, :]
        dense_ffn.fc1.bias.data = dense_ffn.fc1.bias.data[perm]
        dense_ffn.fc2.weight.data = dense_ffn.fc2.weight.data[:, perm]

    return dense_ffn


def compute_fsgpt_permutation_by_weight_matching(
        reference_ffn: FSGPTMoEDenseActDense,
        target_ffn: FSGPTMoEDenseActDense,
        include_wo: Optional[bool] = True,
) -> torch.Tensor:
    with torch.no_grad():
        lsa_cost_matrix = torch.mm(
            reference_ffn.fc1.weight.data.float(), target_ffn.fc1.weight.data.t().float()
        )
        if include_wo:
            lsa_cost_matrix += torch.mm(
                reference_ffn.fc2.weight.data.t().float(), target_ffn.fc2.weight.data.float()
            )
    _, perm = linear_sum_assignment(lsa_cost_matrix.cpu().numpy(), maximize=True)
    return torch.from_numpy(perm).to(lsa_cost_matrix.device)


def align_fsgpt_permutation_for_all_experts_by_weight_matching(
        fsgpt_model: FSGPTMoEForCausalLM,
        include_wo: bool = True,
) -> FSGPTMoEForCausalLM:
    config = fsgpt_model.config
    sparse_layer_indices = list(range(1, config.num_layers, config.sparse_step))
    num_experts = config.num_experts
    progress_bar = tqdm(range(len(sparse_layer_indices) * (num_experts - 1)),
                        desc=f"[Permutation]Aligning permutation {'with' if include_wo else 'without'} Wo")
    for layer_idx in sparse_layer_indices:
        ffn = fsgpt_model.decoder.layers[layer_idx].ffn
        for expert_idx in range(1, num_experts):
            perm = compute_fsgpt_permutation_by_weight_matching(
                ffn.experts["expert_0"],
                ffn.experts[f"expert_{expert_idx}"],
                include_wo=include_wo,
            )
            ffn.experts[f"expert_{expert_idx}"] = permute_fsgpt_ffn_dense_expert_(
                ffn.experts[f"expert_{expert_idx}"], perm
            )
            progress_bar.update(1)

    return fsgpt_model


def compute_switch_permutation_by_activation_matching(
        reference_mlp: SwitchTransformersDenseActDense,
        target_mlp: SwitchTransformersDenseActDense,
        forwarded_hidden_states: torch.Tensor,
        mini_batch_size: Optional[int] = None,
) -> torch.Tensor:
    if len(forwarded_hidden_states) == 0 or len(forwarded_hidden_states) == 1:
        return torch.arange(reference_mlp.wi.out_features, device=forwarded_hidden_states.device)

    if forwarded_hidden_states.shape[-1] != reference_mlp.wi.in_features:
        raise ValueError(
            f"The last dimension of forwarded_hidden_states should be {reference_mlp.wi.in_features}, "
            f"but got {forwarded_hidden_states.shape[-1]}."
        )
    if mini_batch_size is None:
        mini_batch_size = forwarded_hidden_states.shape[0]
    reference_activations = []
    target_activations = []

    def _ref_activation_hook(module, input, output):
        reference_activations.append(input[0].detach().reshape(-1, input[0].shape[-1]))

    def _target_activation_hook(module, input, output):
        target_activations.append(input[0].detach().reshape(-1, input[0].shape[-1]))

    reference_handle = reference_mlp.wo.register_forward_hook(_ref_activation_hook)
    target_handle = target_mlp.wo.register_forward_hook(_target_activation_hook)
    with torch.no_grad():
        for i in range(0, forwarded_hidden_states.shape[0], mini_batch_size):
            reference_mlp(forwarded_hidden_states[i:i + mini_batch_size])
            target_mlp(forwarded_hidden_states[i:i + mini_batch_size])

    reference_activations = torch.cat(reference_activations, dim=0)  
    target_activations = torch.cat(target_activations, dim=0)  

    
    mean_ref = reference_activations.mean(dim=0, keepdim=True)  
    mean_target = target_activations.mean(dim=0, keepdim=True)  
    std_ref = reference_activations.std(dim=0, keepdim=True)  
    std_target = target_activations.std(dim=0, keepdim=True)  
    covar = torch.mm(
        (reference_activations - mean_ref).t(),
        (target_activations - mean_target)
    ) / (reference_activations.shape[0] - 1)  
    cost_matrix = covar / (std_ref.t() * std_target + FP32_EPS)  

    _, perm = linear_sum_assignment(cost_matrix.cpu().numpy(), maximize=True)
    reference_handle.remove()
    target_handle.remove()
    return torch.from_numpy(perm).to(cost_matrix.device)


def align_switch_permutation_for_all_experts_by_activation_matching(
        switch_model: SwitchTransformersForConditionalGeneration,
        dataloader: DataLoader,
) -> SwitchTransformersForConditionalGeneration:
    
    forwarded_hidden_states = dict()

    switch_model.eval()
    config = switch_model.config
    sparse_layer_indices = list(range(1, config.num_layers, config.encoder_sparse_step))

    
    handles = []

    def _get_activation_hook(name):
        def hook(module, input, output):
            forwarded_hidden_states[name].append(input[0].detach().reshape(-1, input[0].shape[-1]))

        return hook

    for layer_idx in tqdm(sparse_layer_indices, desc="[Permutation]Registering forward hook..."):
        encoder_mlp_name = f"encoder.block.{layer_idx}.layer.1.mlp"
        decoder_mlp_name = f"decoder.block.{layer_idx}.layer.2.mlp"
        forwarded_hidden_states[encoder_mlp_name] = []
        forwarded_hidden_states[decoder_mlp_name] = []
        handles.append(switch_model.encoder.block[layer_idx].layer[-1].mlp.register_forward_hook(
            _get_activation_hook(encoder_mlp_name))
        )
        handles.append(switch_model.decoder.block[layer_idx].layer[-1].mlp.register_forward_hook(
            _get_activation_hook(decoder_mlp_name))
        )

    
    router_indices = {name: [] for name in forwarded_hidden_states.keys()}
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="[Permutation]Computing activations..."):
            batch = {k: v.cuda() for k, v in batch.items()}
            output = switch_model(**batch)
            for layer_idx in sparse_layer_indices:
                router_indices[f"encoder.block.{layer_idx}.layer.1.mlp"].append(
                    output.encoder_router_logits[layer_idx][1].reshape(-1)
                )
                router_indices[f"decoder.block.{layer_idx}.layer.2.mlp"].append(
                    output.decoder_router_logits[layer_idx][1].reshape(-1)
                )

    
    num_experts = config.num_experts
    progress_bar = tqdm(range(len(sparse_layer_indices) * (num_experts - 1)),
                        desc="[Permutation]Aligning permutation by activation matching")
    for layer_idx in sparse_layer_indices:
        encoder_mlp = switch_model.encoder.block[layer_idx].layer[-1].mlp
        encoder_mlp_name = f"encoder.block.{layer_idx}.layer.1.mlp"
        decoder_mlp = switch_model.decoder.block[layer_idx].layer[-1].mlp
        decoder_mlp_name = f"decoder.block.{layer_idx}.layer.2.mlp"

        expert0_encoder_hidden_states = torch.cat([
            forwarded_hidden_states[encoder_mlp_name][i][router_indices[encoder_mlp_name][i] == 0] for i in
            range(len(forwarded_hidden_states[encoder_mlp_name]))
        ])
        expert0_decoder_hidden_states = torch.cat([
            forwarded_hidden_states[decoder_mlp_name][i][router_indices[decoder_mlp_name][i] == 0] for i in
            range(len(forwarded_hidden_states[decoder_mlp_name]))
        ])
        for expert_idx in range(1, num_experts):
            
            expert_encoder_hidden_states = torch.cat([
                forwarded_hidden_states[encoder_mlp_name][i][router_indices[encoder_mlp_name][i] == expert_idx]
                for i in range(len(forwarded_hidden_states[encoder_mlp_name]))
            ])
            perm = compute_switch_permutation_by_activation_matching(
                reference_mlp=encoder_mlp.experts["expert_0"],
                target_mlp=encoder_mlp.experts[f"expert_{expert_idx}"],
                forwarded_hidden_states=torch.cat([expert0_encoder_hidden_states, expert_encoder_hidden_states], dim=0)
            )
            encoder_mlp.experts[f"expert_{expert_idx}"] = permute_switch_mlp_dense_expert_(
                encoder_mlp.experts[f"expert_{expert_idx}"], perm
            )
            
            expert_decoder_hidden_states = torch.cat([
                forwarded_hidden_states[decoder_mlp_name][i][router_indices[decoder_mlp_name][i] == expert_idx]
                for i in range(len(forwarded_hidden_states[decoder_mlp_name]))
            ])
            perm = compute_switch_permutation_by_activation_matching(
                reference_mlp=decoder_mlp.experts["expert_0"],
                target_mlp=decoder_mlp.experts[f"expert_{expert_idx}"],
                forwarded_hidden_states=torch.cat([expert0_decoder_hidden_states, expert_decoder_hidden_states], dim=0)
            )
            decoder_mlp.experts[f"expert_{expert_idx}"] = permute_switch_mlp_dense_expert_(
                decoder_mlp.experts[f"expert_{expert_idx}"], perm
            )

            progress_bar.update(1)

    for handle in handles:
        handle.remove()
    return switch_model


def merge_switch_mlp_by_activation_matching_within_and_across_models(
        mlp_list: List[SwitchTransformersDenseActDense],
        forwarded_hidden_states: torch.Tensor,
        mini_batch_size: Optional[int] = None,
        alpha_for_repeated_merging: Optional[float] = 0.1,
        average_coefs: Optional[List[float]] = None,
) -> SwitchTransformersDenseActDense:
    mlp_list = [mlp.eval() for mlp in mlp_list]
    concat_mlp = deepcopy(mlp_list[0])
    d_ff, d_model = concat_mlp.wi.out_features, concat_mlp.wi.in_features
    if average_coefs is None:
        average_coefs = [1.0] * len(mlp_list) * d_ff
    elif len(average_coefs) == len(mlp_list):
        average_coefs = [coef for coef in average_coefs for _ in range(d_ff)]
    elif len(average_coefs) != len(mlp_list) * d_ff:
        raise ValueError(
            f"The length of average_coefs should be either {len(mlp_list)} or {len(mlp_list) * d_ff}, "
            f"but got {len(average_coefs)}."
        )
    num_mlp = len(mlp_list)
    if len(forwarded_hidden_states) == 0 or len(forwarded_hidden_states) == 1:
        return concat_mlp
    if mini_batch_size is None:
        mini_batch_size = forwarded_hidden_states.shape[0]

    mlp_all_wi = torch.cat([
        mlp.wi.weight.data for mlp in mlp_list
    ], dim=0)
    mlp_all_wo = torch.cat([
        mlp.wo.weight.data for mlp in mlp_list
    ], dim=1)
    concat_mlp.wi = torch.nn.Linear(d_model, d_ff * num_mlp, bias=False)
    concat_mlp.wo = torch.nn.Linear(d_ff * num_mlp, d_model, bias=False)
    with torch.no_grad():
        concat_mlp.wi.weight.data = mlp_all_wi  
        concat_mlp.wo.weight.data = mlp_all_wo  
    concat_mlp = concat_mlp.eval().to(forwarded_hidden_states.device)

    activations = []

    def _activation_hook(module, input, output):
        activations.append(input[0].detach().reshape(-1, input[0].shape[-1]))

    handle = concat_mlp.wo.register_forward_hook(_activation_hook)

    print(f"Compute activations with data shape {forwarded_hidden_states.shape} and batch size {mini_batch_size}")

    with torch.no_grad():
        for i in range(0, forwarded_hidden_states.shape[0], mini_batch_size):
            concat_mlp(forwarded_hidden_states[i:i + mini_batch_size])
    activations = torch.cat(activations, dim=0)  

    
    mean = activations.mean(dim=0, keepdim=True)  
    std = activations.std(dim=0, keepdim=True)  
    covar = torch.mm(
        (activations - mean).t(),
        (activations - mean)
    ) / (activations.shape[0] - 1)  
    corr_matrix = covar / (std.t() * std + FP32_EPS)  

    del activations, covar, std, mean
    torch.cuda.empty_cache()

    corr_matrix[torch.arange(d_ff * num_mlp), torch.arange(d_ff * num_mlp)] = -1  

    
    while mlp_all_wi.shape[0] > d_ff:
        
        max_index = torch.argmax(corr_matrix)
        max_i, max_j = max_index // corr_matrix.shape[0], max_index % corr_matrix.shape[0]

        
        i_coef, j_coef = average_coefs[max_i], average_coefs[max_j]
        mlp_all_wi[max_i] = (i_coef * mlp_all_wi[max_i] + j_coef * mlp_all_wi[max_j]) / (i_coef + j_coef + FP32_EPS)
        mlp_all_wo[:, max_i] = (i_coef * mlp_all_wo[:, max_i] + j_coef * mlp_all_wo[:, max_j]) / (
                i_coef + j_coef + FP32_EPS)

        
        mlp_all_wi = torch.cat([
            mlp_all_wi[:max_j],
            mlp_all_wi[max_j + 1:]
        ], dim=0)
        mlp_all_wo = torch.cat([
            mlp_all_wo[:, :max_j],
            mlp_all_wo[:, max_j + 1:]
        ], dim=1)

        
        updated_corr_vec = alpha_for_repeated_merging * torch.min(
            torch.stack([corr_matrix[max_i], corr_matrix[max_j]]), dim=0
        ).values
        corr_matrix[max_i] = updated_corr_vec
        corr_matrix[:, max_i] = updated_corr_vec
        corr_matrix[max_i, max_i] = -1  

        
        corr_matrix = torch.cat([
            corr_matrix[:, :max_j],
            corr_matrix[:, max_j + 1:]
        ], dim=1)
        corr_matrix = torch.cat([
            corr_matrix[:max_j],
            corr_matrix[max_j + 1:]
        ], dim=0)

        
        average_coefs[max_i] += average_coefs[max_j]
        average_coefs = average_coefs[:max_j] + average_coefs[max_j + 1:]

    handle.remove()
    merged_mlp = deepcopy(mlp_list[0])
    with torch.no_grad():
        merged_mlp.wi.weight.data = mlp_all_wi
        merged_mlp.wo.weight.data = mlp_all_wo

    return merged_mlp


def merge_switch_mlp_by_weight_matching_within_and_across_models(
        mlp_list: List[SwitchTransformersDenseActDense],
        include_wo: Optional[bool] = True,
        average_coefs: Optional[List[float]] = None,
) -> SwitchTransformersDenseActDense:
    d_ff, d_model = mlp_list[0].wi.out_features, mlp_list[0].wi.in_features
    if average_coefs is None:
        average_coefs = [1.0] * len(mlp_list) * d_ff
    elif len(average_coefs) == len(mlp_list):
        average_coefs = [coef for coef in average_coefs for _ in range(d_ff)]
    elif len(average_coefs) != len(mlp_list) * d_ff:
        raise ValueError(
            f"The length of average_coefs should be either {len(mlp_list)} or {len(mlp_list) * d_ff}, "
            f"but got {len(average_coefs)}."
        )
    concat_wi = torch.cat([
        mlp.wi.weight.data for mlp in mlp_list
    ], dim=0)  
    concat_wo = torch.cat([
        mlp.wo.weight.data.t() for mlp in mlp_list
    ], dim=0)  
    if include_wo:
        cost_matrix = torch.mm(concat_wi, concat_wi.t()) + torch.mm(concat_wo, concat_wo.t())
    else:
        cost_matrix = torch.mm(concat_wi, concat_wi.t())

    
    dtype_min = torch.finfo(cost_matrix.dtype).min
    cost_matrix[torch.arange(d_ff * len(mlp_list)), torch.arange(d_ff * len(mlp_list))] = dtype_min

    
    while concat_wi.shape[0] > d_ff:
        
        max_index = torch.argmax(cost_matrix)
        max_i, max_j = max_index // cost_matrix.shape[0], max_index % cost_matrix.shape[0]

        
        i_coef = average_coefs[max_i]
        j_coef = average_coefs[max_j]
        concat_wi[max_i] = (i_coef * concat_wi[max_i] + j_coef * concat_wi[max_j]) / (i_coef + j_coef + FP32_EPS)
        concat_wo[max_i] = (i_coef * concat_wo[max_i] + j_coef * concat_wo[max_j]) / (i_coef + j_coef + FP32_EPS)

        
        updated_cost_vec = torch.mm(
            concat_wi[max_i].reshape(1, -1),
            concat_wi.t()
        ) + torch.mm(
            concat_wo[max_i].reshape(1, -1),
            concat_wo.t()
        )
        cost_matrix[max_i] = updated_cost_vec
        cost_matrix[:, max_i] = updated_cost_vec
        cost_matrix[max_i, max_i] = dtype_min  

        
        concat_wi = torch.cat([
            concat_wi[:max_j],
            concat_wi[max_j + 1:]
        ], dim=0)
        concat_wo = torch.cat([
            concat_wo[:max_j],
            concat_wo[max_j + 1:]
        ], dim=0)

        
        cost_matrix = torch.cat([
            cost_matrix[:, :max_j],
            cost_matrix[:, max_j + 1:]
        ], dim=1)
        cost_matrix = torch.cat([
            cost_matrix[:max_j],
            cost_matrix[max_j + 1:]
        ], dim=0)

        
        average_coefs[max_i] += average_coefs[max_j]
        average_coefs = average_coefs[:max_j] + average_coefs[max_j + 1:]

    merged_mlp = deepcopy(mlp_list[0])
    with torch.no_grad():
        merged_mlp.wi.weight.data = concat_wi
        merged_mlp.wo.weight.data = concat_wo.t()

    return merged_mlp
