import torch

class TaskVector():
    def __init__(self, pretrained_state_dict=None, finetuned_state_dict=None, vector=None):
        if vector is not None:
            self.vector = vector
        else:
            assert pretrained_state_dict is not None and finetuned_state_dict is not None
            with torch.no_grad():
                self.vector = {}
                for key in pretrained_state_dict:
                    if pretrained_state_dict[key].dtype in [torch.int64, torch.uint8]:
                        continue
                    self.vector[key] = finetuned_state_dict[key] - pretrained_state_dict[key]
    
    def __add__(self, other):
        with torch.no_grad():
            new_vector = {}
            for key in self.vector:
                if key not in other.vector:
                    print(f'Warning, key {key} is not present in both task vectors.')
                    continue
                new_vector[key] = self.vector[key] + other.vector[key]
        return TaskVector(vector=new_vector)

    def __radd__(self, other):
        if other is None or isinstance(other, int):
            return self
        return self.__add__(other)

    def weightmerging(self, taskvectors, coefficients):
        with torch.no_grad():
            new_vector = {}
            common_keys = set.intersection(*(set(tv.vector.keys()) for tv in taskvectors))
            for key in self.vector:
                if key not in common_keys:
                    print(f'Warning, key {key} is not present in all task vectors.')
                    continue
                if 'head' in key or 'classifier' in key:
                    continue
                new_vector[key] = sum(coefficients[k] * taskvectors[k].vector[key] for k in range(len(taskvectors)))
        return TaskVector(vector=new_vector)

    def apply_to(self, pretrained_model, finetuned_state_dict, scaling_coef=1.0):
        with torch.no_grad():
            new_state_dict = {}
            pretrained_state_dict = pretrained_model.state_dict()
            for key in pretrained_state_dict:
                if key not in self.vector:
                    continue
                new_state_dict[key] = pretrained_state_dict[key] + scaling_coef * self.vector[key]
        pretrained_model.load_state_dict(new_state_dict, strict=False)
        
        for head_key in ['head.weight', 'head.bias']:
            if head_key in finetuned_state_dict:
                pretrained_model.state_dict()[head_key].copy_(finetuned_state_dict[head_key])
                
        return pretrained_model
    
    def sparsify(self, keep_ratio):
        with torch.no_grad():
            cpu_values = torch.cat(
                [v.detach().flatten().abs().cpu() for v in self.vector.values()]
            )
        k = max(int(cpu_values.numel() * keep_ratio), 1)
        threshold = torch.topk(cpu_values, k, largest=True).values[-1]
        del cpu_values
        for key, tensor in self.vector.items():
            mask = tensor.abs() >= threshold.to(tensor.device)
            self.vector[key] = tensor * mask
            