import torch
from torch import nn

from ._interface import BaseSearch


class UNIFORMSearch(BaseSearch):
    def __init__(self, eval_data, mixup_fn, name_omit=[], ratio_target=0.5, stage_name_in_current_model="stages"):
        self.eval_data = eval_data
        self.name_omit = name_omit

        self.mixup_fn = mixup_fn
        self.dev = torch.device(torch.cuda.current_device())
        self.stage_name_in_current_model = stage_name_in_current_model
        # sensitivity dict needed for ASVD search
        self.sensitivity_dict = {}
        self.lrd_method = None
        self.ratio_target = ratio_target

    def search(self, model: nn.Module):
        default_param_ratio = 1.0
        layer_compression_dict = {
            name: self.ratio_target for name, _ in model.named_modules()
        }

        # replace name omit layer compression with 1.0
        for name in layer_compression_dict.keys():
            if any(n in name for n in self.name_omit):
                layer_compression_dict[name] = default_param_ratio
        # layer_compression_dict = {'': 0.5, 'patch_embed': 1.0, 'patch_embed.proj': 1.0, 'patch_embed.norm': 1.0, 'pos_drop': 0.5, 'patch_drop': 0.5, 'norm_pre': 1.0, 'blocks': 0.5, 'blocks.0': 0.5, 'blocks.0.norm1': 1.0, 'blocks.0.attn': 0.5, 'blocks.0.attn.qkv': 0.5, 'blocks.0.attn.q_norm': 1.0, 'blocks.0.attn.k_norm': 1.0, 'blocks.0.attn.attn_drop': 0.5, 'blocks.0.attn.norm': 1.0, 'blocks.0.attn.proj': 0.5, 'blocks.0.attn.proj_drop': 0.5, 'blocks.0.ls1': 0.5, 'blocks.0.drop_path1': 0.5, 'blocks.0.norm2': 1.0, 'blocks.0.mlp': 0.5, 'blocks.0.mlp.fc1': 0.5, 'blocks.0.mlp.act': 0.5, 'blocks.0.mlp.drop1': 0.5, 'blocks.0.mlp.norm': 1.0, 'blocks.0.mlp.fc2': 0.5, 'blocks.0.mlp.drop2': 0.5, 'blocks.0.ls2': 0.5, 'blocks.0.drop_path2': 0.5, 'blocks.1': 0.5, 'blocks.1.norm1': 1.0, 'blocks.1.attn': 0.5, 'blocks.1.attn.qkv': 0.5, 'blocks.1.attn.q_norm': 1.0, 'blocks.1.attn.k_norm': 1.0, 'blocks.1.attn.attn_drop': 0.5, 'blocks.1.attn.norm': 1.0, 'blocks.1.attn.proj': 0.5, 'blocks.1.attn.proj_drop': 0.5, 'blocks.1.ls1': 0.5, 'blocks.1.drop_path1': 0.5, 'blocks.1.norm2': 1.0, 'blocks.1.mlp': 0.5, 'blocks.1.mlp.fc1': 0.5, 'blocks.1.mlp.act': 0.5, 'blocks.1.mlp.drop1': 0.5, 'blocks.1.mlp.norm': 1.0, 'blocks.1.mlp.fc2': 0.5, 'blocks.1.mlp.drop2': 0.5, 'blocks.1.ls2': 0.5, 'blocks.1.drop_path2': 0.5, 'blocks.2': 0.5, 'blocks.2.norm1': 1.0, 'blocks.2.attn': 0.5, 'blocks.2.attn.qkv': 0.5, 'blocks.2.attn.q_norm': 1.0, 'blocks.2.attn.k_norm': 1.0, 'blocks.2.attn.attn_drop': 0.5, 'blocks.2.attn.norm': 1.0, 'blocks.2.attn.proj': 0.5, 'blocks.2.attn.proj_drop': 0.5, 'blocks.2.ls1': 0.5, 'blocks.2.drop_path1': 0.5, 'blocks.2.norm2': 1.0, 'blocks.2.mlp': 0.5, 'blocks.2.mlp.fc1': 0.5, 'blocks.2.mlp.act': 0.5, 'blocks.2.mlp.drop1': 0.5, 'blocks.2.mlp.norm': 1.0, 'blocks.2.mlp.fc2': 0.5, 'blocks.2.mlp.drop2': 0.5, 'blocks.2.ls2': 0.5, 'blocks.2.drop_path2': 0.5, 'blocks.3': 0.5, 'blocks.3.norm1': 1.0, 'blocks.3.attn': 0.5, 'blocks.3.attn.qkv': 0.5, 'blocks.3.attn.q_norm': 1.0, 'blocks.3.attn.k_norm': 1.0, 'blocks.3.attn.attn_drop': 0.5, 'blocks.3.attn.norm': 1.0, 'blocks.3.attn.proj': 0.5, 'blocks.3.attn.proj_drop': 0.5, 'blocks.3.ls1': 0.5, 'blocks.3.drop_path1': 0.5, 'blocks.3.norm2': 1.0, 'blocks.3.mlp': 0.5, 'blocks.3.mlp.fc1': 0.5, 'blocks.3.mlp.act': 0.5, 'blocks.3.mlp.drop1': 0.5, 'blocks.3.mlp.norm': 1.0, 'blocks.3.mlp.fc2': 0.5, 'blocks.3.mlp.drop2': 0.5, 'blocks.3.ls2': 0.5, 'blocks.3.drop_path2': 0.5, 'blocks.4': 0.5, 'blocks.4.norm1': 1.0, 'blocks.4.attn': 0.5, 'blocks.4.attn.qkv': 0.5, 'blocks.4.attn.q_norm': 1.0, 'blocks.4.attn.k_norm': 1.0, 'blocks.4.attn.attn_drop': 0.5, 'blocks.4.attn.norm': 1.0, 'blocks.4.attn.proj': 0.5, 'blocks.4.attn.proj_drop': 0.5, 'blocks.4.ls1': 0.5, 'blocks.4.drop_path1': 0.5, 'blocks.4.norm2': 1.0, 'blocks.4.mlp': 0.5, 'blocks.4.mlp.fc1': 0.5, 'blocks.4.mlp.act': 0.5, 'blocks.4.mlp.drop1': 0.5, 'blocks.4.mlp.norm': 1.0, 'blocks.4.mlp.fc2': 0.5, 'blocks.4.mlp.drop2': 0.5, 'blocks.4.ls2': 0.5, 'blocks.4.drop_path2': 0.5, 'blocks.5': 0.5, 'blocks.5.norm1': 1.0, 'blocks.5.attn': 0.5, 'blocks.5.attn.qkv': 0.5, 'blocks.5.attn.q_norm': 1.0, 'blocks.5.attn.k_norm': 1.0, 'blocks.5.attn.attn_drop': 0.5, 'blocks.5.attn.norm': 1.0, 'blocks.5.attn.proj': 0.5, 'blocks.5.attn.proj_drop': 0.5, 'blocks.5.ls1': 0.5, 'blocks.5.drop_path1': 0.5, 'blocks.5.norm2': 1.0, 'blocks.5.mlp': 0.5, 'blocks.5.mlp.fc1': 0.5, 'blocks.5.mlp.act': 0.5, 'blocks.5.mlp.drop1': 0.5, 'blocks.5.mlp.norm': 1.0, 'blocks.5.mlp.fc2': 0.5, 'blocks.5.mlp.drop2': 0.5, 'blocks.5.ls2': 0.5, 'blocks.5.drop_path2': 0.5, 'blocks.6': 0.5, 'blocks.6.norm1': 1.0, 'blocks.6.attn': 0.5, 'blocks.6.attn.qkv': 0.5, 'blocks.6.attn.q_norm': 1.0, 'blocks.6.attn.k_norm': 1.0, 'blocks.6.attn.attn_drop': 0.5, 'blocks.6.attn.norm': 1.0, 'blocks.6.attn.proj': 0.5, 'blocks.6.attn.proj_drop': 0.5, 'blocks.6.ls1': 0.5, 'blocks.6.drop_path1': 0.5, 'blocks.6.norm2': 1.0, 'blocks.6.mlp': 0.5, 'blocks.6.mlp.fc1': 0.5, 'blocks.6.mlp.act': 0.5, 'blocks.6.mlp.drop1': 0.5, 'blocks.6.mlp.norm': 1.0, 'blocks.6.mlp.fc2': 0.5, 'blocks.6.mlp.drop2': 0.5, 'blocks.6.ls2': 0.5, 'blocks.6.drop_path2': 0.5, 'blocks.7': 0.5, 'blocks.7.norm1': 1.0, 'blocks.7.attn': 0.5, 'blocks.7.attn.qkv': 0.5, 'blocks.7.attn.q_norm': 1.0, 'blocks.7.attn.k_norm': 1.0, 'blocks.7.attn.attn_drop': 0.5, 'blocks.7.attn.norm': 1.0, 'blocks.7.attn.proj': 0.5, 'blocks.7.attn.proj_drop': 0.5, 'blocks.7.ls1': 0.5, 'blocks.7.drop_path1': 0.5, 'blocks.7.norm2': 1.0, 'blocks.7.mlp': 0.5, 'blocks.7.mlp.fc1': 0.5, 'blocks.7.mlp.act': 0.5, 'blocks.7.mlp.drop1': 0.5, 'blocks.7.mlp.norm': 1.0, 'blocks.7.mlp.fc2': 0.5, 'blocks.7.mlp.drop2': 0.5, 'blocks.7.ls2': 0.5, 'blocks.7.drop_path2': 0.5, 'blocks.8': 0.5, 'blocks.8.norm1': 1.0, 'blocks.8.attn': 0.5, 'blocks.8.attn.qkv': 0.5, 'blocks.8.attn.q_norm': 1.0, 'blocks.8.attn.k_norm': 1.0, 'blocks.8.attn.attn_drop': 0.5, 'blocks.8.attn.norm': 1.0, 'blocks.8.attn.proj': 0.5, 'blocks.8.attn.proj_drop': 0.5, 'blocks.8.ls1': 0.5, 'blocks.8.drop_path1': 0.5, 'blocks.8.norm2': 1.0, 'blocks.8.mlp': 0.5, 'blocks.8.mlp.fc1': 0.5, 'blocks.8.mlp.act': 0.5, 'blocks.8.mlp.drop1': 0.5, 'blocks.8.mlp.norm': 1.0, 'blocks.8.mlp.fc2': 0.5, 'blocks.8.mlp.drop2': 0.5, 'blocks.8.ls2': 0.5, 'blocks.8.drop_path2': 0.5, 'blocks.9': 0.5, 'blocks.9.norm1': 1.0, 'blocks.9.attn': 0.5, 'blocks.9.attn.qkv': 0.5, 'blocks.9.attn.q_norm': 1.0, 'blocks.9.attn.k_norm': 1.0, 'blocks.9.attn.attn_drop': 0.5, 'blocks.9.attn.norm': 1.0, 'blocks.9.attn.proj': 0.5, 'blocks.9.attn.proj_drop': 0.5, 'blocks.9.ls1': 0.5, 'blocks.9.drop_path1': 0.5, 'blocks.9.norm2': 1.0, 'blocks.9.mlp': 0.5, 'blocks.9.mlp.fc1': 0.5, 'blocks.9.mlp.act': 0.5, 'blocks.9.mlp.drop1': 0.5, 'blocks.9.mlp.norm': 1.0, 'blocks.9.mlp.fc2': 0.5, 'blocks.9.mlp.drop2': 0.5, 'blocks.9.ls2': 0.5, 'blocks.9.drop_path2': 0.5, 'blocks.10': 0.5, 'blocks.10.norm1': 1.0, 'blocks.10.attn': 0.5, 'blocks.10.attn.qkv': 0.5, 'blocks.10.attn.q_norm': 1.0, 'blocks.10.attn.k_norm': 1.0, 'blocks.10.attn.attn_drop': 0.5, 'blocks.10.attn.norm': 1.0, 'blocks.10.attn.proj': 0.5, 'blocks.10.attn.proj_drop': 0.5, 'blocks.10.ls1': 0.5, 'blocks.10.drop_path1': 0.5, 'blocks.10.norm2': 1.0, 'blocks.10.mlp': 0.5, 'blocks.10.mlp.fc1': 0.5, 'blocks.10.mlp.act': 0.5, 'blocks.10.mlp.drop1': 0.5, 'blocks.10.mlp.norm': 1.0, 'blocks.10.mlp.fc2': 0.5, 'blocks.10.mlp.drop2': 0.5, 'blocks.10.ls2': 0.5, 'blocks.10.drop_path2': 0.5, 'blocks.11': 0.5, 'blocks.11.norm1': 1.0, 'blocks.11.attn': 0.5, 'blocks.11.attn.qkv': 0.5, 'blocks.11.attn.q_norm': 1.0, 'blocks.11.attn.k_norm': 1.0, 'blocks.11.attn.attn_drop': 0.5, 'blocks.11.attn.norm': 1.0, 'blocks.11.attn.proj': 0.5, 'blocks.11.attn.proj_drop': 0.5, 'blocks.11.ls1': 0.5, 'blocks.11.drop_path1': 0.5, 'blocks.11.norm2': 1.0, 'blocks.11.mlp': 0.5, 'blocks.11.mlp.fc1': 0.5, 'blocks.11.mlp.act': 0.5, 'blocks.11.mlp.drop1': 0.5, 'blocks.11.mlp.norm': 1.0, 'blocks.11.mlp.fc2': 0.5, 'blocks.11.mlp.drop2': 0.5, 'blocks.11.ls2': 0.5, 'blocks.11.drop_path2': 0.5, 'blocks.12': 0.5, 'blocks.12.norm1': 1.0, 'blocks.12.attn': 0.5, 'blocks.12.attn.qkv': 0.5, 'blocks.12.attn.q_norm': 1.0, 'blocks.12.attn.k_norm': 1.0, 'blocks.12.attn.attn_drop': 0.5, 'blocks.12.attn.norm': 1.0, 'blocks.12.attn.proj': 0.5, 'blocks.12.attn.proj_drop': 0.5, 'blocks.12.ls1': 0.5, 'blocks.12.drop_path1': 0.5, 'blocks.12.norm2': 1.0, 'blocks.12.mlp': 0.5, 'blocks.12.mlp.fc1': 0.5, 'blocks.12.mlp.act': 0.5, 'blocks.12.mlp.drop1': 0.5, 'blocks.12.mlp.norm': 1.0, 'blocks.12.mlp.fc2': 0.5, 'blocks.12.mlp.drop2': 0.5, 'blocks.12.ls2': 0.5, 'blocks.12.drop_path2': 0.5, 'blocks.13': 0.5, 'blocks.13.norm1': 1.0, 'blocks.13.attn': 0.5, 'blocks.13.attn.qkv': 0.5, 'blocks.13.attn.q_norm': 1.0, 'blocks.13.attn.k_norm': 1.0, 'blocks.13.attn.attn_drop': 0.5, 'blocks.13.attn.norm': 1.0, 'blocks.13.attn.proj': 0.5, 'blocks.13.attn.proj_drop': 0.5, 'blocks.13.ls1': 0.5, 'blocks.13.drop_path1': 0.5, 'blocks.13.norm2': 1.0, 'blocks.13.mlp': 0.5, 'blocks.13.mlp.fc1': 0.5, 'blocks.13.mlp.act': 0.5, 'blocks.13.mlp.drop1': 0.5, 'blocks.13.mlp.norm': 1.0, 'blocks.13.mlp.fc2': 0.5, 'blocks.13.mlp.drop2': 0.5, 'blocks.13.ls2': 0.5, 'blocks.13.drop_path2': 0.5, 'blocks.14': 0.5, 'blocks.14.norm1': 1.0, 'blocks.14.attn': 0.5, 'blocks.14.attn.qkv': 0.5, 'blocks.14.attn.q_norm': 1.0, 'blocks.14.attn.k_norm': 1.0, 'blocks.14.attn.attn_drop': 0.5, 'blocks.14.attn.norm': 1.0, 'blocks.14.attn.proj': 0.5, 'blocks.14.attn.proj_drop': 0.5, 'blocks.14.ls1': 0.5, 'blocks.14.drop_path1': 0.5, 'blocks.14.norm2': 1.0, 'blocks.14.mlp': 0.5, 'blocks.14.mlp.fc1': 0.5, 'blocks.14.mlp.act': 0.5, 'blocks.14.mlp.drop1': 0.5, 'blocks.14.mlp.norm': 1.0, 'blocks.14.mlp.fc2': 0.5, 'blocks.14.mlp.drop2': 0.5, 'blocks.14.ls2': 0.5, 'blocks.14.drop_path2': 0.5, 'blocks.15': 0.5, 'blocks.15.norm1': 1.0, 'blocks.15.attn': 0.5, 'blocks.15.attn.qkv': 0.5, 'blocks.15.attn.q_norm': 1.0, 'blocks.15.attn.k_norm': 1.0, 'blocks.15.attn.attn_drop': 0.5, 'blocks.15.attn.norm': 1.0, 'blocks.15.attn.proj': 0.5, 'blocks.15.attn.proj_drop': 0.5, 'blocks.15.ls1': 0.5, 'blocks.15.drop_path1': 0.5, 'blocks.15.norm2': 1.0, 'blocks.15.mlp': 0.5, 'blocks.15.mlp.fc1': 0.5, 'blocks.15.mlp.act': 0.5, 'blocks.15.mlp.drop1': 0.5, 'blocks.15.mlp.norm': 1.0, 'blocks.15.mlp.fc2': 0.5, 'blocks.15.mlp.drop2': 0.5, 'blocks.15.ls2': 0.5, 'blocks.15.drop_path2': 0.5, 'blocks.16': 0.5, 'blocks.16.norm1': 1.0, 'blocks.16.attn': 0.5, 'blocks.16.attn.qkv': 0.5, 'blocks.16.attn.q_norm': 1.0, 'blocks.16.attn.k_norm': 1.0, 'blocks.16.attn.attn_drop': 0.5, 'blocks.16.attn.norm': 1.0, 'blocks.16.attn.proj': 0.5, 'blocks.16.attn.proj_drop': 0.5, 'blocks.16.ls1': 0.5, 'blocks.16.drop_path1': 0.5, 'blocks.16.norm2': 1.0, 'blocks.16.mlp': 0.5, 'blocks.16.mlp.fc1': 0.5, 'blocks.16.mlp.act': 0.5, 'blocks.16.mlp.drop1': 0.5, 'blocks.16.mlp.norm': 1.0, 'blocks.16.mlp.fc2': 0.5, 'blocks.16.mlp.drop2': 0.5, 'blocks.16.ls2': 0.5, 'blocks.16.drop_path2': 0.5, 'blocks.17': 0.5, 'blocks.17.norm1': 1.0, 'blocks.17.attn': 0.5, 'blocks.17.attn.qkv': 0.5, 'blocks.17.attn.q_norm': 1.0, 'blocks.17.attn.k_norm': 1.0, 'blocks.17.attn.attn_drop': 0.5, 'blocks.17.attn.norm': 1.0, 'blocks.17.attn.proj': 0.5, 'blocks.17.attn.proj_drop': 0.5, 'blocks.17.ls1': 0.5, 'blocks.17.drop_path1': 0.5, 'blocks.17.norm2': 1.0, 'blocks.17.mlp': 0.5, 'blocks.17.mlp.fc1': 0.5, 'blocks.17.mlp.act': 0.5, 'blocks.17.mlp.drop1': 0.5, 'blocks.17.mlp.norm': 1.0, 'blocks.17.mlp.fc2': 0.5, 'blocks.17.mlp.drop2': 0.5, 'blocks.17.ls2': 0.5, 'blocks.17.drop_path2': 0.5, 'blocks.18': 0.5, 'blocks.18.norm1': 1.0, 'blocks.18.attn': 0.5, 'blocks.18.attn.qkv': 0.5, 'blocks.18.attn.q_norm': 1.0, 'blocks.18.attn.k_norm': 1.0, 'blocks.18.attn.attn_drop': 0.5, 'blocks.18.attn.norm': 1.0, 'blocks.18.attn.proj': 0.5, 'blocks.18.attn.proj_drop': 0.5, 'blocks.18.ls1': 0.5, 'blocks.18.drop_path1': 0.5, 'blocks.18.norm2': 1.0, 'blocks.18.mlp': 0.5, 'blocks.18.mlp.fc1': 0.5, 'blocks.18.mlp.act': 0.5, 'blocks.18.mlp.drop1': 0.5, 'blocks.18.mlp.norm': 1.0, 'blocks.18.mlp.fc2': 0.5, 'blocks.18.mlp.drop2': 0.5, 'blocks.18.ls2': 0.5, 'blocks.18.drop_path2': 0.5, 'blocks.19': 0.5, 'blocks.19.norm1': 1.0, 'blocks.19.attn': 0.5, 'blocks.19.attn.qkv': 0.5, 'blocks.19.attn.q_norm': 1.0, 'blocks.19.attn.k_norm': 1.0, 'blocks.19.attn.attn_drop': 0.5, 'blocks.19.attn.norm': 1.0, 'blocks.19.attn.proj': 0.5, 'blocks.19.attn.proj_drop': 0.5, 'blocks.19.ls1': 0.5, 'blocks.19.drop_path1': 0.5, 'blocks.19.norm2': 1.0, 'blocks.19.mlp': 0.5, 'blocks.19.mlp.fc1': 0.5, 'blocks.19.mlp.act': 0.5, 'blocks.19.mlp.drop1': 0.5, 'blocks.19.mlp.norm': 1.0, 'blocks.19.mlp.fc2': 0.5, 'blocks.19.mlp.drop2': 0.5, 'blocks.19.ls2': 0.5, 'blocks.19.drop_path2': 0.5, 'blocks.20': 0.5, 'blocks.20.norm1': 1.0, 'blocks.20.attn': 0.5, 'blocks.20.attn.qkv': 0.5, 'blocks.20.attn.q_norm': 1.0, 'blocks.20.attn.k_norm': 1.0, 'blocks.20.attn.attn_drop': 0.5, 'blocks.20.attn.norm': 1.0, 'blocks.20.attn.proj': 0.5, 'blocks.20.attn.proj_drop': 0.5, 'blocks.20.ls1': 0.5, 'blocks.20.drop_path1': 0.5, 'blocks.20.norm2': 1.0, 'blocks.20.mlp': 0.5, 'blocks.20.mlp.fc1': 0.5, 'blocks.20.mlp.act': 0.5, 'blocks.20.mlp.drop1': 0.5, 'blocks.20.mlp.norm': 1.0, 'blocks.20.mlp.fc2': 0.5, 'blocks.20.mlp.drop2': 0.5, 'blocks.20.ls2': 0.5, 'blocks.20.drop_path2': 0.5, 'blocks.21': 0.5, 'blocks.21.norm1': 1.0, 'blocks.21.attn': 0.5, 'blocks.21.attn.qkv': 0.5, 'blocks.21.attn.q_norm': 1.0, 'blocks.21.attn.k_norm': 1.0, 'blocks.21.attn.attn_drop': 0.5, 'blocks.21.attn.norm': 1.0, 'blocks.21.attn.proj': 0.5, 'blocks.21.attn.proj_drop': 0.5, 'blocks.21.ls1': 0.5, 'blocks.21.drop_path1': 0.5, 'blocks.21.norm2': 1.0, 'blocks.21.mlp': 0.5, 'blocks.21.mlp.fc1': 0.5, 'blocks.21.mlp.act': 0.5, 'blocks.21.mlp.drop1': 0.5, 'blocks.21.mlp.norm': 1.0, 'blocks.21.mlp.fc2': 0.5, 'blocks.21.mlp.drop2': 0.5, 'blocks.21.ls2': 0.5, 'blocks.21.drop_path2': 0.5, 'blocks.22': 0.5, 'blocks.22.norm1': 1.0, 'blocks.22.attn': 0.5, 'blocks.22.attn.qkv': 0.5, 'blocks.22.attn.q_norm': 1.0, 'blocks.22.attn.k_norm': 1.0, 'blocks.22.attn.attn_drop': 0.5, 'blocks.22.attn.norm': 1.0, 'blocks.22.attn.proj': 0.5, 'blocks.22.attn.proj_drop': 0.5, 'blocks.22.ls1': 0.5, 'blocks.22.drop_path1': 0.5, 'blocks.22.norm2': 1.0, 'blocks.22.mlp': 0.5, 'blocks.22.mlp.fc1': 0.5, 'blocks.22.mlp.act': 0.5, 'blocks.22.mlp.drop1': 0.5, 'blocks.22.mlp.norm': 1.0, 'blocks.22.mlp.fc2': 0.5, 'blocks.22.mlp.drop2': 0.5, 'blocks.22.ls2': 0.5, 'blocks.22.drop_path2': 0.5, 'blocks.23': 0.5, 'blocks.23.norm1': 1.0, 'blocks.23.attn': 0.5, 'blocks.23.attn.qkv': 0.5, 'blocks.23.attn.q_norm': 1.0, 'blocks.23.attn.k_norm': 1.0, 'blocks.23.attn.attn_drop': 0.5, 'blocks.23.attn.norm': 1.0, 'blocks.23.attn.proj': 0.5, 'blocks.23.attn.proj_drop': 0.5, 'blocks.23.ls1': 0.5, 'blocks.23.drop_path1': 0.5, 'blocks.23.norm2': 1.0, 'blocks.23.mlp': 0.5, 'blocks.23.mlp.fc1': 0.5, 'blocks.23.mlp.act': 0.5, 'blocks.23.mlp.drop1': 0.5, 'blocks.23.mlp.norm': 1.0, 'blocks.23.mlp.fc2': 0.5, 'blocks.23.mlp.drop2': 0.5, 'blocks.23.ls2': 0.5, 'blocks.23.drop_path2': 0.5, 'norm': 1.0, 'fc_norm': 1.0, 'head_drop': 1.0, 'head': 1.0}
        return layer_compression_dict

    def search_blockwise(self, model: nn.Module, stage_name: str, calib_data=None):
        default_param_ratio = 1.0
        compression_dict = {
            name: default_param_ratio for name, _ in model.named_modules()
        }
        blocks, blocks_layer_names = self.get_model_blocks(model, stage_name)
        for block, block_layer_names in zip(blocks, blocks_layer_names):
            self.lrd_method.compute_scaling(
                model,
                name_omit=self.name_omit,
                calib_data=calib_data,
                mixup_fn=self.mixup_fn,
                white_list=block_layer_names,
            )
            for layer_name, _ in block_layer_names:
                compression_dict[layer_name] = self.ratio_target
                layer_name_within_block = max(
                    (
                        name
                        for name in dict(block.named_modules()).keys()
                        if name in layer_name
                    ),
                    key=len,
                )
                smodule: nn.Linear = dict(block.named_modules())[
                    layer_name_within_block
                ]
                factorized_matrix = self.lrd_method.factorize_matrix(
                    smodule.weight, ratio=self.ratio_target, name=layer_name
                )
                smodule.weight.data.copy_(
                    factorized_matrix.mat_l.to(self.dev)
                    @ factorized_matrix.mat_r.to(self.dev)
                )
        return compression_dict
