import re
from collections.abc import Callable

import torch
import torch.nn as nn


def _get_param_names_to_merge(
    input_param_names: list[str], exclude_param_names_regex: list[str] | None
) -> list[str]:
    """
    get the names of parameters that need to be merged
    :param input_param_names: list, names of input parameters
    :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
    :return:
    """
    if exclude_param_names_regex is None:
        exclude_param_names_regex = []
    param_names_to_merge = []
    for param_name in input_param_names:
        exclude = any(
            [
                re.match(exclude_pattern, param_name)
                for exclude_pattern in exclude_param_names_regex
            ]
        )
        if not exclude:
            param_names_to_merge.append(param_name)
    return param_names_to_merge


class TaskVector:
    def __init__(
        self,
        base_model: nn.Module | None = None,
        finetuned_model: nn.Module | None = None,
        exclude_param_names_regex: list[str] | None = None,
        finetuned_param_name_convert_fn: Callable[[str], str] | None = None,
    ):
        """
        Task vector. Initialize the task vector from a pretrained model and a finetuned model, or
        directly passing the task_vector_param_dict dictionary.
        :param pretrained_model: nn.Module, pretrained model
        :param finetuned_model: nn.Module, finetuned model
        :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
        :param task_vector_param_dict: dict, task vector to initialize self.task_vector_param_dict
        """
        self.task_vector_param_dict = {}
        if base_model is None or finetuned_model is None:
            return
        pretrained_param_dict = {
            param_name: param_value
            for param_name, param_value in base_model.named_parameters()
        }
        self.base_model_param_names = [
            name for name, _ in base_model.named_parameters()
        ]
        finetuned_param_dict = {
            (finetuned_param_name_convert_fn or (lambda s: s))(param_name): param_value
            for param_name, param_value in finetuned_model.named_parameters()
        }
        param_names_to_merge = _get_param_names_to_merge(
            input_param_names=list(pretrained_param_dict.keys()),
            exclude_param_names_regex=exclude_param_names_regex,
        )
        with torch.no_grad():
            for param_name in param_names_to_merge:
                if param_name not in finetuned_param_dict:
                    print(
                        f"param_name {param_name} is not contained in finetuned model!"
                    )
                    self.task_vector_param_dict[param_name] = torch.zeros_like(
                        input=pretrained_param_dict[param_name],
                        dtype=torch.float16,
                    )
                else:
                    self.task_vector_param_dict[param_name] = (
                        finetuned_param_dict[param_name]
                        - pretrained_param_dict[param_name]
                    )

    @classmethod
    def from_param_dict(cls, task_vector_param_dict: dict, base_model: nn.Module):
        """
        create TaskVector from task_vector_param_dict
        :param task_vector_param_dict: dict, task vector to initialize self.task_vector_param_dict
        :return:
        """
        task_vector = cls()
        task_vector.task_vector_param_dict = task_vector_param_dict
        task_vector.base_model_param_names = [
            name for name, _ in base_model.named_parameters()
        ]
        return task_vector

    def __add__(self, other):
        """
        add task vector
        :param other: TaskVector to add, at right side
        :return:
        """
        assert isinstance(
            other, TaskVector
        ), "addition of TaskVector can only be done with another TaskVector!"
        new_task_vector_param_dict = {}
        with torch.no_grad():
            for param_name in self.task_vector_param_dict:
                assert (
                    param_name in other.task_vector_param_dict.keys()
                ), f"param_name {param_name} is not contained in both task vectors!"
                new_task_vector_param_dict[param_name] = (
                    self.task_vector_param_dict[param_name]
                    + other.task_vector_param_dict[param_name]
                )
        return TaskVector.from_param_dict(
            task_vector_param_dict=new_task_vector_param_dict
        )

    def __radd__(self, other):
        """
        other + self = self + other
        :param other: TaskVector to add, at left side
        :return:
        """
        return self.__add__(other)

    def combine_with_pretrained_model(
        self, base_model: nn.Module, scaling_coefficient: float = 1.0
    ):
        """
        combine the task vector with pretrained model
        :param pretrained_model: nn.Module, pretrained model
        :param scaling_coefficient: float, scaling coefficient to merge the task vector
        :return:
        """
        pretrained_param_dict = {
            param_name: param_value
            for param_name, param_value in base_model.named_parameters()
        }

        with torch.no_grad():
            merged_params = {}
            for param_name in self.task_vector_param_dict:
                merged_params[param_name] = (
                    pretrained_param_dict[param_name]
                    + scaling_coefficient * self.task_vector_param_dict[param_name]
                )

        return merged_params
