"""
This file contains the implementation of the regmean method for model merging.
modified from https://github.com/yule-BUAA/MergeLM/blob/6d49ad96fd69c92013654b837041b868aa806564/model_merging_methods/merging_methods.py
"""

import logging
import re
from collections import defaultdict
from typing import Dict, List, cast

import torch
from torch import Tensor, nn
from tqdm.autonotebook import tqdm

from fusion_bench.method import BaseAlgorithm
from fusion_bench.mixins import SimpleProfilerMixin
from fusion_bench.modelpool import BaseModelPool

log = logging.getLogger(__name__)


def get_param_names_to_merge(
    input_param_names: List[str], exclude_param_names_regex: list
):
    """
    get the names of parameters that need to be merged
    :param input_param_names: list, names of input parameters
    :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
    :return:
    """
    param_names_to_merge = []
    for param_name in input_param_names:
        exclude = any(
            [
                re.match(exclude_pattern, param_name)
                for exclude_pattern in exclude_param_names_regex
            ]
        )
        if not exclude:
            param_names_to_merge.append(param_name)
    return param_names_to_merge


def get_modules_to_merge(model: nn.Module, include_module_types: list):
    """
    get the model modules that need to be merged, whose type is in include_module_types
    :param model: nn.Module, input model
    :param include_module_types: list, module types that want to include
    :return:
    """
    modules_to_merge: Dict[str, nn.Module] = {}
    for module_name, module in model.named_modules():
        is_valid_type = not include_module_types or any(
            [
                isinstance(module, include_module_type)
                for include_module_type in include_module_types
            ]
        )
        if is_valid_type:
            modules_to_merge[module_name] = module
    return modules_to_merge


def reduce_non_diagonal_elements(
    regmean_weights: torch.Tensor, reduce_non_diagonal_ratio: float
):
    """
    reduce the non-diagonal elements in regmean_weights
    :param regmean_weights: Tensor, shape (hidden_dim, hidden_dim), input regmean weights
    :param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
    :return:
    """
    # diagonal matrix with (1 - reduce_non_diagonal_ratio) as elements
    diag_weights = torch.diag(
        torch.ones(regmean_weights.shape[0]) - reduce_non_diagonal_ratio
    ).to(regmean_weights.device)
    # matrix with reduce_non_diagonal_ratio as elements
    non_diag_weights = torch.zeros_like(diag_weights).fill_(reduce_non_diagonal_ratio)
    # diagonal elements are unchanged, while non-diagonal elements are multiplied by reduce_non_diagonal_ratio
    return regmean_weights * (diag_weights + non_diag_weights)


def regmean_params_merge(
    param_weight_list: List[Tensor],
    param_regmean_list: List[Tensor],
    reduce_non_diagonal_ratio: float = 1.0,
    weight_transpose: bool = True,
    module_name: str = "",
    device = "cpu"
):
    # two lists with length num_models_to_merge
    param_multiplied_results, module_regmean_weights_list = [], []
    for model_idx, module_regmean_weights in enumerate(
        param_regmean_list
    ):
        # reduce non-diagonal elements
        module_regmean_weights = reduce_non_diagonal_elements(
            regmean_weights=module_regmean_weights,
            reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
        )
        module_regmean_weights_list.append(module_regmean_weights)

        model_to_merge_param = param_weight_list[model_idx]
        # since the weight shape of Linear module is (output_size, input_size), we need to transpose it
        param_multiplied_results.append(
            torch.matmul(
                module_regmean_weights,
                (
                    model_to_merge_param.transpose(0, 1)
                    if weight_transpose
                    else model_to_merge_param
                ),
            )
        )

    # sum up module_regmean_weights and param_multiplied_results over all individual models
    sum_module_regmean_weights = sum(module_regmean_weights_list)
    sum_param_multiplied_results = sum(param_multiplied_results)

    # get the inverse matrix
    inv_sum_module_regmean_weights = torch.inverse(
        sum_module_regmean_weights
    )
    # merge parameters with regmean
    merged_param = torch.matmul(
        inv_sum_module_regmean_weights, sum_param_multiplied_results
    )
    # transpose to the original shape of "weight" in Linear module
    merged_param = merged_param.transpose(0, 1) if weight_transpose else merged_param

    return merged_param


def merging_with_regmean_weights(
    models_to_merge_param_dict: dict,
    models_to_merge_regmean_weights_list: list,
    reduce_non_diagonal_ratio: float = 1.0,
    weight_transpose: bool = True,
):
    """
    merge parameters of different models with computed regmean weights
    :param models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
    value is a list of the corresponding parameters of all the models that need to be merged
    :param models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
    each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
    :param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
    :return:
    """
    # dict, dictionary of model parameters
    merged_params = {}

    for param_name, param_value_list in models_to_merge_param_dict.items():
        merged_by_regmean = False
        # only perform regmean merging on the "weight" parameter of Linear module
        if param_name.endswith(".weight"):
            module_name = param_name[: -len(".weight")]
            if module_name in models_to_merge_regmean_weights_list[0].keys():
                # two lists with length num_models_to_merge
                module_regmean_weights_list = []
                for model_idx, model_to_merge_regmean_weights in enumerate(
                    models_to_merge_regmean_weights_list
                ):
                    device = param_value_list[model_idx].device

                    # Tensor, shape (hidden_dim, hidden_dim)
                    module_regmean_weights = model_to_merge_regmean_weights[module_name].to(device)
                    module_regmean_weights_list.append(module_regmean_weights)

                merged_params[param_name] = regmean_params_merge(param_weight_list=param_value_list,
                                                                 param_regmean_list=module_regmean_weights_list,
                                                                 reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
                                                                 weight_transpose=weight_transpose,
                                                                 module_name=module_name,
                                                                 device=device)

                merged_by_regmean = True
        # use average merging for parameters whose names are not end with ".weight" or not in Linear module
        if not merged_by_regmean:
            merged_params[param_name] = torch.stack(param_value_list, dim=0).mean(dim=0)

    return merged_params


class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
    _include_module_type = [nn.Linear]
    _config_mapping = {
        "num_regmean_examples": "num_regmean_examples",
        "exclude_param_names_regex": "exclude_param_names_regex",
        "reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
        "weight_transpose": "weight_transpose",
    }

    def __init__(
        self,
        *,
        num_regmean_examples: int,
        exclude_param_names_regex: list,
        reduce_non_diagonal_ratio: float,
        weight_transpose: bool,
        **kwargs,
    ):
        self.num_regmean_examples = num_regmean_examples
        self.exclude_param_names_regex = exclude_param_names_regex
        self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
        self.weight_transpose = weight_transpose
        super().__init__(**kwargs)

    def run(self, modelpool: BaseModelPool, **kwargs):
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(modelpool)
        self.modelpool = modelpool
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
        models_to_merge_dict = {name: model.to(device) for name, model in modelpool.named_models()}
        self.on_regmean_start()

        # initialize the merged models as the pretrained model
        merged_model = modelpool.load_pretrained_model().to(device)
        merged_params_dict = {}

        # 1. merge embedding layer
        merged_embedding_dict = self.merge_embedding_layer(models_to_merge_dict=models_to_merge_dict)
        merged_model.load_state_dict(merged_embedding_dict, strict=False)

        with torch.no_grad():
            # 1.1. compute input for the first layer
            with (
                self.profile("merging models"),
                self.profile("computing first layer input"),
            ):
                batches_input_dict = defaultdict(list)
                for name in tqdm(models_to_merge_dict.keys(), desc="computing input for first layer"):
                    dataset = modelpool.load_train_dataset(name)
                    
                    ##################### ONLY FOR EXPERIMENTS #####################
                    import os
                    from datasets import concatenate_datasets


                    # Sequential Merging
                    merged_tasks = os.getenv("MERGED_TASKS", None)
                    if merged_tasks is not None:
                        if name == "imagenet":
                            merged_tasks = merged_tasks.split(":")

                            from hydra.utils import instantiate
                            from omegaconf import OmegaConf

                            seed = 42
                            n_samples = int(256 / len(merged_tasks))
                            all_datasets = []
                            for task in merged_tasks:
                                config = OmegaConf.load(f"config/dataset/image_classification/train/{task}.yaml")
                                dataset = instantiate(config)[task]

                                dataset = dataset.train_test_split(test_size=n_samples, seed=seed)["test"]
                                dataset = dataset.remove_columns("label")
                                dataset = dataset.map(lambda example: {"label": 1})

                                all_datasets.append(dataset)

                            assert len(all_datasets) == len(merged_tasks)
                            
                            all_datasets = concatenate_datasets(all_datasets)
                            dataset = all_datasets.shuffle(seed=seed)
                            
                            assert len(dataset) <= 256 and len(dataset) >= 256-len(merged_tasks)


                    # Class Imbalance
                    class_id = os.getenv("EXP_CLASS_ID")
                    if class_id is not None:
                        class_id = int(class_id)
                        dataset = dataset.filter(lambda example: example['label'] == class_id, 
                                                 num_proc=4, 
                                                 load_from_cache_file=False)
                        
                    ##################### ONLY FOR EXPERIMENTS #####################

                    batches_input_dict[name] = self.get_input_for_first_layer(
                        merged_model,
                        dataset
                    )

            # 2. iteratively merge layer by layer with regmean algorithm
            backbone_layers = self.get_layers(merged_model)
            num_layers = len(backbone_layers)

            models_to_merge_layers_dict = defaultdict(list)
            for name, model in models_to_merge_dict.items():
                models_to_merge_layers_dict[name] = self.get_layers(model)

            param_names_to_merge = None
            for layer_idx, backbone_layer in tqdm(enumerate(backbone_layers), 
                                                  desc="merging layers", 
                                                  total=num_layers):
                # dictionary of list, where key is the parameter name,
                # value is a list of the corresponding parameters of all the models that need to be merged
                models_to_merge_param_dict = defaultdict(list)

                # list of dictionaries with length len(models_to_merge),
                # each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged
                models_to_merge_regmean_weights_list = []

                for name, layers_to_merge in models_to_merge_layers_dict.items():
                    layer_to_merge = layers_to_merge[layer_idx]
                    param_dict = layer_to_merge.state_dict()

                    # exclude parameter whose name matches element in exclude_param_names_regex
                    if param_names_to_merge is None:
                        param_names_to_merge = get_param_names_to_merge(
                            input_param_names=list(param_dict.keys()),
                            exclude_param_names_regex=self.config.get(
                                "exclude_param_names_regex", []
                            ),
                        )
                
                    for param_name in param_names_to_merge:
                        models_to_merge_param_dict[param_name].append(
                            param_dict[param_name]
                        )

                    linear_modules_to_merge = get_modules_to_merge(
                        model=layer_to_merge, include_module_types=self._include_module_type
                    )
                    assert len(linear_modules_to_merge) > 0, "No linear modules to merge"

                    # 2.1. compute regmean weights for each model
                    with (
                        self.profile("merging models"),
                        self.profile("computing regmean weights"),
                    ):
                        regmean_weights = self.get_regmean_weights(
                            name,
                            layer_to_merge,
                            batches_input=batches_input_dict[name],
                            linear_modules_to_merge=linear_modules_to_merge,
                        )

                        module_subset = get_param_names_to_merge(
                            input_param_names=list(param_dict.keys()),
                            exclude_param_names_regex=self.exclude_param_names_regex
                        )
                        module_subset = [name.replace(".weight", "").replace(".bias", "") for name in module_subset]
                        module_subset = list(set(module_subset))
                        regmean_weights = {module_name: regmean_weights[module_name] for module_name in module_subset if module_name in regmean_weights}
                        
                        models_to_merge_regmean_weights_list.append(regmean_weights)

                # 2.2. merge parameters with regmean weights
                with self.profile("merging models"):
                    # merging with regmean weights
                    merged_layer_params = merging_with_regmean_weights(
                        models_to_merge_param_dict=models_to_merge_param_dict,
                        models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
                        reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
                        weight_transpose=self.config.get("weight_transpose", True),
                    )

                    ##################### ONLY FOR EXPERIMENTS #####################
                    import os
                    
                    # Region-specific Merging & Layer-wise Merging
                    layer_start, layer_end = None, None

                    layer_i = os.getenv("LAYER_ID", None)
                    if layer_i is not None: layer_start, layer_end = int(layer_i), int(layer_i)

                    layer_range = os.getenv("MERGE_LAYER_RANGE", None)
                    if layer_range is not None:
                        if layer_range == "early":
                            layer_start = 0 * num_layers/3
                            layer_end = 1 * num_layers/3 - 1
                        elif layer_range == "middle":
                            layer_start = 1 * num_layers/3
                            layer_end = 2 * num_layers/3 - 1
                        elif layer_range == "late":
                            layer_start = 2 * num_layers/3
                            layer_end = num_layers - 1
                        elif layer_range == "middle+late":
                            layer_start = 1 * num_layers/3
                            layer_end = num_layers - 1
                        layer_start = int(layer_start)
                        layer_end = int(layer_end)
                        
                    if layer_start is not None and layer_end is not None:
                        if layer_idx < layer_start or layer_idx > layer_end:
                            merged_layer_params = dict()
                            for param_name, param_value_list in models_to_merge_param_dict.items():
                                merged_layer_params[param_name] = torch.stack(param_value_list, dim=0).mean(dim=0)
                    ##################### ONLY FOR EXPERIMENTS #####################

                    merged_params_dict = self.update_merged_params_dict(
                        merged_params_dict=merged_params_dict,
                        new_merged_params=merged_layer_params,
                        layer_idx=layer_idx,
                    )

                # 2.3. compute input for the next layer
                with (
                    self.profile("merging models"),
                    self.profile("forwarding next layer"),
                ):
                    if layer_idx < num_layers - 1:
                        backbone_layer.load_state_dict(merged_layer_params, strict=False)
                        batches_output_dict = defaultdict(list)
                        for name in models_to_merge_dict.keys():
                            batches_output_dict[name] = self.layer_batches_forward(
                                backbone_layer, 
                                batches_input_dict[name]
                            )
                        batches_input_dict = batches_output_dict
                
            # 3. load state dict to the merged model
            merged_model.load_state_dict(merged_params_dict, strict=False)

        self.print_profile_summary()
        return merged_model
    
    def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
        """
        Merge the embedding layer of the model with the merged model.
        This method should be implemented in subclasses if needed.
        """
        raise NotImplementedError()

    def get_input_for_first_layer(self, model: nn.Module, train_dataset):
        raise NotImplementedError

    def get_layers(self, model: nn.Module):
        raise NotImplementedError
    
    def update_merged_params_dict(self, merged_params_dict, new_merged_params, layer_idx):
        raise NotImplementedError
    
    def layer_batches_forward(self, layer: nn.Module, batches_input: List[Tensor]):
        raise NotImplementedError

    def on_regmean_start(self):
        pass

    def get_regmean_weights(
        self,
        model_name: str,
        layer: nn.Module,
        batches_input: List[Tensor],
        linear_modules_to_merge: Dict[str, nn.Module],
    ):
        raise NotImplementedError
