R"""
```python
# Get the task-wise weights
task_wise_weights = get_task_wise_weights(num_models)

# Define the task vectors (in this case, we'll use the state_dict of the pretrained model)
task_vectors = ...

# Initialize the TaskWiseMergedModel
merged_model = TaskWiseMergedModel(pretrained_model, task_wise_weights, task_vectors)

# Now you can use the merged_model like a regular PyTorch model
outputs = merged_model(inputs)
```
"""
import logging
from re import S
import types
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Iterator, List

import torch
from torch import Tensor, nn
from torch.func import functional_call

from .ties_merging_utils import check_parameterNamesMatch
from .type import StateDict
from .utils import timeit_context

log = logging.getLogger(__name__)

__all__ = ["fuse_weights", "RouteMergedModel"]

# def _fuse_weights(
#     mask_pre: Tensor,
#     mask_post: Tensor,
#     task_vectors: List[Tensor],
#     mask_alpha: float = 0.5,
#     device: str = "cuda:1",
# ) -> Tensor:
#     """
#     This function fuses the weights of the models.

#     Args:
#         mask_pre (Tensor): The mask for the pre-task.
#         mask_post (Tensor): The mask for the post-task.
#         task_vectors (List[Tensor]): The task vectors.

#     Returns:
#         Tensor: The fused weights.
#     """
#     assert len(task_vectors) == 2
#     return mask_pre.to(device) * task_vectors[0].to(device) * (mask_alpha) + mask_post.to(device) * task_vectors[1].to(device) * (1-mask_alpha)


# def fuse_weights(
#     task_vector_pre: StateDict,
#     task_vector_post: StateDict,
#     masks_pre: Dict[str, Tensor],
#     masks_post: Dict[str, Tensor],
#     mask_alpha: float = 0.5,
#     device: str = "cuda:1",
# ) -> StateDict:
    
#     task_vector = {}
#     for k in task_vector_pre.keys():
#         task_vector[k] = _fuse_weights(masks_pre[k], masks_post[k], [task_vector_pre[k], task_vector_post[k]], mask_alpha, device=device)
#     return task_vector

from singular_vector import TSVM_utils

def _multi_weights(
    mask_pre: Tensor,
    task_vector: Tensor,
    device: str = "cuda:1",
) -> Tensor:

    return mask_pre.to(device) * task_vector.to(device)


def fuse_weights(
    task_vector_pre: StateDict,
    task_vector_post: StateDict,
    masks_pre: Dict[str, Tensor],
    masks_post: Dict[str, Tensor],
    device: str = "cuda:1",
) -> StateDict:
    
    for k in task_vector_pre.keys():
        #检查是否为nan
        if torch.isnan(masks_pre[k]).any():
            raise ValueError(f"Task vector {k} contains NaN values.")
        task_vector_pre[k] = _multi_weights(masks_pre[k], task_vector_pre[k], device=device)
        task_vector_post[k] = _multi_weights(masks_post[k], task_vector_post[k], device=device)

    config = types.SimpleNamespace()
    config.DATASETS = ["0", "1"]
    config.device = device
    return TSVM_utils.compute_and_sum_svd_mem_reduction([task_vector_pre, task_vector_post], config)


class RouteMergedModel(nn.Module):
    def __init__(
        self,
        pretrained_model: nn.Module,
        task_vector_pre: StateDict,
        task_vector_post: StateDict,
        masks_pre: Dict[str, Tensor],
        masks_post: Dict[str, Tensor],
        mask_alpha: float,
        device: str,
    ):
        super().__init__()
        self._model = (pretrained_model,)  # self._model should be on cpu
        self.pretrained_model = pretrained_model
        self.task_vector_pre = task_vector_pre
        self.task_vector_post = task_vector_post
        self.masks_pre = masks_pre
        self.masks_post = masks_post
        self.mask_alpha = nn.Parameter(torch.tensor(mask_alpha), requires_grad=False)
        self.device = device
        self.pretrained_state_dict: StateDict = self.model.state_dict(keep_vars=False)
        self.merged_state_dict = None

    @property
    def model(self):
        return self._model[0]

    def merge_weights(self):
        task_vector = fuse_weights(self.task_vector_pre, self.task_vector_post, self.masks_pre, self.masks_post, device=self.device)
        device = self.device
        self.merged_state_dict = {k: self.pretrained_state_dict[k].to(device, non_blocking=True) for k in self.pretrained_state_dict.keys()}
        for k in task_vector.keys():
            self.merged_state_dict[k] += task_vector[k]

    def forward(self, *args, **kwargs):
        if self.merged_state_dict is None:
            self.merge_weights()
        return functional_call(
            self.model,
            self.merged_state_dict,
            args=args,
            kwargs=kwargs,
            tie_weights=False,
        )
 
    def __getattr__(self, name: str) -> Any:
        try:
            return super().__getattr__(name)
        except AttributeError:
            attr = getattr(self.model, name)
            if isinstance(attr, Callable):
                warnings.warn(f"forwarding `{name}` to the underlying model", UserWarning)
            return attr

    def __setattr__(self, name: str, value: Any) -> None:
        try:
            super().__setattr__(name, value)
        except AttributeError:
            setattr(self.model, name, value)
