import logging
from copy import deepcopy
from typing import Dict, List, Mapping, Optional, Union

import torch
from torch import nn

from fusion_bench.method.base_algorithm import BaseAlgorithm
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
from fusion_bench.modelpool import BaseModelPool
from fusion_bench.models.utils import (
    get_target_state_dict,
    load_state_dict_into_target_modules,
    validate_target_modules_equal,
)
from fusion_bench.utils import LazyStateDict
from fusion_bench.utils.state_dict_arithmetic import (
    state_dict_add,
    state_dict_avg,
    state_dict_div,
    state_dict_mul,
)
from fusion_bench.utils.type import StateDictType

log = logging.getLogger(__name__)


def simple_average(
    modules: List[Union[nn.Module, StateDictType]],
    base_module: Optional[nn.Module] = None,
):
    R"""
    Averages the parameters of a list of PyTorch modules or state dictionaries.

    This function takes a list of PyTorch modules or state dictionaries and returns a new module with the averaged parameters, or a new state dictionary with the averaged parameters.

    If `_fusion_bench_target_modules` attribute is set on the modules, only the parameters of the specified target submodules will be averaged.

    Args:
        modules (List[Union[nn.Module, StateDictType]]): A list of PyTorch modules or state dictionaries.
        base_module (Optional[nn.Module]): A base module to use for the new module. If provided, the averaged parameters will be loaded into this module. If not provided, a new module will be created by copying the first module in the list.

    Returns:
        module_or_state_dict (Union[nn.Module, StateDictType]): A new PyTorch module with the averaged parameters, or a new state dictionary with the averaged parameters.

    Examples:
        >>> import torch.nn as nn
        >>> model1 = nn.Linear(10, 10)
        >>> model2 = nn.Linear(10, 10)
        >>> averaged_model = simple_average([model1, model2])

        >>> state_dict1 = model1.state_dict()
        >>> state_dict2 = model2.state_dict()
        >>> averaged_state_dict = simple_average([state_dict1, state_dict2])
    """
    assert len(modules) > 0, "modules must be a non-empty list"
    validate_target_modules_equal(modules)

    if isinstance(modules[0], nn.Module):
        if base_module is None:
            new_module = deepcopy(modules[0])
        else:
            new_module = base_module
        state_dict = state_dict_avg(
            [get_target_state_dict(module) for module in modules]
        )
        load_state_dict_into_target_modules(new_module, state_dict)
        return new_module
    elif isinstance(modules[0], Mapping):
        return state_dict_avg(modules)


@auto_register_config
class SimpleAverageAlgorithm(SimpleProfilerMixin, BaseAlgorithm):
    def __init__(self, show_pbar: bool = False, inplace: bool = True, **kwargs):
        """
        Args:
            show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
            inplace (bool): If True, overwrites the weights of the first model in the model pool.
                If False, creates a new model for the merged weights. Default is True.
        """
        super().__init__(**kwargs)

    @torch.no_grad()
    def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]) -> nn.Module:
        """
        Fuse the models in the given model pool using simple averaging.

        This method iterates over the names of the models in the model pool, loads each model, and appends it to a list.
        It then returns the simple average of the models in the list.

        Args:
            modelpool: The pool of models to fuse.

        Returns:
            The fused model obtained by simple averaging.
        """
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(modelpool)

        log.info(
            f"Fusing models using simple average on {len(modelpool.model_names)} models. "
            f"models: {modelpool.model_names}"
        )
        if modelpool.has_instance_models and self.inplace:
            log.warning(
                "The model pool contains instance models, and inplace is set to True. "
                "Therefore, the weights of the first model will be overwritten. "
                "If this is desired behavior, this warning can be ignored."
            )

        sd: Optional[StateDictType] = None
        forward_model = None
        merged_model_names = []

        for model_name in modelpool.model_names:
            with self.profile("load model"):
                model = modelpool.load_model(model_name)
                merged_model_names.append(model_name)
                print(f"load model of type: {type(model).__name__}")
            with self.profile("merge weights"):
                if sd is None:
                    # Initialize the state dictionary with the first model's state dictionary
                    sd = get_target_state_dict(model)
                    forward_model = model if self.inplace else deepcopy(model)
                else:
                    # Add the current model's state dictionary to the accumulated state dictionary
                    sd = state_dict_add(
                        sd, get_target_state_dict(model), show_pbar=self.show_pbar
                    )
        with self.profile("merge weights"):
            # Divide the accumulated state dictionary by the number of models to get the average
            sd = state_dict_div(
                sd, len(modelpool.model_names), show_pbar=self.show_pbar
            )

        if isinstance(forward_model, LazyStateDict):
            # if the model is a LazyStateDict, convert it to an empty module
            forward_model = deepcopy(forward_model.meta_module).to_empty(
                device=forward_model._device
            )

        result = load_state_dict_into_target_modules(forward_model, sd, strict=False)
        if result.unexpected_keys:
            raise ValueError(f"Unexpected keys in state dict: {result.unexpected_keys}")
        if result.missing_keys:
            log.warning(f"Missing keys in state dict: {result.missing_keys}")

        # print profile report and log the merged models
        self.print_profile_summary()
        log.info(f"merged {len(merged_model_names)} models:")
        for model_name in merged_model_names:
            log.info(f"  - {model_name}")
        return forward_model
