import torch
from peft_utils.peft_layers import low_rank_CP,low_rank_linear
from peft_utils.bilevel_trainer import get_s,prune,update_state

def orth_regularize(layers,reg = 1e-2):
    reg_loss = 0.0
    for l in layers:
        if isinstance(l,low_rank_linear):
            reg_loss += reg*torch.linalg.norm(l.us.T@l.us-torch.eye(l.us.shape[1],device = l.us.device))**2
            reg_loss += reg*torch.linalg.norm(l.vs.T@l.vs-torch.eye(l.vs.shape[1],device = l.vs.device))**2
        elif isinstance(l,low_rank_CP):
            for u in l.us:
                reg_loss+= reg*torch.linalg.norm(u.T@u-torch.eye(u.shape[1],device = u.device))**2
    return reg_loss


def magnitude_compress(layers,cr):
    s = torch.abs(get_s(layers))
    s_pruned = prune(s,torch.tensor(cr))
    update_state(s_pruned,layers)
