
import copy
import torch
from src.utils.utils import vector_to_state_dict, state_dict_to_vector

def cart(mean_tv_flat_checks, ptm_check, pruning_ratio=0.08, apply_subspace_boosting=False):
    merged_state_dict = copy.deepcopy(ptm_check)

    # Zero out dict
    for key in merged_state_dict:
        param = merged_state_dict[key]
        if isinstance(param, torch.Tensor):
            param.zero_()

    for tv in mean_tv_flat_checks:
        print("Processing task vector")
        model_state_dict = vector_to_state_dict(tv, ptm_check, remove_keys=[])

        for key, param in model_state_dict.items():
            if ('attn' in key or 'mlp' in key) and not ('ln' in key or 'bias' in key) and isinstance(param, torch.Tensor):
                # decompose the weights via SVD
                U, S, Vh = torch.linalg.svd(param, full_matrices=False)

                num_components_to_keep = int(0.08 * S.shape[0])
                U_k = U[:, :num_components_to_keep]
                S_k = S[:num_components_to_keep]
                Vh_k = Vh[:num_components_to_keep, :]

                reconstructed_weights = (U_k * S_k.unsqueeze(0)) @ Vh_k
                merged_state_dict[key] += reconstructed_weights
            else:
                if isinstance(param, torch.Tensor):
                    merged_state_dict[key] += param

    # apply subspace boosting
    if apply_subspace_boosting:
        svd_thresh=0.0
        for key, param in merged_state_dict.items():
            if ('attn' in key or 'mlp' in key) and not ('ln' in key or 'bias' in key) and isinstance(param, torch.Tensor):

                # Apply SVD
                U, S, Vh = torch.linalg.svd(param, full_matrices=False)
                
                total_sum = S.sum()
                cumulative = torch.cumsum(S, dim=0)
                thresh = svd_thresh

                k = (cumulative / total_sum >= thresh).nonzero(as_tuple=False)
                cutoff_idx = k[0].item()

                S_damped = torch.clamp(S, min=S[cutoff_idx])

                merged_state_dict[key] = (U * S_damped.unsqueeze(0)) @ Vh

    return state_dict_to_vector(merged_state_dict)


