import torch
from functools import partial


@torch.no_grad()
def collect_lr_layers(model, instance_module):
    adapters_layers = []

    def fn(mod):
        if isinstance(mod, instance_module):
            adapters_layers.append(mod)

    model.apply(fn)
    return adapters_layers


def activate_upper_level(self, adapter_name):
    if adapter_name is not None:
        self.lora_B[adapter_name].requires_grad = True
        self.lora_A[adapter_name].requires_grad = True
        self.lora_E[adapter_name].requires_grad = True
        if hasattr(self,'bias'):
            self.bias.requires_grad = False
    else:
        self.lora_B.requires_grad = True
        self.lora_A.requires_grad = True
        self.lora_E.requires_grad = True
        if hasattr(self,'bias'):
            self.bias.requires_grad = False


def activate_lower_level(self, adapter_name):
    if adapter_name is not None:
        self.lora_B[adapter_name].requires_grad = True
        self.lora_A[adapter_name].requires_grad = True
        self.lora_E[adapter_name].requires_grad = False
        self.lora_E[adapter_name].grad = None
        if hasattr(self,'bias'):
            self.bias.requires_grad = True
    else:
        self.lora_B.requires_grad = True
        self.lora_A.requires_grad = True
        self.lora_E.requires_grad = False
        self.lora_E.grad = None
        if hasattr(self,'bias'):
            self.bias.requires_grad = True

@torch.no_grad()
def get_hypergradient(self, adapter_name):
    if adapter_name is not None:
        eps = 1e-8
        masked_product = torch.tensor([1./s if s>=1e-5 else 0.0 for s in self.lora_E[adapter_name].squeeze() ],device = self.lora_E[adapter_name].device)
        self.lora_E[adapter_name].grad.add_(
            (
                torch.diag(
                    self.lora_B[adapter_name].T @ self.lora_B[adapter_name].grad
                    + self.lora_A[adapter_name].grad @ self.lora_A[adapter_name].T
                ) * masked_product
            )[:, None]
        )
    else:
        eps = 1e-8
        masked_product = torch.tensor([1./s if s>=1e-5 else 0.0 for s in self.lora_E[adapter_name].squeeze() ],device = self.lora_E[adapter_name].device)
        self.lora_E.grad.add_(
            (
                torch.diag(
                    self.lora_B.T @ self.lora_B.grad
                    + self.lora_A.grad @ self.lora_A.T
                ) * masked_product
            )[:, None]
        )


@torch.no_grad()
def wrap_classinstance_with_lr_methods(listofinstances, funlist, adapter_default_name):
    for o in listofinstances:
        for f in funlist:
            setattr(o, f.__name__, partial(f, o, adapter_name=adapter_default_name))
        # o.adapter_name = adapter_default_name
        setattr(o,'adapter_name',adapter_default_name)
    return listofinstances


@torch.no_grad()
def load_weights(
    full_model, low_rak_model, low_rank_layers, layers_to_avoid=[]
):  # ['conv1.weight','fc.weight','fc.bias']):
    """
    full_model: original model
    low_rank_model: low_rank format model
    layers_to_avoid: layers to avoid in the loading
    returns the factorized model with the original factorized weights
    """
    if full_model == None:
        for l in low_rank_layers:
            del l.weight
    else:
        full_sd = full_model.state_dict()
        lr_sd = low_rak_model.state_dict()
        for k, v in full_sd.items():
            if k not in layers_to_avoid:
                lr_sd[k].copy_(v)
        for l in low_rank_layers:
            l.format_weight()
