from collections import defaultdict, OrderedDict
from copy import deepcopy
import torch.nn.functional as F
from tqdm.auto import tqdm
from time import time
from torch import nn
import torch
import pdb
#import ipdb
from collections import defaultdict
from argparse import ArgumentError
import random
from typing import Callable, Optional, Union

from utils import compute_fisher_qkv_outproj, get_clip_encodings, get_merging_fn, get_mask_fn
from masking_ops import masked_merge
from transformers.masking_utils import create_causal_mask
from transformers.integrations.sdpa_attention import sdpa_attention_forward

import numpy as np
from scipy.sparse.linalg import LinearOperator
from scipy.sparse.linalg import cg
from peft.tuners.lora.layer import Linear as PeftLinear
from time import time

def forward_wrapper(func: Callable, *args):
    """
    A wrapper to call a function with the given arguments and keyword arguments.
    This is useful for passing functions as callbacks or for use in higher-order functions.
    """
    hidden_states: torch.Tensor = args[0]
    pad_tokens: int = args[1] if len(args) > 1 else 0
    return func(hidden_states=hidden_states)

class StaticPrinter():
    printed=False
    @staticmethod
    def static_print(**kwargs):
        if not StaticPrinter.printed:
            to_be_printed = [(k,v) for k,v in kwargs.items()]
            print(to_be_printed)
            StaticPrinter.printed=True

class CustomLinear(PeftLinear):
    def __init__(self, matrix: PeftLinear):
        for elem, value in vars(matrix).items():
            setattr(self, elem, value)
        self.default_forward = matrix.forward
    def forward(self, x, pad_tokens=0):
        return self.default_forward(x)


class VectorOps(nn.Module):
    def directions_to_reps(self, directions):
        if isinstance(directions, list):
            return [self.directions_to_reps(direction) for direction in directions]
        return torch.nn.utils.parameters_to_vector(
            [value.reshape(-1) for key, value in directions.items()]
        )

    def rep_to_state_dict(self, vector, state_dict, remove_keys=[]):
        if isinstance(vector, list) or len(vector.shape) == 2:
            return [self.rep_to_state_dict(v, state_dict, remove_keys) for v in vector]
        # create a reference dict to define the order of the vector
        reference_dict = state_dict
        for key in remove_keys:
            if key in reference_dict:
                del reference_dict[key]
        sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))

        # create a shared state dict using the refence dict
        torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())

        # add back the encoder and decoder embedding weights.
        if "transformer.shared.weight" in sorted_reference_dict:
            for key in remove_keys:
                sorted_reference_dict[key] = sorted_reference_dict[
                    "transformer.shared.weight"
                ]
        return sorted_reference_dict

    def mask_to_state_dict(self, mask, state_dict, remove_keys=[]):
        if isinstance(mask, list):
            return [self.mask_to_state_dict(m, state_dict, remove_keys) for m in mask]
        return self.rep_to_state_dict(mask, state_dict, remove_keys)

    def forward(self, directions, merging_fn, merge_config):
        vectors = self.directions_to_reps(directions)
        merged_vector,rows_to_keep, topk_mask = merging_fn(vectors)
        mask_sd = self.rep_to_state_dict(topk_mask, directions[0])

        ties_mask = [dict() for _ in range(len(rows_to_keep))]
        for idx in range(len(rows_to_keep)):
            ties_mask[idx] = self.rep_to_state_dict(rows_to_keep[idx], directions[0])
        sd = self.rep_to_state_dict(merged_vector, directions[0])

        return sd, ties_mask


class TaskMerger(nn.Module):
    def __init__(self, finetuned_models, pretrained_model, param_handler, device=0, merge_config=None):
        super().__init__()

        self.device = device
        self.scaling_coeffs = torch.tensor([1.] * len(finetuned_models))
        self.param_handler = param_handler
        self.finetuned_models = finetuned_models
        self.ftms_params = [param_handler(ft_model) for ft_model in finetuned_models]
        self.pretrained_model = pretrained_model.cpu()
        self.pt_params = self.pretrained_model.state_dict()
        self.merge_config = merge_config

    def randbin(self, M, N, P):
        P = 1-P
        return torch.randint(2, size=(M, N), dtype=torch.float32).bernoulli(P)

    def apply_dare(self, ftms_params, p, dare_seed = 0):
        print("DARE seed: ", dare_seed)
        torch.manual_seed(dare_seed)
        finetuned_directions = []
        for ftm_params in ftms_params:
            direction_sd = {}
            for key, finetuned_val in ftm_params.items():
                direction_sd[key] = finetuned_val * self.randbin(finetuned_val.shape[0], finetuned_val.shape[1], p) * (1/(1-p))
                # pdb.set_trace()
            finetuned_directions += [OrderedDict(sorted(direction_sd.items()))]
        return finetuned_directions

    def get_task_directions(self, ptm_params, ftms_params):
        finetuned_directions = []
        for ftm_params in ftms_params:
            direction_sd = {}

            for key, finetuned_val in ftm_params.items():
                if key not in ptm_params:
                    ptm_val = torch.zeros_like(finetuned_val)
                else:
                    ptm_val = ptm_params[key]
                direction_sd[key] = finetuned_val - ptm_val
            finetuned_directions += [OrderedDict(sorted(direction_sd.items()))]
        return finetuned_directions

    def set_scaling_coeffs(self, scaling_coeffs):
        if isinstance(scaling_coeffs, float) or len(scaling_coeffs) == 1:
            self.scaling_coeffs = torch.tensor([scaling_coeffs] * len(self.ftms_params))
        else:
            self.scaling_coeffs = torch.tensor(scaling_coeffs)

    def get_layer_names(self, state_dict):
        layer_names = defaultdict(lambda: dict())
        for key in state_dict:
            if ('.weight' in key) or ('_weight' in key):
                strip_key = key.replace('.weight', '').replace('_weight', '')
                layer_names[strip_key]['weight'] = key
            elif ('.bias' in key) or ('_bias' in key):
                strip_key = key.replace('.bias', '').replace('_bias', '')
                layer_names[strip_key]['bias'] = key
            else:
                layer_names[key]['other'] = key + ':other'
        return layer_names

    def add_task_parameters(self, base_model, parameters, concat_across_output = True, scaling_coeffs=1.):
        if isinstance(parameters, list):
            return [self.add_task_parameters(
                base_model,
                parameter,
                concat_across_output=concat_across_output,
                scaling_coeffs=scaling_coeffs
            ) for parameter in parameters]
        sd = base_model.state_dict()
        for key, val in parameters.items():
            try:
                new_key = key.replace(".weight", ".base_layer.weight")
                if (concat_across_output):
                    sd[new_key].add_(val.cpu() * scaling_coeffs)
                else:
                    sd[new_key].add_(val.T.cpu() * scaling_coeffs)
            except Exception as e:
                print(f"Error adding parameter {key} to base model: {e}")
                pdb.set_trace()
        return base_model

    def directions_to_matrices(self, directions, reference_layer_names=None):
        if isinstance(directions, list):
            return [self.directions_to_matrices(direction, reference_layer_names) for direction in directions]

        if reference_layer_names is None:
            layer_names = self.get_layer_names(directions)
        else:
            layer_names = reference_layer_names

        matrices = {}
        for layer_name, parameter_names in layer_names.items():
            if 'other' in parameter_names:
                other_parameter = directions[parameter_names['other'].replace(':other', '')].to(torch.float32)
                # Ensure parameters are always two dimensional
                if len(other_parameter.shape) == 1: # e.g., class token, positional embeddings
                    other_parameter = other_parameter[None, :]
                elif len(other_parameter.shape) > 2: # e.g., patch embeddings
                    other_parameter = other_parameter.flatten(1)
                matrices[layer_name + ':other'] = other_parameter
            elif 'weight' in parameter_names:
                weight_name = parameter_names['weight']
                weight = directions[weight_name]
                if 'norm' in layer_name or 'ln' in layer_name:
                    weight = torch.diag(weight)
                matrices[layer_name] = weight.flatten(1)
                if 'bias' in parameter_names:
                    bias = directions[parameter_names['bias']]
                    matrices[layer_name] = torch.concat((matrices[layer_name], bias.reshape(-1, 1)), dim=1)
        return matrices

    def matrix_to_state_dict(self, matrix, state_dict, remove_keys=[]):
        if isinstance(matrix, list):
            return [self.matrix_to_state_dict(m, state_dict) for m in matrix]

        reference_dict = state_dict
        for key in remove_keys:
            if key in reference_dict:
                del reference_dict[key]

        layer_names = self.get_layer_names(reference_dict)
        merged_state_dict = {}
        for layer_name, value in matrix.items():
            try:
                parameter_types = layer_names[layer_name.replace(':other', '')]
                if 'other' in parameter_types:
                    name = parameter_types['other'].replace(':other', '')
                    merged_state_dict[name] = value.reshape(reference_dict[name].shape)
                else:
                    if 'bias' in parameter_types:
                        bias_index = value.shape[1] - 1
                        value, bias = value[:, :bias_index], value[:, -1].flatten()
                        merged_state_dict[parameter_types['bias']] = bias
                    if 'norm' in layer_name or 'ln' in layer_name:
                        value = torch.diagonal(value)
                    name = parameter_types['weight']
                    merged_state_dict[name] = value.reshape(*(reference_dict[name].shape))
            except:
                pdb.set_trace()

        # add back the encoder and decoder embedding weights.
        if "transformer.shared.weight" in merged_state_dict:
            for key in remove_keys:
                merged_state_dict[key] = merged_state_dict[
                    "transformer.shared.weight"
                ]
        return merged_state_dict

    def transform(self, *args, **kwargs):
        return

class VectorMerger(TaskMerger):
    def __init__(self, finetuned_models, pretrained_model, param_handler, device=0, merge_config=None, dataloaders=None):
        super().__init__(
            finetuned_models=finetuned_models,
            pretrained_model=pretrained_model,
            param_handler=param_handler,
            device=device,
            merge_config=merge_config
        )

        self.representation_helper = VectorOps()

    def merge(self, merge_config={'merge_method': 'tv'}):
        print(merge_config['merge_method'])
        merging_fn = lambda x: get_merging_fn(merge_config['merge_method'])(
            x, **merge_config, weights=self.scaling_coeffs
        )

        ptm_reference_params = self.param_handler(self.pretrained_model).get_ft_parameters()
        try:
            ftms_relevant_params = [ftm.get_ft_parameters() for ftm in self.ftms_params]
        except Exception as e:
            import pdb
            print(e)
            reserved = torch.cuda.memory_reserved()
            allocated = torch.cuda.memory_allocated()
            print(f"Reserved memory: {reserved / 1000000000}GB, Allocated memory: {allocated / 1000000000}GB.")
            pdb.set_trace()
        keys = ftms_relevant_params[0].keys()
        ensemble = torch.stack([torch.cat([ftm[key].flatten() for key in keys], dim=0) for ftm in ftms_relevant_params], dim=0).to(self.device).to(torch.float32).sum()
        print(f"Total sum: {ensemble}")
        ftms_task_dirs = self.get_task_directions(ptm_reference_params, ftms_relevant_params)
        ensemble = torch.stack([torch.cat([ftm[key].flatten() for key in keys], dim=0) for ftm in ftms_task_dirs], dim=0).to(self.device).to(torch.float32).sum()
        print(f"Total sum directions: {ensemble}")
        #import ipdb
        #ipdb.set_trace()

        if merge_config.get('dare', False):
            ftms_task_dirs = self.apply_dare(
                ftms_task_dirs, merge_config['dare_pruning_coeffs'], merge_config['dare_seed']
            )

        merged_sd = self.representation_helper(ftms_task_dirs, merging_fn, merge_config)
        self.pretrained_model = self.pretrained_model.to("cpu")
        for key in ptm_reference_params.keys():
            ptm_reference_params[key] = ptm_reference_params[key].to("cpu")
        torch.cuda.empty_cache()
        for i in range(len(ftms_relevant_params)):
            for key in ftms_relevant_params[i].keys():
                ftms_relevant_params[i][key] = ftms_relevant_params[i][key].to("cpu")
        torch.cuda.empty_cache()

        merged_base = deepcopy(self.pretrained_model)
        print(f"merged_base model device: {next(iter(merged_base.parameters())).device}")
        if len(merged_sd) == 2:
            merged_sd, mask = merged_sd
        #import ipdb
        #ipdb.set_trace()
        keys = list(self.pretrained_model.state_dict().keys())
        #import ipdb
        #ipdb.set_trace()
        print(f"Pre  {keys[52]}: {self.pretrained_model.state_dict()[keys[52]].to(torch.float32).sum()}")
        merged_model = self.add_task_parameters(merged_base, merged_sd)
        delta_keys = list(ftms_relevant_params[0].keys())
        final_model_delta = torch.cat([merged_model.state_dict()[key.replace('.weight', '.base_layer.weight')].flatten() for key in delta_keys], dim=0)
        torch.save(final_model_delta, 'final_model_delta.pt')
        print("Saved final_model_delta.pt")
        print(f"Post {keys[52]}: {self.pretrained_model.state_dict()[keys[52]].to(torch.float32).sum()}")
        #pdb.set_trace()

        return merged_model
    
class SensitivityMerger(TaskMerger):
    def __init__(self, finetuned_models, pretrained_model, param_handler, device=0, merge_config=None, dataloaders=None):
        super().__init__(
            finetuned_models=finetuned_models,
            pretrained_model=pretrained_model,
            param_handler=param_handler,
            device=device,
            merge_config=merge_config
        )

        self.representation_helper = VectorOps()

    def merge(self, merge_config={'merge_method': 'tv'}):
        print(merge_config['merge_method'])
        merging_fn = lambda x: get_merging_fn(merge_config['merge_method'])(
            x, **merge_config, weights=self.scaling_coeffs
        )

        ptm_reference_params = self.param_handler(self.pretrained_model).get_ft_parameters()
        try:
            ftms_relevant_params = [ftm.get_ft_parameters() for ftm in self.ftms_params]
        except Exception as e:
            import pdb
            print(e)
            reserved = torch.cuda.memory_reserved()
            allocated = torch.cuda.memory_allocated()
            print(f"Reserved memory: {reserved / 1000000000}GB, Allocated memory: {allocated / 1000000000}GB.")
            pdb.set_trace()
        keys = ftms_relevant_params[0].keys()
        ensemble = torch.stack([torch.cat([ftm[key].flatten() for key in keys], dim=0) for ftm in ftms_relevant_params], dim=0).to(self.device).to(torch.float32).sum()
        print(f"Total sum: {ensemble}")
        ftms_task_dirs = self.get_task_directions(ptm_reference_params, ftms_relevant_params)
        ensemble = torch.stack([torch.cat([ftm[key].flatten() for key in keys], dim=0) for ftm in ftms_task_dirs], dim=0).to(self.device).to(torch.float32).sum()
        print(f"Total sum directions: {ensemble}")
        #import ipdb
        #ipdb.set_trace()

        if merge_config.get('dare', False):
            ftms_task_dirs = self.apply_dare(
                ftms_task_dirs, merge_config['dare_pruning_coeffs'], merge_config['dare_seed']
            )

        merged_sd = self.representation_helper(ftms_task_dirs, merging_fn, merge_config)
        self.pretrained_model = self.pretrained_model.to("cpu")
        for key in ptm_reference_params.keys():
            ptm_reference_params[key] = ptm_reference_params[key].to("cpu")
        torch.cuda.empty_cache()
        for i in range(len(ftms_relevant_params)):
            for key in ftms_relevant_params[i].keys():
                ftms_relevant_params[i][key] = ftms_relevant_params[i][key].to("cpu")
        torch.cuda.empty_cache()

        merged_base = deepcopy(self.pretrained_model)
        print(f"merged_base model device: {next(iter(merged_base.parameters())).device}")
        if len(merged_sd) == 2:
            merged_sd, mask = merged_sd
        #import ipdb
        #ipdb.set_trace()
        keys = list(self.pretrained_model.state_dict().keys())
        #import ipdb
        #ipdb.set_trace()
        print(f"Pre  {keys[52]}: {self.pretrained_model.state_dict()[keys[52]].to(torch.float32).sum()}")
        merged_model = self.add_task_parameters(merged_base, merged_sd)
        delta_keys = list(ftms_relevant_params[0].keys())
        final_model_delta = torch.cat([merged_model.state_dict()[key.replace('.weight', '.base_layer.weight')].flatten() for key in delta_keys], dim=0)
        torch.save(final_model_delta, 'final_model_delta.pt')
        print("Saved final_model_delta.pt")
        print(f"Post {keys[52]}: {self.pretrained_model.state_dict()[keys[52]].to(torch.float32).sum()}")
        #pdb.set_trace()

        return merged_model

class SVDMerger(TaskMerger):
    def __init__(self, finetuned_models, pretrained_model, param_handler, device=0, merge_config=None, dataloaders=None):
        super().__init__(
            finetuned_models=finetuned_models,
            pretrained_model=pretrained_model,
            param_handler=param_handler,
            device=device,
            merge_config=merge_config
        )

        self.layer_names = self.get_layer_names(self.ftms_params[0].get_ft_parameters())
        self.representation_helper = VectorOps()
        self.ingredients = None

    def variable_extend_dim(self, elements, op_dim):
        if isinstance(elements, list):
            return [self.variable_extend_dim(element, op_dim) for element in elements]
        while len(elements.shape) < (op_dim+1):
            elements = elements.unsqueeze(-1)
        return elements

    def dict_of_concat_matrices(self, list_of_dictmatrices, dim=0, concat_across_output = True):
        dict2matrix_stack = defaultdict(lambda: list())
        for dict2matrix in list_of_dictmatrices:
            for key, val in dict2matrix.items():
                if(concat_across_output == True):
                    dict2matrix_stack[key] += [val.to(self.device)]
                else:
                    dict2matrix_stack[key] += [val.T.to(self.device)]

        for key, list_of_vals in dict2matrix_stack.items():
            # Extend dim as necessary
            list_of_vals = self.variable_extend_dim(list_of_vals, op_dim=dim)
            dict2matrix_stack[key] = torch.concat(list_of_vals, dim=dim)
        return dict2matrix_stack

    def reconstruct_merged_sd(self, U_sd, sV_sd):
        if isinstance(sV_sd, list):
            if isinstance(U_sd, list):
                return [self.reconstruct_merged_sd(U, sV) for U, sV in zip(U_sd, sV_sd)]
            return [self.reconstruct_merged_sd(U_sd, sV) for sV in sV_sd]
        sd = {}
        for key, U in U_sd.items():
            sd[key] = (U @ sV_sd[key]).to(torch.float32)
        return sd

    def apply_svd(self, ft_params, concat_across_output = True):
        UsV_dict = {}
        basis_dict = {} # basis for reconstruction
        s_compositions_dict = [dict() for _ in range(len(ft_params))]
        V_compositions_dict = [dict() for _ in range(len(ft_params))] # basis composition information per task

        print(f'Calculating SVD over {len(ft_params)} models. S > 1e-5')
        concated_ft_params = self.dict_of_concat_matrices(ft_params, dim=1, concat_across_output = concat_across_output)
        for key, val in tqdm(concated_ft_params.items(), desc='Obtaining SVDs...'):
            U, s, V = torch.linalg.svd(val.to(torch.float64), full_matrices=False)
            # Keep only supported basis components
            U = U[:, s > 1e-5].type(torch.float32)
            V = V[s > 1e-5].type(torch.float32)
            s = s[s > 1e-5].type(torch.float32)
            UsV_dict[key] = {'U': U, 's': s, 'V': V}
            # Set all s to be the same scale
            s[s <= 1e-5] = 0
            cat_hidden_dim = V.shape[1] // len(ft_params)

            basis_dict[key] = U.cpu()
            sV_concat = V
            Vs = list(torch.split(sV_concat, cat_hidden_dim, dim=1))
            for idx, V in enumerate(Vs):
                V = torch.diag(s) @ V # Simple and safe for all merging methods we use.
                s_model = s / s

                s_compositions_dict[idx][key] = s_model.cpu()
                V_compositions_dict[idx][key] = V.cpu()
        return basis_dict, s_compositions_dict, V_compositions_dict, UsV_dict

    def apply_Ss_on_Vs(self, task_Vs, task_Ss):
        task_sVs = [dict() for i in range(len(task_Vs))]
        for idx, (Vs, Ss) in enumerate(zip(task_Vs, task_Ss)):
            for key, V in Vs.items():
                if len(Ss[key].shape) == 2:
                    task_sVs[idx][key] = Ss[key] @ V
                else:
                    task_sVs[idx][key] = torch.diag(Ss[key]) @ V
        return task_sVs

    def remove_others(self, ftms_mats):
        other_mats = [dict() for i in range(len(ftms_mats))]
        transform_mats = [dict() for i in range(len(ftms_mats))]

        for m_idx, ftm_mats in enumerate(ftms_mats):
            for key, val in ftm_mats.items():
                if ':other' in key:
                    other_mats[m_idx][key] = val
                elif 'modules_to_save' in key:
                    other_mats[m_idx][key] = val
                else:
                    transform_mats[m_idx][key] = val
        print(f'Len other: {len(other_mats[0])}| len: transform: {len(transform_mats[0])}')
        return other_mats, transform_mats

    def add_others(self, ftms_mats, ftms_others):
        if isinstance(ftms_mats, list):
            return [self.add_others(ftms_mat, ftms_other) for ftms_mat, ftms_other in zip(ftms_mats, ftms_others)]

        for key, val in ftms_others.items():
            ftms_mats[key] = val
        return ftms_mats

    def transform(self, merge_config):
        # Setup parameters
        ptm_reference_params = self.param_handler(self.pretrained_model).get_ft_parameters()
        ftms_relevant_params = [ftm.get_ft_parameters() for ftm in self.ftms_params]
        ftms_task_dirs = self.get_task_directions(ptm_reference_params, ftms_relevant_params)

        ftms_task_mats = self.directions_to_matrices(ftms_task_dirs)
        ftms_others, ftms_mats = self.remove_others(ftms_task_mats)

        U, task_Ss, task_sVs, UsV_dict = self.apply_svd(
            ftms_mats,
            concat_across_output = merge_config.get('concat_across_output', True),
        )

        self.ingredients = {
            'ftms_relevant_params': ftms_relevant_params,
            'ftms_others': ftms_others,
            'ptm_reference_params': ptm_reference_params,
            'U': U,
            'task_Ss': task_Ss,
            'task_sVs': task_sVs,
            'UsV_dict': UsV_dict,
        }

        if merge_config.get('ingredients_path') is not None:
            torch.save(self.ingredients, merge_config['ingredients_path'])

    def merge(self, merge_config):
        if merge_config.get('ingredients_path') is not None:
            ingredients = torch.load(merge_config['ingredients_path'])
        else:
            ingredients = self.ingredients

        ftms_others = ingredients['ftms_others']
        ptm_reference_params = ingredients['ptm_reference_params']
        U = ingredients['U']
        task_Ss = ingredients['task_Ss']
        task_sVs = ingredients['task_sVs']

        if merge_config.get('dare', False):
            print("Applying DARE")
            task_sVs = self.apply_dare(
                task_sVs, merge_config['dare_pruning_coeffs'], merge_config['dare_seed']
            )

        representations = self.representation_helper.directions_to_reps(task_sVs)
        ftms_reps = representations

        mask_fn = get_mask_fn(merge_config['merge_method'])
        masks = mask_fn(ftms_reps, **merge_config)
        ftms_reps = torch.vstack(ftms_reps).clone()
        masked_sVs = ftms_reps * masks
        pre_merge_sVs_dict = self.representation_helper.rep_to_state_dict(masked_sVs, task_sVs[0])
        rescaled_Vs = self.apply_Ss_on_Vs(pre_merge_sVs_dict, task_Ss)

        rescaled_Vs = torch.stack(self.representation_helper.directions_to_reps(rescaled_Vs), dim=0)
        merged_sV_ = masked_merge(
            merge_func=merge_config.get('merging_type'), vectors=rescaled_Vs, weights=self.scaling_coeffs
        )
        merged_sV_sd = self.representation_helper.rep_to_state_dict(merged_sV_, task_sVs[0])

        merged_sd = self.reconstruct_merged_sd(U, merged_sV_sd)

        merging_fn = lambda x: get_merging_fn(merge_config['merge_method'])(
            x, **merge_config, weights=self.scaling_coeffs
        )
        if merge_config.get('merge_other_params', False):
            merged_others,_ = self.representation_helper(ftms_others,  merging_fn=merging_fn)
            merged_sd = self.add_others(merged_sd, merged_others)

        merged_sd = self.matrix_to_state_dict(merged_sd, ptm_reference_params)
        # Add merged sd to the ptm
        merged_base = self.pretrained_model
        merged_model = self.add_task_parameters(merged_base, merged_sd,  concat_across_output = merge_config.get('concat_across_output', True))
        return merged_model


class RegMeanMerger(TaskMerger):
    def __init__(self, finetuned_models, pretrained_model, param_handler, device=0, merge_config=None, dataloaders=None):
        super().__init__(
            finetuned_models=finetuned_models,
            pretrained_model=pretrained_model,
            param_handler=param_handler,
            device=device,
            merge_config=merge_config
        )

        self.representation_helper = VectorOps()
        self.dataloaders = dataloaders
        normalize_grams = merge_config.get("normalize_grams")
        self.normalize_grams = True if normalize_grams is not None and normalize_grams else False
        print(f"Normalizing grams? {self.normalize_grams}")

    def hook_handler(self, name):
        def hook_forward(module, inputs, _):
            x = inputs[0].detach().to(self.gram_dtype)
            if len(x.shape) == 3:
                x = x.view(-1, x.size(-1))
            tmp = torch.zeros(x.size(-1), x.size(-1), device=self.device, dtype=self.gram_dtype)
            torch.matmul(x.T, x, out=tmp)
            self.grams[-1][name] = self.grams[-1][name].to(self.device)
            if len(self.grams[-1][name]) == 0:
                self.grams[-1][name] = tmp
            else:
                self.grams[-1][name] += tmp

        return hook_forward

    def merge(self, merge_config={'merge_method': 'regmean'}):
        print(merge_config['merge_method'])
        if merge_config["gram_dtype"] == "torch.float16":
            self.gram_dtype = torch.float16
        elif merge_config["gram_dtype"] == "torch.bfloat16":
            self.gram_dtype = torch.bfloat16
        elif merge_config["gram_dtype"] == "torch.float32":
            self.gram_dtype = torch.float32
        elif merge_config["gram_dtype"] == "torch.float64":
            self.gram_dtype = torch.float64
        else:
            raise ArgumentError

        ftms_relevant_params = [ftm.get_ft_parameters() for ftm in self.ftms_params]

        self.grams = []
        gram_modules = [n for n, m in self.pretrained_model.named_modules() if "base_layer" in n]
        VISION = True if getattr(self.pretrained_model, 'vision_model', None) is not None else False

        # Compute input grams for each model
        all_real_tokens = []
        for t, ftm in enumerate(ftms_relevant_params):
            self.grams.append({key: torch.tensor([], dtype=self.gram_dtype) for key in gram_modules})
            cur_model = deepcopy(self.pretrained_model)
            cur_state_dict = {
                k.replace(".weight", ".base_layer.weight"): v for k, v in ftm.items()
            }
            cur_model.load_state_dict(cur_state_dict, strict=False)
            cur_model = cur_model.to(self.device)
            cur_train_loader = self.dataloaders[t]["train"]["full"]
            hooks = {name: None for name in gram_modules}

            for name, module in cur_model.named_modules():
                if name in gram_modules:
                    hooks[name] = module.register_forward_hook(self.hook_handler(name))

            if merge_config["grams_data_measure"] == "quantity":
                num_samples = merge_config["grams_data_quantity"]
            elif merge_config["grams_data_measure"] == "percentage":
                num_samples = int(round(merge_config["grams_data_perc"] * len(cur_train_loader.dataset)))
            done = False
            tot_samples = 0
            real_tokens = 0
            with torch.no_grad():
                if VISION:
                    for idx, (x, y) in enumerate(tqdm(cur_train_loader, desc="Computing Gram matrices")):
                        if not done:
                            print(f"x.shape: {x.shape}")
                            done = True
                        if tot_samples >= num_samples:
                            break
                        x, y = x.to(self.device), y.to(self.device)
                        cur_model(x)
                        tot_samples += y.shape[0]
                else:
                    for idx, stuff in enumerate(tqdm(cur_train_loader, desc="Computing Gram matrices")):
                        #import ipdb
                        #ipdb.set_trace()
                        input_ids = stuff["input_ids"]
                        attention_masks = stuff["attention_mask"]
                        y = stuff["labels"]
                        if not done:
                            print(f"y.shape: {y.shape}")
                            done = True
                        if tot_samples >= num_samples:
                            break
                        y, input_ids, attention_masks = stuff.values()
                        input_ids, attention_masks = input_ids.to(self.device), attention_masks.to(self.device)
                        cur_model(input_ids, attention_masks)
                        tot_samples += y.shape[0]
                        real_tokens += (attention_masks != 0).sum().item()
            all_real_tokens.append(real_tokens)
            for name, module in cur_model.named_modules():
                if name in gram_modules:
                    alpha = merge_config['alpha_regmean']
                    cur_feats = self.grams[-1][name]
                    eye = torch.eye(cur_feats.shape[-1], dtype=self.gram_dtype, device=self.device)
                    self.grams[-1][name] *= eye + (1 - eye) * alpha
                    self.grams[-1][name] = self.grams[-1][name].cpu()
                    hooks[name].remove()

        cur_model = None
        del cur_model

        # Regmean merge
        if merge_config["inv_gram_dtype"] == "torch.float16":
            self.inv_gram_dtype = torch.float16
        elif merge_config["inv_gram_dtype"] == "torch.bfloat16":
            self.inv_gram_dtype = torch.bfloat16
        elif merge_config["inv_gram_dtype"] == "torch.float32":
            self.inv_gram_dtype = torch.float32
        elif merge_config["inv_gram_dtype"] == "torch.float64":
            self.inv_gram_dtype = torch.float64
        else:
            raise ArgumentError

        regmean_keys = [n for n in self.pretrained_model.state_dict().keys() if "base_layer.weight" in n]
        merged_base = deepcopy(self.pretrained_model)
        sd = merged_base.state_dict()
        for key in tqdm(regmean_keys, desc="Computing new modules"):
            ft_key = key.replace("base_layer.", "")
            gram_key = key.replace(".weight", "")
            for num_model in range(len(self.grams)):
                self.grams[num_model][gram_key] /= all_real_tokens[num_model]
            sd[key] += (
                torch.stack(
                    [
                        ftm[ft_key].to(self.inv_gram_dtype) @ gram[gram_key].to(self.inv_gram_dtype) 
                        for ftm, gram in zip(ftms_relevant_params, self.grams)
                    ]
                ).sum(0)
                @ torch.pinverse(torch.stack([gram[gram_key].to(self.inv_gram_dtype) for gram in self.grams]).sum(0))
            ).to(torch.float32)

        merged_base.load_state_dict(sd)
        return merged_base


class CoMMerger(TaskMerger):
    def __init__(self, finetuned_models, pretrained_model, param_handler, device=0, merge_config=None, dataloaders=None):
        super().__init__(
            finetuned_models=finetuned_models,
            pretrained_model=pretrained_model,
            param_handler=param_handler,
            device=device,
            merge_config=merge_config
        )
        self.normalize_grams = merge_config.get("normalize_grams", True)
        #self.normalize_grams = True if normalize_grams is not None and normalize_grams else False
        #self.normalize_grams = True
        print(f"Normalizing grams? {self.normalize_grams}")
        self.use_real_inverse = merge_config.get("use_real_inverse", False)
        if self.use_real_inverse:
            self.inverse_function = torch.inverse
        else:
            self.inverse_function = torch.pinverse
        self.use_cupy = merge_config.get("use_cupy", False)
        if self.use_cupy:
            import cupy as cp
            if self.use_real_inverse:
                self.inverse_function = cp.linalg.inv
            else:
                self.inverse_function = cp.linalg.pinv
        self.to_torch = lambda x : x if not self.use_cupy else torch.tensor
        self.to_cupy = (lambda x: cp.array(x)) if self.use_cupy else (lambda x: x)

        if self.use_real_inverse:
            print("Using real inverse matrix", end=" ")
        else:
            print("Using (pseudo) inverse matrix", end=" ")
        if self.use_cupy:
            print("with CuPy.")
        else:
            print("with PyTorch.")
        self.representation_helper = VectorOps()
        self.dataloaders = dataloaders

    def hook_handler(self, name, VISION=True, diagonal=False):
        def hook_forward(module, inputs, _):
            x = inputs[0].detach().to(self.gram_dtype)
            StaticPrinter.static_print(**{"VISION" : VISION, "diagonal" : diagonal, "len(inputs)" : len(inputs), "len(inputs[1])" : len(inputs[1]) if len(inputs) > 1 else "no len"})
            if not VISION and len(inputs) > 1 and isinstance(inputs[1], torch.Tensor):
                pad_tokens = inputs[1] #list containing the number of pad tokens
                real_tokens = x.shape[1] - pad_tokens
                feature_dim = x.shape[2]
                self.num_tokens[-1] += real_tokens.sum().item()

                x = torch.cat([x[i, :real_tokens[i]].reshape(-1, feature_dim) for i in range(len(real_tokens))], dim=0)

            if len(x.shape) == 3:
                x = x.view(-1, x.size(-1))
            x = x / x.norm(dim=-1, keepdim=True)
            if diagonal:
                tmp = (x ** 2).sum(dim=0)
            else:
                tmp = torch.zeros(x.size(-1), x.size(-1), device=self.device, dtype=self.gram_dtype)
                torch.matmul(x.T, x, out=tmp)
            self.grams[-1] = self.grams[-1].to(self.device)
            if len(self.grams[-1]) == 0:
                self.grams[-1] = tmp
            else:
                self.grams[-1] += tmp

        return hook_forward

    @staticmethod
    def forward_self_attn(self_attn, x, proj_idx):
        if proj_idx == 0:
            B, N, D = x.size()

            proj_shape = (B * self_attn.num_heads, -1, self_attn.head_dim)
            q = self_attn._shape(self_attn.q_proj(x), -1, B).view(*proj_shape)
            k = self_attn._shape(self_attn.k_proj(x), -1, B).view(*proj_shape)
            v = self_attn._shape(self_attn.v_proj(x), -1, B).view(*proj_shape)

            attn_weights = F.softmax(q @ k.transpose(1, 2) * self_attn.scale, dim=-1)
            attn_probs = F.dropout(attn_weights, p=self_attn.dropout, training=self_attn.training)
            x = attn_probs @ v
            x = x.view(B, self_attn.num_heads, N, self_attn.head_dim).transpose(1, 2).reshape(B, N, D)
        else:
            x = self_attn.out_proj(x)

        return x

    @staticmethod
    def rotate_half(x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    @staticmethod
    def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
        """Applies Rotary Position Embedding to the query and key tensors.

        Args:
            q (`torch.Tensor`): The query tensor.
            k (`torch.Tensor`): The key tensor.
            cos (`torch.Tensor`): The cosine part of the rotary embedding.
            sin (`torch.Tensor`): The sine part of the rotary embedding.
            position_ids (`torch.Tensor`, *optional*):
                Deprecated and unused.
            unsqueeze_dim (`int`, *optional*, defaults to 1):
                The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
                sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
                that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
                k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
                cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
                the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
        Returns:
            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
        """
        cos = cos.unsqueeze(unsqueeze_dim)
        sin = sin.unsqueeze(unsqueeze_dim)
        q_embed = (q * cos) + (CoMMerger.rotate_half(q) * sin)
        k_embed = (k * cos) + (CoMMerger.rotate_half(k) * sin)
        return q_embed, k_embed

    @staticmethod
    def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
        """
        This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
        num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
        """
        batch, num_key_value_heads, slen, head_dim = hidden_states.shape
        if n_rep == 1:
            return hidden_states
        hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
        return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

    
    def forward_self_attn_llama(
        self_attn,
        #hidden_states: torch.Tensor,
        #position_embeddings = tuple[torch.Tensor, torch.Tensor],
        #attention_mask: Optional[torch.Tensor] = None,
        #past_key_value = None,
        #cache_position = None,
        **kwargs,
    ):
        do_out_proj = kwargs.get('do_out_proj', True)
        pad_tokens = kwargs.get('pad_tokens', 0)
        hidden_states = kwargs.get('hidden_states')
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self_attn.head_dim)
        past_key_value = kwargs.get('past_key_value', None)
        if not do_out_proj:
            position_embeddings = kwargs.get('position_embeddings', None)
            attention_mask = kwargs.get('attention_mask', None)
            kwargs.pop('attention_mask')
            cache_position = kwargs.get('cache_position', None)

            query_states = self_attn.q_proj(hidden_states, pad_tokens).view(hidden_shape).transpose(1, 2)
            key_states = self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            value_states = self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)


            #TODO: controllare se è corretto che si triggeri l'if
            if position_embeddings is None:
                cos, sin = self_attn.rotary_emb(value_states, kwargs['position_ids'])
            else:
                cos, sin = position_embeddings
            query_states, key_states = CoMMerger.apply_rotary_pos_emb(query_states, key_states, cos, sin)

            if past_key_value is not None:
                # sin and cos are specific to RoPE models; cache_position needed for the static cache
                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
                key_states, value_states = past_key_value.update(key_states, value_states, self_attn.layer_idx, cache_kwargs)

            attn_output, _ = sdpa_attention_forward(
                self_attn,
                query_states,
                key_states,
                value_states,
                attention_mask,
                dropout=0.0 if not self_attn.training else self_attn.attention_dropout,
                scaling=self_attn.scaling,
                **kwargs,
            )
        else: #do out proj only
            #attn_output = kwargs.get('attn_output', None)
            attn_output = hidden_states
            assert attn_output is not None
            input_shape = input_shape[:-1]
            attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            attn_output = self_attn.o_proj(attn_output, pad_tokens)
        return attn_output, past_key_value


    @staticmethod
    def forward_block(block, x, residual, proj_idx):
        if proj_idx == 0:
            residual = x
            x = block.layer_norm1(x)
        x = CoMMerger.forward_self_attn(block.self_attn, x, proj_idx=proj_idx)
        if proj_idx == 1:
            x = residual + x
            residual = x
            x = block.layer_norm2(x)
            x = block.mlp(x)
            x = residual + x
        return x, residual

    @staticmethod
    def forward_layer_llama(layer, mlp_forward = False, hidden_states=None, attention_mask=None, residuals=None, position_ids=None, past_key_values=None, 
                            output_attentions=None, use_cache=None, cache_position=None, position_embeddings=None, pad_tokens=0):
        ret = {}
        #print(f"\nmlp_forward = {mlp_forward}")
        if not mlp_forward:
            residual = hidden_states
            hidden_states = layer.input_layernorm(hidden_states)

        forward_self_attn_kwargs = {
            "self_attn": layer.self_attn,
            "hidden_states": hidden_states,
            "position_embeddings": position_embeddings,
            "attention_mask": attention_mask,
            #"position_ids": position_ids,
            "past_key_value": past_key_values,
            #"output_attentions": output_attentions,
            #"use_cache": use_cache,
            "cache_position": cache_position,
            "pad_tokens": pad_tokens,
            "do_out_proj": mlp_forward,  # if True, only do out_proj
        }

        hidden_states, past_key_values = CoMMerger.forward_self_attn_llama(
            **forward_self_attn_kwargs
        )

        if mlp_forward:
            residual = residuals.to(hidden_states.dtype)
            hidden_states = residual + hidden_states
            # Fully Connected
            residual = hidden_states
            hidden_states = layer.post_attention_layernorm(hidden_states)
            hidden_states = layer.mlp(hidden_states)
            hidden_states = residual + hidden_states
        ret['hidden_states'] = hidden_states
        ret['attention_mask'] = attention_mask
        ret['position_ids'] = position_ids
        ret['past_key_values'] = past_key_values
        ret['output_attentions'] = output_attentions
        ret['use_cache'] = use_cache
        ret['cache_position'] = cache_position
        ret['position_embeddings'] = position_embeddings
        ret['residuals'] = residual
        return ret

    @staticmethod
    def clip_forward(model, stuff, start_layer_idx, end_layer_idx):
        x = stuff["inputs"].to(model.device)
        residual= stuff["residual"].to(model.device)
        if hasattr(x, "pixel_values"):
            x = x.pixel_values.squeeze(1)
        vision_model = model.vision_model.model
        ret = None
        if start_layer_idx == -1:
            x = vision_model.pre_layrnorm(vision_model.embeddings(x))
            ret = {"inputs": x, "residual": residual}

        layer_idx = max(start_layer_idx, 0)
        while layer_idx < end_layer_idx:
            block = vision_model.encoder.layers[layer_idx // 2]
            x, residual = CoMMerger.forward_block(block, x, residual, proj_idx=layer_idx % 2)
            if layer_idx + 2 == end_layer_idx:
                ret = {"inputs": x, "residual": residual}
            layer_idx += 1
        return ret

    @staticmethod
    def llama_block_forward(model, input_ids, attention_mask, residuals, position_ids, past_key_values, use_cache, inputs_embeds = None,
                            output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, position_embeddings = None, start_layer_idx=0,
                            end_layer_idx=-1, pad_tokens=0, device="cuda", return_device="cpu"):
        ret = {}
        ret['hidden_states'] = input_ids
        ret['attention_mask'] = attention_mask
        ret['residuals'] = residuals
        ret['position_ids'] = position_ids
        ret['past_key_values'] = past_key_values
        #ret['inputs_embeds'] = inputs_embeds
        ret['use_cache'] = use_cache
        ret['position_embeddings'] = position_embeddings
        ret['pad_tokens'] = pad_tokens
        for key in ret:
            if isinstance(ret[key], torch.Tensor):
                ret[key] = ret[key].to(device)
        if start_layer_idx == -1:
            #import pdb
            #pdb.set_trace()
            output_attentions = output_attentions if output_attentions is not None else model.model.config.output_attentions
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else model.model.config.output_hidden_states
            )
            use_cache = use_cache if use_cache is not None else model.model.config.use_cache
            return_dict = return_dict if return_dict is not None else model.model.config.use_return_dict

            if (input_ids is None) ^ (inputs_embeds is not None):
                raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

            if inputs_embeds is None:
                inputs_embeds = model.model.model.embed_tokens(input_ids)

            if cache_position is None:
                past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
                cache_position = torch.arange(
                    past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
                )
            if position_ids is None:
                position_ids = cache_position.unsqueeze(0)

            causal_mask = create_causal_mask(
                config=model.model.model.config,
                input_embeds=inputs_embeds,
                attention_mask=attention_mask,
                cache_position=cache_position,
                past_key_values=past_key_values,
                position_ids=position_ids,
            )
            if causal_mask is None:
                causal_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[1], dtype=torch.bool).to(attention_mask.device)
            hidden_states = inputs_embeds
            position_embeddings = model.model.model.rotary_emb(hidden_states, position_ids)
            #ret = hidden_states
            ret = {}
            ret['hidden_states'] = hidden_states
            #ret['attention_mask'] = causal_mask if causal_mask is not None else attention_mask
            ret['attention_mask'] = causal_mask
            ret['residuals'] = residuals
            ret['position_ids'] = position_ids
            ret['past_key_values'] = past_key_values
            #ret['inputs_embeds'] = inputs_embeds
            ret['use_cache'] = use_cache
            ret['position_embeddings'] = position_embeddings
            ret['pad_tokens'] = pad_tokens
            #ret['output_attentions'] = output_attentions
            #ret['output_hidden_states'] = output_hidden_states
            #ret['return_dict'] = return_dict
            #ret['cache_position'] = cache_position
            #ret['past_key_values'] = past_key_values

        layer_idx = max(start_layer_idx, 0)
        stuff = {k: v for k, v in ret.items()}
        while layer_idx < end_layer_idx:
            block = model.model.model.layers[layer_idx // 2]
            x = CoMMerger.forward_layer_llama(layer=block, mlp_forward=layer_idx % 2,hidden_states=stuff['hidden_states'],
                                            attention_mask=stuff["attention_mask"], residuals=stuff["residuals"], position_ids=stuff["position_ids"],
                                            past_key_values=stuff["past_key_values"], use_cache=stuff["use_cache"],
                                            position_embeddings=stuff["position_embeddings"], pad_tokens=stuff["pad_tokens"])

            for key, value in x.items():
                if key in stuff:
                    stuff[key] = value
            if layer_idx + 2 == end_layer_idx: # so if it's the 1st out of the 2 iterations
                ret['hidden_states'] = x['hidden_states'] #which actually is the output of the block
                ret['attention_mask'] = x['attention_mask'] #which actually is the causal mask (same as attention_mask, but with the correct shape for each layer)
                ret['residuals'] = x['residuals']
                ret['position_ids'] = x['position_ids']
                ret['past_key_values'] = x['past_key_values']
                ret['use_cache'] = x['use_cache']
                ret['position_embeddings'] = x['position_embeddings']
                #ret['output_attentions'] = x['output_attentions']
                #ret['cache_position'] = x['cache_position']
            layer_idx += 1
        if ret is None:
            raise ValueError("No output returned from the specified block index.")
        if "cuda" not in return_device:
            for key in ret:
                if isinstance(ret[key], torch.Tensor):
                    ret[key] = ret[key].to(return_device)
            torch.cuda.empty_cache()
        return ret

    @torch.no_grad()
    def compare_with_model(self, model, cur_stuff, **kwargs):
        for k in cur_stuff.keys():
            if isinstance(cur_stuff[k], torch.Tensor):
                cur_stuff[k] = cur_stuff[k].to(self.device)
        hidden_states = deepcopy(cur_stuff.get("hidden_states"))
        attention_mask = deepcopy(cur_stuff.get("attention_mask"))
        position_ids = deepcopy(cur_stuff.get("position_ids"))
        past_key_values = deepcopy(cur_stuff.get("past_key_values"))
        cache_position = deepcopy(cur_stuff.get("cache_position"))
        position_embeddings = deepcopy(cur_stuff.get("position_embeddings"))
        real_hidden_states = [hidden_states]
        #import ipdb
        #ipdb.set_trace()
        for decoder_layer in tqdm(model.model.model.layers, desc="Original forward"):
            hidden_states = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **kwargs,
            )
            if isinstance(hidden_states, tuple):
                hidden_states = hidden_states[0]
            if "cuda" not in str(hidden_states.device):
                print("Dio cane GPU esplosa")
            real_hidden_states.append(hidden_states)
        transformer_blocks = model.model.model.layers
        my_hidden_states = [cur_stuff['hidden_states']]
        for layer_idx in tqdm(range(len(transformer_blocks) * 2), desc="My forward"):
            cur_stuff = self.llama_block_forward(
                model,
                input_ids=cur_stuff['hidden_states'].to(self.device),
                attention_mask=cur_stuff['attention_mask'].to(self.device),
                residuals=cur_stuff['residuals'].to(self.device),
                position_ids=cur_stuff['position_ids'],
                past_key_values=cur_stuff['past_key_values'],
                use_cache=cur_stuff['use_cache'],
                inputs_embeds=cur_stuff.get("inputs_embeds", None),
                output_attentions=cur_stuff.get('output_attentions', None),
                output_hidden_states=cur_stuff.get('output_hidden_states', None),
                return_dict=cur_stuff.get('return_dict', None),
                cache_position=cur_stuff.get('cache_position', None),
                position_embeddings=cur_stuff.get('position_embeddings', None),
                start_layer_idx=max(layer_idx - 1, 0),
                end_layer_idx=layer_idx + 1,
                pad_tokens=cur_stuff.get('pad_tokens', None),
                device=self.device,
                return_device="cpu",
            )
            if "cuda" not in str(cur_stuff['hidden_states'].device):
                print("Dio cane GPU esplosa")
            my_hidden_states.append(cur_stuff['hidden_states'])
        return real_hidden_states, my_hidden_states

    @torch.no_grad()
    def merge(self, merge_config={'merge_method': 'regmean'}):
        print(merge_config['merge_method'])
        if merge_config["gram_dtype"] == "torch.float16":
            self.gram_dtype = torch.float16
        if merge_config["gram_dtype"] == "torch.bfloat16":
            self.gram_dtype = torch.bfloat16
        elif merge_config["gram_dtype"] == "torch.float32":
            self.gram_dtype = torch.float32
        elif merge_config["gram_dtype"] == "torch.float64":
            self.gram_dtype = torch.float64
        else:
            raise ArgumentError

        ftms_relevant_params = [ftm.get_ft_parameters() for ftm in self.ftms_params] #only delta parameters
        VISION = True if getattr(self.pretrained_model, 'vision_model', None) is not None else False

        if VISION:
            if merge_config["cache_data"]:
                all_stuff = []
                self.pretrained_model.to(self.device)
                print("Cacheing data...")
                for t in tqdm(range(len(self.dataloaders))):
                    all_stuff.append([])
                    if merge_config["merge_split"] == "val":
                        cur_train_loader = self.dataloaders[t]["test"]["val"]
                    else:
                        cur_train_loader = self.dataloaders[t]["train"]["full"]

                    if merge_config["grams_data_measure"] == "quantity":
                        num_samples = merge_config["grams_data_quantity"]
                    elif merge_config["grams_data_measure"] == "percentage":
                        num_samples = int(round(merge_config["grams_data_perc"] * len(cur_train_loader.dataset)))
                    else:
                        raise ArgumentError

                    num_classes = [200, 47, 10, 43, 10, 45, 397, 10]
                    counts = defaultdict(int)
                    num_cls = num_classes[t]
                    base, rem = divmod(num_samples, num_cls)
                    classes = random.sample(range(num_cls), num_cls)
                    targets = {c: base + (1 if i < rem else 0) for i, c in enumerate(classes)}
                    total = 0

                    for (x, y) in cur_train_loader:
                        x = x.pixel_values.squeeze(1)
                        y = y.tolist()
                        if merge_config["grams_data_distribution"] == "class-wise":
                            keep = [i for i, label in enumerate(y) if counts[label] < targets[label]]
                        else:
                            keep = list(range(len(x)))
                        if not keep:
                            continue
                        for i in keep:
                            if total >= num_samples:
                                break
                            lbl = y[i]
                            counts[lbl] += 1
                            total += 1
                            xi = x[i].unsqueeze(0)

                            blocks_input = self.clip_forward(self.pretrained_model, {"inputs": xi, "residual": torch.tensor([0.])}, start_layer_idx=-1, end_layer_idx=0)
                            all_stuff[-1].append(blocks_input)
                        if total >= num_samples:
                            break

                    tmp_all_stuff = {}
                    vision_keys = list(all_stuff[-1][0].keys())
                    for key in vision_keys:
                        cur_batches = [all_stuff[-1][i][key] for i in range(len(all_stuff[-1]))]
                        tmp_all_stuff[key] = list(torch.cat(cur_batches)[:num_samples].split(cur_train_loader.batch_size))

                    all_stuff[-1] = []
                    for batch_idx in range(len(tmp_all_stuff[vision_keys[0]])):
                        cur_values = [tmp_all_stuff[key][batch_idx] for key in vision_keys]
                        all_stuff[-1].append({k: v for k, v in zip(vision_keys, cur_values)})

        else:
            # TEXT_KEYS = ['hidden_states', 'attention_mask', 'position_ids', 'past_key_values', 'inputs_embeds', 'use_cache', 'output_attentions', 'output_hidden_states', 'return_dict', 'cache_position', 'past_key_values']
            if merge_config["cache_data"]:
                all_stuff = []
                self.pretrained_model.to(self.device)
                print("Cacheing data...")
                for t in tqdm(range(len(self.dataloaders))):
                    all_stuff.append([])
                    if merge_config["merge_split"] == "val":
                        cur_train_loader = self.dataloaders[t]["test"]["val"]
                    else:
                        cur_train_loader = self.dataloaders[t]["train"]["full"]

                    if merge_config["grams_data_measure"] == "quantity":
                        num_samples = merge_config["grams_data_quantity"]
                    elif merge_config["grams_data_measure"] == "percentage":
                        num_samples = int(round(merge_config["grams_data_perc"] * len(cur_train_loader.dataset)))
                    else:
                        raise ArgumentError

                    if merge_config["grams_data_distribution"] == "class-wise":
                        num_classes = [3, 3, 2, 2, 2, 2]
                        counts = defaultdict(int)
                        num_cls = num_classes[t]
                        base, rem = divmod(num_samples, num_cls)
                        classes = torch.from_numpy(cur_train_loader.dataset.data["labels"].to_numpy().copy()).unique().tolist()
                        targets = {c: base + (1 if i < rem else 0) for i, c in enumerate(classes)}
                        total = 0
                        max_len = 0
                        input_ids_list = []
                        attention_masks_list = []

                        for stuff in cur_train_loader:
                            y, input_ids, attention_masks = stuff.values()
                            y = y.tolist()
                            keep = [i for i, label in enumerate(y) if counts[label] < targets[label]]
                            if not keep:
                                continue
                            for i in keep:
                                if total >= num_samples:
                                    break
                                lbl = y[i]
                                counts[lbl] += 1
                                total += 1

                                input_id = input_ids[i].unsqueeze(0).to(self.device)
                                if input_ids.shape[-1] > max_len:
                                    max_len = input_id.shape[-1]
                                attention_mask = attention_masks[i].unsqueeze(0).to(self.device)
                                input_ids_list.append(input_id)
                                attention_masks_list.append(attention_mask)

                            if total >= num_samples:
                                break
                        
                        print(f"Number of samples: {len(input_ids_list)}, {len(attention_masks_list)}")
                        tot_real_tokens = 0
                        tot_pad_tokens = 0
                        for input_id, attention_mask in zip(input_ids_list, attention_masks_list):
                            pad_tokens = max_len - input_id.shape[-1]
                            input_id = F.pad(input_id, (0, pad_tokens), value=0)
                            attention_mask = F.pad(attention_mask, (0, pad_tokens), value=0)
                            ret = self.llama_block_forward(
                                    self.pretrained_model,
                                    input_ids=input_id.to(self.device),
                                    attention_mask=attention_mask.to(self.device),
                                    residuals=None,
                                    position_ids=None,
                                    past_key_values=None,
                                    inputs_embeds=None,
                                    use_cache=False,
                                    output_attentions=True,
                                    output_hidden_states=False,
                                    return_dict=False,
                                    cache_position=None,
                                    position_embeddings=None,
                                    start_layer_idx=-1,
                                    end_layer_idx=0,
                                    #return_output_idx=-1,
                                    device=self.device,
                                    return_device="cpu"
                                )
                            ret["pad_tokens"] = torch.tensor(pad_tokens)
                            ret["residuals"] = torch.tensor([0.])
                            all_stuff[-1].append(ret)
                            tot_real_tokens += input_id.shape[-1]
                            tot_pad_tokens += pad_tokens

                        tmp_all_stuff = {}
                        keys = list(all_stuff[-1][0].keys())
                        #to_be_reshaped_keys = ['hidden_states', 'attention_mask', 'position_ids', 'past_key_values', 'inputs_embeds']
                        tensorial_keys = ['hidden_states', 'attention_mask', 'position_ids', 'inputs_embeds', 'pad_tokens']
                        for key in keys:
                            cur_batches = [all_stuff[-1][i][key] for i in range(len(all_stuff[-1]))]
                            if key in tensorial_keys:
                                if 'pad_tokens' in key:
                                    tmp_all_stuff[key] = list(torch.tensor(cur_batches).split(cur_train_loader.batch_size))
                                else:
                                    tmp_all_stuff[key] = list(torch.cat(cur_batches).split(cur_train_loader.batch_size))
                            else:
                                tmp_all_stuff[key] = cur_batches

                        all_stuff[-1] = []
                        for batch_idx in range(len(tmp_all_stuff[keys[0]])):
                            cur_values = [tmp_all_stuff[key][batch_idx] for key in keys]
                            all_stuff[-1].append({k: v for k, v in zip(keys, cur_values)})
                    else:
                        raise ArgumentError

        self.pretrained_model.to("cpu")
        merged_base = deepcopy(self.pretrained_model)
        sd = self.pretrained_model.state_dict()
        transformer_blocks = self.pretrained_model.vision_model.encoder.layers if VISION else merged_base.model.model.layers
        diagonal = merge_config["alpha_regmean"] == 0.0
        if not VISION:
            for cur_transformer_block in transformer_blocks:
                cur_transformer_block.self_attn.q_proj = CustomLinear(
                        cur_transformer_block.self_attn.q_proj
                    )
                cur_transformer_block.self_attn.o_proj = CustomLinear(
                        cur_transformer_block.self_attn.o_proj
                    )

                
        print("Merging layers...")
        cur_model = deepcopy(merged_base)
        for layer_idx in tqdm(range(len(transformer_blocks) * 2)):
            # Compute input grams for each layer of each model
            block_idx = layer_idx // 2
            gram_modules = [n for n, _ in transformer_blocks[block_idx].named_modules() \
                            if "base_layer" in n and (int("out_proj" in n or "o_proj" in n) == (layer_idx % 2 == 1))]
            original_gram_modules = deepcopy(gram_modules)
            if not VISION:
                for i in range(len(gram_modules)):
                    #print(gram_modules[i])
                    gram_modules[i] = gram_modules[i].replace(".base_layer", "")
                    #print(gram_modules[i])
            self.grams = []
            if getattr(self, "num_tokens", None) is None:
                setattr(self, "num_tokens", [])
            self.num_tokens = []
            for t, ftm in enumerate(ftms_relevant_params):
                self.grams.append(torch.tensor([], dtype=self.gram_dtype))
                self.num_tokens.append(0)
                cur_state_dict = {
                    k.replace(".weight", ".base_layer.weight"): v
                    for k, v in ftm.items()
                }
                cur_model.load_state_dict(cur_state_dict, strict=False) # load only the delta parameters
                cur_model = cur_model.to(self.device)

                cur_transformer_block = cur_model.vision_model.encoder.layers[block_idx] if VISION else cur_model.model.model.layers[block_idx]
                cur_train_loader = self.dataloaders[t]["train"]["full"]
                hook = None

                for name, module in cur_transformer_block.named_modules():
                    if name == gram_modules[0]:  # just the first one to save computation
                        hook = module.register_forward_hook(self.hook_handler(name, VISION, diagonal))

                if merge_config["cache_data"]:
                    start_fetch = time()
                    for cache_idx, cur_stuff in enumerate(all_stuff[t]):
                        if VISION:
                            cur_output_stuff = self.clip_forward(cur_model, cur_stuff,
                                                                start_layer_idx=max(layer_idx - 1, 0),
                                                                end_layer_idx=layer_idx + 1)
                            if cur_output_stuff is not None:
                                cur_output_stuff = {k: v.cpu() for k, v in cur_output_stuff.items()}
                                all_stuff[t][cache_idx] = cur_output_stuff
                        else:
                            #print(f"Processing {cache_idx} of {len(all_stuff[ti])} for block {block_idx}")
                            pre_forward_time = time()
                            #print(f"Time fetch cache: {pre_forward_time - start_fetch}")
                            #print(f"Time pre forw: {pre_forward_time - start_time_model}")
                            #real_hidden, my_hidden = self.compare_with_model(cur_model, cur_stuff)
                            #import ipdb
                            #ipdb.set_trace()
                            all_stuff[t][cache_idx] = self.llama_block_forward(
                                cur_model,
                                input_ids=cur_stuff['hidden_states'].to(self.device),
                                attention_mask=cur_stuff['attention_mask'].to(self.device),
                                residuals=cur_stuff['residuals'].to(self.device),
                                position_ids=cur_stuff['position_ids'],
                                past_key_values=cur_stuff['past_key_values'],
                                use_cache=cur_stuff['use_cache'],
                                inputs_embeds=cur_stuff.get("inputs_embeds", None),
                                output_attentions=cur_stuff.get('output_attentions', None),
                                output_hidden_states=cur_stuff.get('output_hidden_states', None),
                                return_dict=cur_stuff.get('return_dict', None),
                                cache_position=cur_stuff.get('cache_position', None),
                                position_embeddings=cur_stuff.get('position_embeddings', None),
                                start_layer_idx=max(layer_idx - 1, 0),
                                end_layer_idx=layer_idx + 1,
                                pad_tokens=cur_stuff.get('pad_tokens', None),
                                #return_output_idx=block_idx - 1 if prev_mlp_forward else -1
                                device=self.device,
                                return_device="cpu",
                            )
                            #cur_output = all_stuff[t][cache_idx]['hidden_states']
                            #all_stuff[t][cache_idx] = cur_output.cpu()
                else:
                    num_samples = merge_config["grams_data_quantity"]
                    for id, (x, y) in enumerate(cur_train_loader): # enumerate(tqdm(cur_train_loader, desc="Computing Gram matrices")):
                        if id / len(cur_train_loader) > merge_config["grams_data_perc"]:
                            break
                        x = x.pixel_values.squeeze(1)
                        if num_samples < 0:
                            break
                        num_samples -= len(x)
                        if num_samples < 0:
                            x = x[:-num_samples]
                        x = x.to(self.device)
                        blocks_input = self.clip_forward(self.pretrained_model, {"inputs": xi, "residual": torch.tensor([0.])}, start_layer_idx=-1, end_layer_idx=layer_idx + 1)[0]

                for name, module in cur_transformer_block.named_modules():
                    if name == gram_modules[0]:
                        alpha = merge_config['alpha_regmean']
                        cur_feats = self.grams[-1]
                        eye = torch.eye(cur_feats.shape[-1], dtype=self.gram_dtype, device=self.device)
                        if not diagonal:
                            self.grams[-1] *= eye + (1 - eye) * alpha
                        self.grams[-1] = self.grams[-1].cpu()
                        hook.remove()
                end_time_model = time()
                #print(f"Total time model: {end_time_model - start_time_model}")

            # Regmean merge
            if merge_config["inv_gram_dtype"] == "torch.float16":
                self.inv_gram_dtype = torch.float16
            if merge_config["inv_gram_dtype"] == "torch.bfloat16":
                self.inv_gram_dtype = torch.bfloat16
            elif merge_config["inv_gram_dtype"] == "torch.float32":
                self.inv_gram_dtype = torch.float32
            elif merge_config["inv_gram_dtype"] == "torch.float64":
                self.inv_gram_dtype = torch.float64
            else:
                raise ArgumentError

            task_weights = torch.tensor([gram[~torch.eye(gram.shape[0], dtype=bool, device=gram.device)].abs().mean() for gram in self.grams])
            task_weights = (task_weights - task_weights.min() + 0.1) / (task_weights.max() - task_weights.min() + 0.1)
            task_weights /= task_weights.sum()
            if layer_idx:
                self.task_weights = (self.task_weights * (layer_idx) + task_weights) / (layer_idx + 1)
            else:
                self.task_weights = task_weights

            if not merge_config["weigh_tasks_by_grams"] or diagonal:
                print("Not using task_weights!")
                task_weights = torch.ones(len(self.grams))

            sd = merged_base.state_dict()
            if self.normalize_grams:
                total_tokens = sum(self.num_tokens)
                self.grams = [gram / local_tokens for gram, local_tokens in zip(self.grams, self.num_tokens)]
            base_name = f"vision_model.base_model.model.encoder.layers.{block_idx}" if VISION else f"base_model.model.model.layers.{block_idx}"

            for gram_key in original_gram_modules:
                ft_key = f"{base_name}.{gram_key.replace('.base_layer', '.weight')}"
                sd_key = f"{base_name}.{gram_key}.weight"
                if not diagonal:
                    sd[sd_key] += (
                        torch.stack(
                            [
                                ftm[ft_key].to(self.inv_gram_dtype) @ (gram * tw).to(self.inv_gram_dtype)
                                for ftm, gram, tw in zip(ftms_relevant_params, self.grams, task_weights)
                            ]
                        ).sum(0)
                        @ self.to_torch(self.inverse_function(self.to_cupy(torch.stack([(gram * tw).to(self.inv_gram_dtype) for gram, tw in zip(self.grams, task_weights)])).sum(0)))
                    ).to(torch.float32) * merge_config["scaling_coeffs"] # *0.38 if in the original joint setting
                else:
                    sd[sd_key] += (
                        torch.stack(
                            [
                                ftm[ft_key].to(self.inv_gram_dtype) @ (((gram * 1 + 0) * tw).to(self.inv_gram_dtype)).diag()
                                for ftm, gram, tw in zip(ftms_relevant_params, self.grams, task_weights)
                            ]
                        ).sum(0)
                        @ (1 / (torch.stack([((gram * 1 + 0) * tw).to(self.inv_gram_dtype) for gram, tw in zip(self.grams, task_weights)]).sum(0) + 1e-8)).diag()
                    ).to(torch.float32) * merge_config["scaling_coeffs"] # *0.38 if in the original joint setting
                for t, ftm in enumerate(ftms_relevant_params):
                    ftm[ft_key] = sd[sd_key]

        merged_base.load_state_dict(sd)
        print(self.task_weights)
        return merged_base


class ISOMerger(TaskMerger):
    def __init__(self, finetuned_models, pretrained_model, param_handler, device=0, merge_config=None, dataloaders=None):
        super().__init__(
            finetuned_models=finetuned_models,
            pretrained_model=pretrained_model,
            param_handler=param_handler,
            device=device,
            merge_config=merge_config
        )

        self.representation_helper = VectorOps()
        self.dataloaders = dataloaders

    def merge(self, merge_config={'merge_method': 'regmean'}):
        print(merge_config['merge_method'])
        ftms_relevant_params = [ftm.get_ft_parameters() for ftm in self.ftms_params]

        sd = self.pretrained_model.state_dict()

        for key in ftms_relevant_params[0].keys():
            merge = torch.sum(torch.stack([ftm[key] for ftm in ftms_relevant_params]), dim=0)
            U, S, V = torch.linalg.svd(merge.type(torch.float64))
            sd[key.replace(".weight", ".base_layer.weight")] += (U @ V * S.mean()).type(torch.float32) * merge_config["scaling_coeffs"]

        self.pretrained_model.load_state_dict(sd)
        return self.pretrained_model    
            

class ConsensusMerger(TaskMerger):
    def __init__(self, finetuned_models, pretrained_model, param_handler, device=0, merge_config=None, dataloaders=None):
        super().__init__(
            finetuned_models=finetuned_models,
            pretrained_model=pretrained_model,
            param_handler=param_handler,
            device=device,
            merge_config=merge_config
        )

        self.representation_helper = VectorOps()
        self.dataloaders = dataloaders

    def merge(self, merge_config={'merge_method': 'consensus'}):
        print(merge_config['merge_method'])
        ftms_relevant_params = [ftm.get_ft_parameters() for ftm in self.ftms_params]
        #pretrained model's keys (i.e. they have '.base_layer')
        sd_keys = [n for n in self.pretrained_model.state_dict() if "base_layer.weight" in n]
        #finetuned model's keys (i.e. they don't have '.base_layer')
        keys = [key.replace('.base_layer', '') for key in sd_keys]
        all_tvs = torch.stack([torch.cat([ftm[key].flatten() for key in keys], dim=0) for ftm in ftms_relevant_params], dim=0).to("cpu")
        multi_task_tv = torch.sum(all_tvs, dim=0)
        #single_masks = torch.zeros_like(all_tvs, device=self.device, dtype=torch.bool)
        single_masks = torch.where(torch.abs(all_tvs) >= torch.abs(multi_task_tv - all_tvs) * merge_config["tall_threshold"], True, False) #shape [ftm_num, param_num]
        summed_consensus = single_masks.to("cpu").sum(0).to("cuda") #shape [param_num]
        multi_task_tv = multi_task_tv.to(self.device)
        multi_mask = torch.where(summed_consensus >= merge_config["consensus_threshold"], True, False) #shape [param_num]
        print(f"{multi_mask.sum()} positive out of {multi_mask.flatten().shape}: {multi_mask.sum() / multi_mask.flatten().shape[0]}")
        final_tv = multi_mask * multi_task_tv # shape [param_num]
        print(f"Total sum of elements of final_tv: {final_tv.sum()}")
        device = self.pretrained_model.device if hasattr(self.pretrained_model, 'device') else next(iter(self.pretrained_model.parameters())).device

        final_tv_dict = {}
        tmp = 0
        #import ipdb
        user_stop = False
        for key in keys:
            #if not user_stop:
            #    #ipdb.set_trace()
            #    response = input("Want to stop future ipdb? [Y|n]")
            #    if "y" in response.lower():
            #        user_stop = True
            try:
                final_tv_dict[key.replace(".weight", ".base_layer.weight")] = final_tv[tmp:tmp + ftms_relevant_params[0][key].numel()].view(ftms_relevant_params[0][key].shape).to(device)
                #final_tv_dict[key.replace(".weight", ".base_layer.weight")] = torch.zeros_like(ftms_relevant_params[0][key]).to(device)
            except:
                print(f"final_tv.shape:\t{final_tv.shape}")
                print(f'key:\t{key},{key.replace(".weight", ".base_layer.weight")}, tmp:\t{tmp}, ftms_relevant_params[0][key].numel():\t{ftms_relevant_params[0][key].numel()}, ftms_relevant_params[0][key].shape:\t{ftms_relevant_params[0][key].shape}, final_tv_dict[key.replace(".weight", ".base_layer.weight")].shape:\t{final_tv_dict[key.replace(".weight", ".base_layer.weight")].shape}')
            tmp += ftms_relevant_params[0][key].numel()

        merged_base = deepcopy(self.pretrained_model)
        merged_base = merged_base.to(device)
        sd = merged_base.state_dict()
        #merged_model = self.add_task_parameters(merged_base, final_tv_dict, scaling_coeffs=merge_config["scaling_coeffs"])
        for key in final_tv_dict.keys():
            sd[key] += final_tv_dict[key] * merge_config["scaling_coeffs"]
        
        merged_base.load_state_dict(sd)

        return merged_base
    
class LinesMerger(TaskMerger):
    def __init__(self, finetuned_models, pretrained_model, param_handler, device=0, merge_config=None, dataloaders=None):
        super().__init__(
            finetuned_models=finetuned_models,
            pretrained_model=pretrained_model,
            param_handler=param_handler,
            device=device,
            merge_config=merge_config,
        )

        self.representation_helper = VectorOps()
        self.dataloaders = dataloaders
    def merge(self, merge_config={'merge_method': 'lines'}):
        print(merge_config['merge_method'])
        ftms_relevant_params = [ftm.get_ft_parameters() for ftm in self.ftms_params]
        #pretrained model's keys (i.e. they have '.base_layer')
        sd_keys = [n for n in self.pretrained_model.state_dict() if "base_layer.weight" in n]
        #finetuned model's keys (i.e. they don't have '.base_layer')
        keys = [key.replace('.base_layer', '') for key in sd_keys]
        #keys = list(ftms_relevant_params[0].keys())
        #import ipdb
        #ipdb.set_trace()
        """
        Lines has 2 scalars hyperparameters, alpha and beta, which are used to compute lambda to weight each layer.
        The formula is given by:
        lambda_l = alpha + beta * (l-1) / (L-1) for each layer l in [1, L]
        However, in their Multi-Task Learning setting, alpha is set to 1/num_models when using task arithmetic, and only beta is tuned,
        and that's what we will do here. So, merge_config["scaling_coeffs"] will be the beta value.
        """
        #alpha = 1.0 / len(self.ftms_params)
        alpha = merge_config.get("alpha_coeff", 1.0)
        beta = merge_config["beta_coeff"]
        print(f"Using: alpha = {alpha}\tbeta = {beta}\tscaling_coeff = {merge_config['scaling_coeffs']}")
        #import ipdb
        #ipdb.set_trace()
        multi_task_tv = torch.stack([torch.cat([ftm[key].flatten() for key in keys], dim=0) for ftm in ftms_relevant_params], dim=0).to(self.device).to(torch.float32).sum(dim=0)
        multi_task_tv *= merge_config['scaling_coeffs']
        delta_keys = keys
        #base_delta = torch.cat([self.pretrained_model.state_dict()[key.replace('.weight', '.base_layer.weight')].flatten() for key in delta_keys], dim=0)
        #pretrained_sd = deepcopy(self.pretrained_model.state_dict())
        #import ipdb
        #ipdb.set_trace()
        print(f"multi_task_tv.shape = {multi_task_tv.shape}")
        #base_params_that_count = torch.cat([self.pretrained_model.state_dict()[key.replace('.weight', '.base_layer.weight')].flatten() for key in keys]).to(multi_task_tv.device)
        #multi_task_tv -= base_params_that_count
        VISION = True if getattr(self.pretrained_model, 'vision_model', None) is not None else False
        if VISION:
            base_name = "vision_model.base_model.model.encoder.layers"
        else:
            base_name = "base_model.model.model.layers"
        blocks = self.pretrained_model.vision_model.encoder.layers if VISION else self.pretrained_model.model.model.layers
        num_blocks = len(blocks)
        sd_keys = [n for n in self.pretrained_model.state_dict() if "base_layer.weight" in n]
        sd = deepcopy(self.pretrained_model.state_dict())
        prev = 0
        for idx, block in enumerate(blocks):
            total_params = 0
            for sub_key in block.state_dict().keys():
                if "base_layer" in sub_key and "bias" not in sub_key:
                    full_name = f"{base_name}.{idx}.{sub_key}"
                    ft_key = full_name.replace("base_layer.", "")
                    assert ft_key in ftms_relevant_params[0]
                    total_params += ftms_relevant_params[0][ft_key].numel()

            multiplying_constant = (alpha + beta * idx / (num_blocks - 1))
            print(f"Multiplying constant for layer {idx}/{num_blocks}:\t{multiplying_constant}")
            multi_task_tv[prev:prev + total_params] = multi_task_tv[prev:prev + total_params] * (alpha + beta * idx / (num_blocks - 1))
            print(f"From {prev} to {total_params + prev}")
            prev += total_params
        print(f"Total params handled: {prev}")
        merged_base = deepcopy(self.pretrained_model)
        sd = deepcopy(merged_base.state_dict())
        prev = 0
        multi_task_tv = multi_task_tv.to(sd[sd_keys[0]].device)
        print(f"Effect: {torch.mean(multi_task_tv.view(-1).abs() * merge_config['scaling_coeffs'])}")
        for key, key_ft in zip(sd_keys, keys):
            if key.startswith(base_name):
                num_params = ftms_relevant_params[0][key_ft].numel()
                sd[key].data.add_(multi_task_tv[prev:prev + num_params].reshape(sd[key].shape))

                prev += num_params
        merged_base.load_state_dict(sd)
        return merged_base


class FisherMerger(TaskMerger):
    def __init__(self, finetuned_models, pretrained_model, param_handler, device=0, merge_config=None, dataloaders=None):
        super().__init__(
            finetuned_models=finetuned_models,
            pretrained_model=pretrained_model,
            param_handler=param_handler,
            device=device,
            merge_config=merge_config
        )

        self.representation_helper = VectorOps()
        self.dataloaders = dataloaders

    def merge(self, merge_config={'merge_method': 'regmean'}):
        print(merge_config['merge_method'])
        ftms_relevant_params = [ftm.get_ft_parameters() for ftm in self.ftms_params]

        model_architecture = "ViT-B-32-CLIP" if len(ftms_relevant_params[0]) < 60 else "ViT-L-14-CLIP"

        base_path = "./data/heads"
        dataset_names = ['stanford_cars', 'dtd', 'eurosat', 'gtsrb', 'mnist', 'resisc45', 'sun397', 'svhn']
        all_clip_encodings = [get_clip_encodings(f"{base_path}/{model_architecture}/{dataset}_head.pt") for dataset in dataset_names]
        all_clip_encodings = [class_vectors / class_vectors.norm(dim=-1, keepdim=True) for class_vectors in all_clip_encodings]    

        all_fishers, num_examples_list = [], []
        # Compute input grams for each model
        for t, (ftm, cur_classifier_weights) in enumerate(zip(ftms_relevant_params, all_clip_encodings)):
            cur_model = deepcopy(self.pretrained_model)
            cur_state_dict = {
                k.replace(".weight", ".base_layer.weight"): v for k, v in ftm.items()
            }
            cur_model.load_state_dict(cur_state_dict, strict=False)
            cur_model = cur_model.to(self.device)
            cur_train_loader = self.dataloaders[t]["train"]["full"]

            cur_classifier = nn.Linear(*cur_classifier_weights.shape[::-1])
            cur_classifier.weight.data = cur_classifier_weights
            cur_classifier = cur_classifier.to(cur_model.device)
            with torch.enable_grad():
                fisher_list, num_examples = compute_fisher_qkv_outproj(cur_model, cur_train_loader, cur_classifier, dataset_names[t], model_architecture)
                all_fishers.append(fisher_list)
                num_examples_list.append(num_examples)

        sd_keys = [n for n in self.pretrained_model.state_dict() if "base_layer.weight" in n]
        sd = self.pretrained_model.state_dict()
        for l, key in enumerate(sd_keys):
            ft_key = key.replace("base_layer.", "")
            sd[key] += (
                torch.stack(
                    [ftm[ft_key] * cur_fisher[l] / num_examples for ftm, cur_fisher in zip(ftms_relevant_params, all_fishers)]
                ).sum(0) / torch.stack([cur_fisher[l] / num_examples for cur_fisher in all_fishers]).sum(0)
            ) * merge_config["scaling_coeffs"]

        self.pretrained_model.load_state_dict(sd)
        return self.pretrained_model

class MatsMerger(TaskMerger):
    def __init__(self, finetuned_models, pretrained_model, param_handler, device=0, merge_config=None, dataloaders=None):
        super().__init__(
            finetuned_models=finetuned_models,
            pretrained_model=pretrained_model,
            param_handler=param_handler,
            device=device,
            merge_config=merge_config
        )

        self.representation_helper = VectorOps()
        self.dataloaders = dataloaders

    def cg_forward(self, sum_fisher, sum_fisher_params, init_model, num_iters):
        weight_shape = sum_fisher_params.shape
        b = sum_fisher_params.flatten()
        flat_sum_fisher = sum_fisher.flatten()
        fisherVector_product = lambda v: flat_sum_fisher * v       

        A = LinearOperator((b.shape[0], b.shape[0]), matvec=fisherVector_product)

        x_final, exit_code = cg(A, b, x0=init_model.flatten(), maxiter=num_iters)

        return torch.from_numpy(x_final.reshape(weight_shape))
    

    def merge(self, merge_config={'merge_method': 'regmean'}):
        print(merge_config['merge_method'])
        ftms_relevant_params = [ftm.get_ft_parameters() for ftm in self.ftms_params]

        model_architecture = "ViT-B-32-CLIP" if len(ftms_relevant_params[0]) < 60 else "ViT-L-14-CLIP"

        base_path = "./data/heads"
        dataset_names = ['stanford_cars', 'dtd', 'eurosat', 'gtsrb', 'mnist', 'resisc45', 'sun397', 'svhn']
        all_clip_encodings = [get_clip_encodings(f"{base_path}/{model_architecture}/{dataset}_head.pt") for dataset in dataset_names]
        all_clip_encodings = [class_vectors / class_vectors.norm(dim=-1, keepdim=True) for class_vectors in all_clip_encodings]    

        all_fishers, num_examples_list = [], []
        # Compute input grams for each model
        for t, (ftm, cur_classifier_weights) in enumerate(zip(ftms_relevant_params, all_clip_encodings)):
            cur_model = deepcopy(self.pretrained_model)
            cur_state_dict = {
                k.replace(".weight", ".base_layer.weight"): v for k, v in ftm.items()
            }
            cur_model.load_state_dict(cur_state_dict, strict=False)
            cur_model = cur_model.to(self.device)
            cur_train_loader = self.dataloaders[t]["train"]["full"]

            cur_classifier = nn.Linear(*cur_classifier_weights.shape[::-1])
            cur_classifier.weight.data = cur_classifier_weights
            cur_classifier = cur_classifier.to(cur_model.device)
            with torch.enable_grad():
                fisher_list, num_examples = compute_fisher_qkv_outproj(cur_model, cur_train_loader, cur_classifier, dataset_names[t], model_architecture)
                all_fishers.append(fisher_list)
                num_examples_list.append(num_examples)

        sd_keys = [n for n in self.pretrained_model.state_dict() if "base_layer.weight" in n]
        sd = self.pretrained_model.state_dict()
        for l, key in enumerate(sd_keys):
            ft_key = key.replace("base_layer.", "")

            sum_diagonalFisherMatrices = torch.stack([cur_fisher[l] / num_examples for cur_fisher in all_fishers]).sum(0)
            average_weights = sd[key] + torch.stack([ftm[ft_key] for ftm in ftms_relevant_params]).mean(0)
            sum_diagonalFisherTimesWeight = torch.stack([(sd[key] + ftm[ft_key]) * cur_fisher[l] / num_examples \
                                                for ftm, cur_fisher in zip(ftms_relevant_params, all_fishers)]).sum(0)

            final_weights = self.cg_forward(
                sum_diagonalFisherMatrices,
                sum_diagonalFisherTimesWeight,
                average_weights,
                merge_config["num_iterations"],
            )

            sd[key] = final_weights
        
        self.pretrained_model.load_state_dict(sd)
        return self.pretrained_model

class TSVMerger(TaskMerger):
    def __init__(self, finetuned_models, pretrained_model, param_handler, device=0, merge_config=None, **kwargs):
        super().__init__(
            finetuned_models=finetuned_models,
            pretrained_model=pretrained_model,
            param_handler=param_handler,
            device=device,
            merge_config=merge_config
        )

        self.layer_names = self.get_layer_names(self.ftms_params[0].get_ft_parameters())
        self.ingredients = None
        self.cache = {}

    def set_scaling_coeffs(self, scaling_coeffs):
        self.scaling_coeffs = torch.tensor(scaling_coeffs)

    def _apply_delta(self, new_sd, key, delta_w):
        for name, param in new_sd.named_parameters():
            if name == key:
                param.data += self.scaling_coeffs * delta_w.type_as(param.data)

    def get_tsv_delta_w(self, ftms_task_dirs):
        sv_reduction = 1 / len(ftms_task_dirs)
        for i, vec in enumerate(ftms_task_dirs):
            u, s, v = torch.linalg.svd(vec.to(torch.float64), full_matrices=False)
            if i == 0:
                sum_u = torch.zeros_like(u)
                sum_s = torch.zeros_like(s)
                sum_v = torch.zeros_like(v)
            reduced_index_s = int(s.shape[0] * sv_reduction)
            # select only the first reduced_index_s columns of u and place them
            sum_u[:, i * reduced_index_s: (i + 1) * reduced_index_s] = u[
                :, :reduced_index_s
            ]
            sum_s[i * reduced_index_s: (i + 1) * reduced_index_s] = s[
                :reduced_index_s
            ]
            # select only the first reduced_index_s rows of v and place them
            sum_v[i * reduced_index_s: (i + 1) * reduced_index_s, :] = v[
                :reduced_index_s, :
            ]
        u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
        u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)

        return torch.linalg.multi_dot((u_u, v_u, torch.diag(sum_s), u_v, v_v)).type_as(ftms_task_dirs[0])

    def merge(self, merge_config):
        ftms_task_dirs = [ftm.get_ft_parameters() for ftm in self.ftms_params]

        # with newer peft versions, the keys may not have '.base_layer' in them
        all_keys = self.pretrained_model.state_dict().keys()
        new_sd = deepcopy(self.pretrained_model)
        for key in tqdm(all_keys, desc="Merging full space"):
            key_base = key.replace('.base_layer', '')
            if key_base in ftms_task_dirs[0]:
                tensor_list = [deepcopy(ft_dir[key_base]) for ft_dir in ftms_task_dirs]
                result = self.get_tsv_delta_w(tensor_list)
                delta_w = result.type_as(tensor_list[0]) if hasattr(result, 'type_as') else result
                self._apply_delta(new_sd, key, delta_w)

        return new_sd


def get_merge_handler(rep_type):
    if rep_type == 'svd-vector':
        return SVDMerger
    elif rep_type == 'vector':
        return VectorMerger
    elif rep_type == "regmean-vector":
        return RegMeanMerger
    elif rep_type == "com-vector":
        return CoMMerger
    elif rep_type == "iso-vector":
        return ISOMerger
    elif rep_type == "consensus-vector":
        return ConsensusMerger
    elif rep_type == "lines-vector":
        return LinesMerger
    elif rep_type == "fisher-vector":
        return FisherMerger
    elif rep_type == "tsv-vector":
        return TSVMerger
    elif rep_type == "mats-vector":
        return MatsMerger
