import torch
import copy
from functools import partial

from collections import defaultdict
import numpy as np


def core_space_preservation(args, task_vector, pretrained_checkpoint):
    model = torch.load(pretrained_checkpoint).to('cuda')
    for name, pp in list(model.named_parameters()):
        if task_vector.vector[name] is None:
            continue
        elif len(pp.shape) == 1:
            pass
        elif len(pp.shape) == 2:
            W_pre = pp.data.to('cuda')

            U, S, Vh = torch.linalg.svd(W_pre, full_matrices=False)  # U:[d_out,m], S:[m], Vh:[m,d_in]
            m = S.shape[0]
            if args.k is -1:
                k = m
            elif args.k < 1:
                energy = S.pow(2)
                cum_energy = torch.cumsum(energy, dim=0)
                total_energy = cum_energy[-1].clamp_min(1e-12)
                target = args.k * total_energy

                k = int(torch.searchsorted(cum_energy, target).item()) + 1
                print('top ', args.k, 'singular is k=', k)
            else:
                k = args.k
            r = min(k, m)
            U_r = U[:, :r].contiguous()
            S_r = S[:r].contiguous()
            V_r = Vh[:r, :].transpose(-2, -1).contiguous()


            # Decomposing TV into 4 parts
            tv = task_vector.vector[name].to('cuda')

            # P_U W  = U(U^T W)
            PUW = U_r @ (U_r.transpose(-2, -1) @ tv)  # [d_out, d_in]
            # W P_V  = (W V) V^T
            WPV = (tv @ V_r) @ V_r.transpose(-2, -1)  # [d_out, d_in]
            # P_U W P_V
            W_in_in = U_r @ (U_r.transpose(-2, -1) @ tv @ V_r) @ V_r.transpose(-2, -1)
            # out U only, in V
            W_outU_inV = WPV - W_in_in  # (I-P_U) W P_V
            # in U, out V only
            W_inU_outV = PUW - W_in_in  # P_U W (I-P_V)
            # out both: (I-P_U) W (I-P_V) = W - P_UW - WPV + P_UWPV
            W_out_out = tv - PUW - WPV + W_in_in


            # Selecting the merging parts of TV
            tv_filtered = 0.0
            if 'inin' in args.parts:
                tv_filtered += W_in_in
            if 'inout' in args.parts:
                tv_filtered += W_inU_outV
            if 'outin' in args.parts:
                tv_filtered += W_outU_inV
            if 'outout' in args.parts:
                tv_filtered += W_out_out

            task_vector.vector[name] = tv_filtered.to('cpu')