from dataclasses import dataclass, field, InitVar

import numpy as np
import torch
from torch import nn
import gc
from tqdm import tqdm
from all_utils.other import get_model_name
import pickle
import os

@dataclass
class FactorizedMatrix:
    mat_l: InitVar[torch.Tensor] = None
    mat_r: InitVar[torch.Tensor] = None
    eq_rank: int = 0
    active_rank: int = 0
    singular_values: torch.Tensor = None

    _mat_l: torch.Tensor = field(init=False, default=None, repr=False)
    _mat_r: torch.Tensor = field(init=False, default=None, repr=False)

    @property
    def mat_l(self):
        # this allows to dynamically change the active rank without changing the stored.
        return self._mat_l[:, :self.active_rank]

    @property
    def mat_r(self):
        # this allows to dynamically change the active rank without changing the stored.
        return self._mat_r[:self.active_rank, :]

    @mat_l.setter
    def mat_l_view(self, value):
        self._mat_l = value

    @mat_r.setter
    def mat_r_view(self, value):
        self._mat_r = value

    def __post_init__(self, mat_l, mat_r):
        if mat_l is not None:
            mat_l_cpu = mat_l.cpu()
            self._mat_l = mat_l_cpu
            del mat_l
        if mat_r is not None:
            mat_r_cpu = mat_r.cpu()
            self._mat_r = mat_r_cpu
            del mat_r

        if self._mat_l is not None and self._mat_l.shape[1] > self.eq_rank:
            print("Truncating mat_l to eq_rank for storage efficiency")
            self._mat_l = self._mat_l[:, :self.eq_rank]

        if self._mat_r is not None and self._mat_r.shape[0] > self.eq_rank:
            print("Truncating mat_r to eq_rank for storage efficiency")
            self._mat_r = self._mat_r[:self.eq_rank, :]
        
        if self.singular_values is not None and self.singular_values.is_cuda:
            self.singular_values = self.singular_values.cpu()

def is_linear_like_conv(layer):
    return (
        isinstance(layer, nn.Conv2d)
        and layer.kernel_size == (1, 1)
        and layer.groups == 1
        and layer.out_channels >= 10
    )

def get_valid_layers(model: nn.Module, name_omit, white_list=[]):
    return [
        (name, module_sub)
        for name, module_sub in model.named_modules()
        if all(omit not in name for omit in name_omit)
        and (not white_list or any(n in name for n in white_list))
        if isinstance(module_sub, nn.Linear) and module_sub.out_features >= 10
    ]

def get_eq_rank(n, m):
    return int(n * m / (n + m))

def _find_decoder_layers(model: nn.Module) -> nn.ModuleList:
    """
    Dynamically finds the module list containing the decoder layers of a transformer model.
    This is a heuristic-based approach that should work for most modern LLMs.
    """
    for name, module in model.named_modules():
        # The decoder layers are usually in a ModuleList.
        if isinstance(module, nn.ModuleList):
            # Check if the children of this ModuleList are the decoder blocks.
            # A common heuristic is that decoder blocks have 'self_attn' and 'mlp' attributes.
            if (
                len(module) > 0 and 
                (
                    hasattr(module[0], 'self_attn') or
                    hasattr(module[0], 'mixer')
                )
            ):
                return module, name

    banned=("qkv", "mlp", "attn", "fc1", "fc2", "patch_embed")
    banned = tuple(s.lower() for s in banned)

    def is_banned(name: str) -> bool:
        lname = name.lower()
        return any(b in lname for b in banned)

    best_path, best_mod, best_depth = None, None, -1
    stack = [([], model)]  # DFS stack of (path, module)

    while stack:
        path, mod = stack.pop()

        # Candidate: this module itself is a Sequential and its own name isn't banned
        own_name = path[-1] if path else ""  # name within parent
        if isinstance(mod, nn.Sequential) and not is_banned(own_name) and len(mod) > 2:
            depth = len(path)
            if depth > best_depth:
                best_path, best_mod, best_depth = path, mod, depth

        # Recurse into children unless the child's *name* is banned
        for name, child in mod.named_children():
            if not is_banned(name):
                stack.append((path + [name], child))
    if best_mod is not None:
        print(f"Found decoder layers at path: {'.'.join(best_path + [best_mod._get_name()])} with depth {best_depth}")
        return best_mod, ".".join(best_path + [best_mod._get_name()])
    raise ValueError("Could not find any decoder layers module list or nn.Sequential in the model.")


class BaseFactorization:
    def __init__(self, vision, calib_dataset_name, use_cache=True, blockwise_factorization=False, progressive_compression=False, do_post_calibration="default", calibration_ranks={}):
        self.scaling_dict = {}
        self.input_shapes = {}
        self.vision = vision
        self.use_file_cache = use_cache
        use_local_cache = True
        self.one_shot_factorization = not blockwise_factorization
        if not blockwise_factorization and progressive_compression:
            print("Warning: progressive compression requires blockwise " \
            "factorization. Setting progressive_compression to False.")
            progressive_compression = False
        self.progressive_compression = progressive_compression
        self.static_progressive_compression_ratio = 0.5
        self.calibration_ranks = calibration_ranks
        # file cache requires local cache
        self.use_local_cache = use_local_cache or self.use_file_cache
        self.factorized_layers_cache: dict[str, FactorizedMatrix] = {}
        self.fact_cache_dir = "./.cache/precomputed_SVDs/"
        self.dev = torch.device(torch.cuda.current_device())
        if do_post_calibration not in ["default", "True", "False"]:
            raise ValueError("post_calibration_needed must be 'default', True or False.")
        if do_post_calibration in ["True", "False"]:
            do_post_calibration = True if do_post_calibration == "True" else False
        self._do_post_calibration = do_post_calibration
        self.calib_dataset_name = calib_dataset_name
    
    @property
    def post_search_calibration(self):
        # if the factorization method requires recalibration after search
        # (e.g. because it uses the scaling statistics to determine the rank)
        raise NotImplementedError("Subclasses should implement this attribute.")
        return False if self._do_post_calibration == "default" else self._do_post_calibration
    
    def get_cache_name(self) -> str:
        # function can be overwritten if cache of a child depends on more 
        # than just the model and class name.
        decomp_name = self.__class__.__name__
        if self.progressive_compression:
            decomp_name += "_prog"
        decomp_name += f"_{self.calib_dataset_name}"
        return decomp_name

    def __get_cache_file(self, model) -> str:
        model_name = get_model_name(model=model)
        decomp_name = self.get_cache_name()
        file_name = f'{self.fact_cache_dir}{model_name}_{decomp_name}.pkl'
        return file_name
    
    def is_memory_efficient_mode(self, model):
        if not self.vision and model.is_gradient_checkpointing:
            print("Gradient checkpointing enabled, using train mode.")
            # train mode required for grad checkpointing to work.
            return True
        else:
            return False

    def factorization_computations(self, model, name_omit, calib_data, mixup_fn, white_list=[]):
        # check if the decomposed layer dictionaries already exist
        file_name = self.__get_cache_file(model=model)
        if self.use_file_cache and os.path.exists(file_name):
            print("Loading existing decomposed layer dictionaries...")
            with open(file_name, 'rb') as f:
                self.factorized_layers_cache = pickle.load(f)
            print("Loaded successfully.")
            return
        
        self.compute_memory_efficient = self.is_memory_efficient_mode(model=model)
        # train mode required for grad checkpointing to work.
        model = model.train().to(self.dev) if self.compute_memory_efficient else model.eval().to(self.dev)
        
        if self.one_shot_factorization and not self.compute_memory_efficient:
            self._get_scale_and_factorize_one_shot(
                model, name_omit, calib_data, mixup_fn, white_list=white_list
            )
        else:
            self._get_scale_and_factorize_block_wise(
                model, name_omit, calib_data, mixup_fn, white_list=white_list
            )

        if self.use_file_cache:
            if not os.path.exists(self.fact_cache_dir):
                os.makedirs(self.fact_cache_dir)
            with open(file_name, 'wb') as f:
                pickle.dump(self.factorized_layers_cache, f)
            print(f"Saved decomposed layer dictionaries to {file_name}")
        return
    
    def _get_scale_and_factorize_one_shot(self, model, name_omit, calib_data, mixup_fn, white_list):
        """
        This function computes the factorization in one go, which is fast, but not memory
        efficient. It requires to store all the activations of all layers at once.
        """
        # scaling computations of sub functions, e.g. collect data statistics for whitening.
        self._get_scale_and_factorize_module(
                model=model,
                hook_module=model,  # in one shot, hook the whole model
                calib_data=calib_data,
                name_omit=name_omit,
                name_prefix="",
                white_list=white_list,
                mixup_fn=mixup_fn,
                tqdm_message=f"Gathering ",
            )

    def _get_scale_and_factorize_block_wise(self, model, name_omit, calib_data, mixup_fn, white_list):
        """
        This function computes the factorization block by block, which is slower, but
        more memory efficient. It only requires to store the activations of one layer at a time.
        It optionally to compress the layers that it computed the scores for before moving on to
        the next one, thereby considering the changed statistics of the previous layers.
        """
        layers, mod_name = _find_decoder_layers(model)
        for l_idx, layer in enumerate(layers):
            name_prefix = f"{mod_name}.{l_idx}."

            self._get_scale_and_factorize_module(
                model=model,
                hook_module=layer,
                calib_data=calib_data,
                name_omit=name_omit,
                name_prefix=name_prefix,
                white_list=white_list,
                mixup_fn=mixup_fn,
                tqdm_message=f"Layer {l_idx+1}/{len(layers)}: Gathering ",
            )

            valid_layer_modules = get_valid_layers(layer, name_omit, white_list=white_list)
            for name, module_sub in valid_layer_modules:
                key = f"{name_prefix}{name}"
                if self.progressive_compression:
                    # compress the layer before moving on to the next one.
                    rank, ratio, cntinue = self._get_active_rank(
                        module_sub.weight.shape, key, default_ratio=self.static_progressive_compression_ratio)
                    if cntinue:
                        continue
                    # get the factorization from cache or compute it if not available.
                    factorized_matrix = self.factorize_matrix(
                        module_sub.weight,
                        name=key,
                        rank=rank,
                        ratio=ratio,
                        verbose=False
                    )
                    # update weights of layer to be compressed representation.
                    module_sub.weight.data.copy_(
                        factorized_matrix.mat_l.to(self.dev)
                        @ factorized_matrix.mat_r.to(self.dev)
                    ).cpu()
                if self.compute_memory_efficient:
                    # this removes all cached scalings, activations etc.. However, it 
                    # removes the ability to recompute the decomposition without
                    # rerunning the scaling computations.
                    print("Cleaning up factorization cache to save memory...")
                    self._factorize_cleanup(key)
                    torch.cuda.empty_cache()
    
    def _get_scale_and_factorize_module(
            self, model, hook_module, name_prefix, calib_data,
            name_omit, mixup_fn=None, white_list=[], tqdm_message="Gathering "
        ):
        self._compute_scaling(
            model=model,
            hook_module=hook_module,
            name_prefix=name_prefix,
            calib_data=calib_data,
            name_omit=name_omit,
            mixup_fn=mixup_fn,
            white_list=white_list,
            tqdm_message=tqdm_message + "scalings..."
        )
        torch.cuda.empty_cache()
        # perform factorization and store the results in self.factorized_layers_cache
        # NOTE: factorize_matrix fills up the cache if use_local_cache is True. If it is
        # false, there is no point to call this function here as nothing will be cached.
        if self.use_local_cache:
            copied_modules = get_valid_layers(hook_module, name_omit, white_list=white_list)
            for name, module_sub in tqdm(copied_modules, desc=tqdm_message + "factorizations..."):
                name = f"{name_prefix}{name}"
                # this will always return a ratio of 1.0. Dobi-SVD is the only one, that
                # will ever call this function with a different ratio because it can
                # do progressive compression without the need for doing it block by block.
                rank, ratio, cntinue = self._get_active_rank(module_sub.weight.shape, name)
                if cntinue:
                    continue
                det_weight = module_sub.weight.clone().detach()
                # saves in internal cache, so we don't need to store the returned value.
                _ = self.factorize_matrix(
                    matrix=det_weight,
                    rank=rank, ratio=ratio, name=name, verbose=False
                )
                det_weight = None
                del det_weight

    def _compute_scaling(self, model, hook_module, name_prefix, calib_data, name_omit, mixup_fn=None, white_list=[], tqdm_message="Gathering "):
        print("\nNo scaling method implemented. This is fine" \
        " as long as your method is only requiring weights.")
        pass

    def _get_active_rank(self, shape, key, default_ratio=1.0):
        cntinue = False
        if self.calibration_ranks and key in self.calibration_ranks:
            rank = self.calibration_ranks[key] if isinstance(self.calibration_ranks[key], int) else -1
            ratio = self.calibration_ranks[key] if isinstance(self.calibration_ranks[key], float) else -1
            if ratio != -1 and ratio == 1.0:
                rank = ratio = -1
                cntinue = True
            n, m = shape
            eq_rank = get_eq_rank(n, m)
            if rank != -1 and rank == eq_rank:
                rank = ratio = -1
                cntinue = True
        else:
            rank = -1
            ratio = default_ratio
        return rank, ratio, cntinue

    def factorize_matrix(self, matrix, name, rank=-1, ratio=-1, verbose=True) -> FactorizedMatrix:
        # function that applies the svd technique to a single matrix and return the
        # compressed one (+ meta data?)
        print(f"Factorizing {name} matrix") if verbose else None
        if rank == -1 and ratio == -1:
            print(f"Warning: {name} rank or ratio must be defined!")
            return
        elif rank != -1 and ratio != -1:
            print(
                f"Warning: {name} rank and ratio are both defined, "
                "only one can be used at a time!"
            )
            return
        eq_rank = get_eq_rank(matrix.shape[0], matrix.shape[1])
        if rank == 0:
            rank = eq_rank
        elif ratio != -1:
            # rank = int(np.round(eq_rank * ratio))
            rank = int(eq_rank * ratio)
        elif rank > eq_rank:
            print(f"Warning: {name} rank is larger than equivalent rank!")
            return
        if self.use_local_cache and name in self.factorized_layers_cache:
            fact_mat = self.factorized_layers_cache[name]
        else:
            fact_mat = self._factorize_matrix(
                matrix=matrix, name=name, eq_rank=eq_rank, rank=rank, dev=self.dev, verbose=verbose
            )
        fact_mat.active_rank = rank

        if self.use_local_cache and name not in self.factorized_layers_cache:
            self.factorized_layers_cache[name] = fact_mat
            
        return fact_mat

    def _factorize_matrix(self, matrix, name, eq_rank, rank, dev, verbose=True) -> FactorizedMatrix:
        # function that applies the svd technique to a single matrix and return the
        # compressed one (+ meta data?)
        raise NotImplementedError("Subclasses should implement this method.")

    def _factorize_cleanup(self, name):
        # function that cleans up the internal calculation cache after 
        # a layer has been compressed.
        raise NotImplementedError("Subclasses should implement this method.")

    def create_factorized_sequential(
        self, factorized_matrix: FactorizedMatrix, original_module
    ) -> nn.Module:
        dev = original_module.weight.device
        dtype = original_module.weight.dtype
        module_l = nn.Linear(
            original_module.in_features,
            factorized_matrix.active_rank,
            bias=False,
            dtype=dtype,
        )
        module_r = nn.Linear(
            factorized_matrix.active_rank,
            original_module.out_features,
            bias=False,
            dtype=dtype,
        )
        module_l = module_l.to(dev)
        module_r = module_r.to(dev)

        weight_l, weight_r = factorized_matrix.mat_l, factorized_matrix.mat_r
        module_l.weight.data.copy_(weight_r[: factorized_matrix.active_rank, :].to(dev))
        module_r.weight.data.copy_(weight_l[:, :factorized_matrix.active_rank].to(dev))
        module = weight_l = weight_r = None
        del weight_l, weight_r, module

        torch.cuda.empty_cache()

        # return nn.Sequential(module_l, module_r).to(dev)
        return SeqSVD(module_l, module_r, original_module.bias if hasattr(original_module, "bias")
                      else None).to(dev)
        # return SeqSVDMemViT(in_features=original_module.in_features, out_features=original_module.out_features, rank_r=factorized_matrix.active_rank, bias=(original_module.bias is not None), init_from=original_module).to(dev)

    def factorize_model(self, uncom_model, rank_dict, name_omit, verbose=True, apply_fact=True) -> dict:
        """
        Apply low-rank decomposition to the model in place. Note that name omit
        is supported implicitly as removing or not mentioning something in the
        compression ratio dict will resul in it not being compressed.

        Args:
            name (str): module name
            module (nn.Linear): the given Linear module
            raw_profile (dict): the raw profile of the given module
        """
        print("\nApplying factorization")
        dev = torch.device(torch.cuda.current_device())
        with torch.cuda.device(dev):
            torch.cuda.empty_cache()

        model = uncom_model.eval().cpu()
        # model = uncom_model.eval().to(dev)
        copied_modules = get_valid_layers(model, name_omit, white_list=[])
        print(rank_dict) if verbose else None
        plot_compression_rates(rank_dict) if verbose else None
        for name, module_sub in tqdm(copied_modules):
            # condition for not applying low rank
            if (rank_dict[name] == -1 or rank_dict[name] == 1.0):
                continue

            factorized = self.factorize_matrix(
                matrix=module_sub.weight,
                rank=rank_dict[name] if isinstance(rank_dict[name], int) else -1,
                ratio=rank_dict[name] if isinstance(rank_dict[name], float) else -1,
                name=name,
                verbose=verbose,
            )
            if factorized is None:
                print(f"Skipping {name} as factorization failed.") if verbose else None
                continue
            
            if not apply_fact:
                continue
            svd_replacement = self.create_factorized_sequential(
                factorized_matrix=factorized, original_module=module_sub
            )

            print(f"Applying low rank on {name:^10}, rank {rank_dict[name]}") if verbose else None

            base, localname = model, name
            while "." in localname:
                prefix, localname = localname.split(".", 1)
                base = base.__getattr__(prefix)

            setattr(base, localname, svd_replacement)

            with torch.cuda.device(dev):
                torch.cuda.empty_cache()
            gc.collect()


class SeqSVD(nn.Module):
    def __init__(self, mod_a, mod_b, bias=None):
        super().__init__()
        self.mod_a = mod_a
        self.mod_b = mod_b
        self.bias = bias

    def forward(self, x):
        x = self.mod_a(x)
        x = self.mod_b(x)
        if self.bias is not None:
            x += self.bias
        return x

# NOTE: Sequential Layer to do latency comparison to MemViT
class SeqSVDMemViT(nn.Module):
    def __init__(self, in_features, out_features, rank_r, rank_q=None, bias=True, init_from=None, gy_ratio=0.05):
        """
        Args:
            in_features: input dimension (k)
            out_features: output dimension (m)
            rank_r: low-rank dimension for U,V
            rank_q: optional override for G,Y rank (if None, use gy_ratio)
            bias: whether to include bias
            init_from: optional nn.Linear to initialize from
            gy_ratio: fraction of original W size to allocate for G,Y (default 5%)
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank_r = rank_r

        # Determine rank_q automatically if not provided
        if rank_q is None:
            k, m = in_features, out_features
            rank_q = int(round((gy_ratio * k * m) / (k + m)))
        self.rank_q = rank_q

        # Define factors
        self.U = nn.Parameter(torch.randn(in_features, rank_r) * 0.02)
        self.V = nn.Parameter(torch.randn(out_features, rank_r) * 0.02)
        self.G = nn.Parameter(torch.randn(in_features, rank_q) * 0.02)
        self.Y = nn.Parameter(torch.randn(out_features, rank_q) * 0.02)

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)

        # Optional initialization from pretrained linear
        if init_from is not None:
            with torch.no_grad():
                W = init_from.weight.data  # [out_features, in_features]
                U_svd, S_svd, Vt_svd = torch.linalg.svd(W, full_matrices=False)
                # Fill U,V with top singular components
                self.U.copy_(Vt_svd[:rank_r, :].T)
                self.V.copy_(U_svd[:, :rank_r] * S_svd[:rank_r])
                # Fill G,Y with next components (or repeat if fewer)
                if rank_q > 0:
                    start = rank_r
                    end = min(W.shape[0], start + rank_q)
                    self.G.copy_(Vt_svd[start:end, :].T)
                    self.Y.copy_(U_svd[:, start:end] * S_svd[start:end])

    def forward(self, x):
        # x: [batch_size, in_features]
        out = (x @ self.U) @ self.V.T + (x @ self.G) @ self.Y.T
        if self.bias is not None:
            out = out + self.bias
        return out



class Hookstuff:
    # this is one unified hook to obtain the scalings for everybody.
    def __init__(self, model: nn.Module, name_omit=[], dump_shape=False, white_list=[], name_prefix=""):
        self.model = model
        self.name_omit = name_omit
        self.white_list = white_list
        self.dump_shape = dump_shape
        self.name_prefix = name_prefix

        self.column_scale =  {}
        self.row_scale = {}
        self.activation_cache = {}
        self.buf_1 = {}
        self.buf_2 = {}
        self.layer_trigger = None

        self.profile = {}
        self.profile_gout = {}
        self.input_shape = {}
        self.hooks = []
        # NOTE: if you print/ access a reversed object once, it is gone afterwards.
        self.cp_modules = reversed(
            get_valid_layers(model, self.name_omit, self.white_list)
        )

        self.bw_cp_modules = reversed(
            get_valid_layers(model, self.name_omit, self.white_list)
        )

    def _hook_fn(self, layer_name):
        def get_scaling_mat(module, input, output):
            pass

        return get_scaling_mat
    
    def _bw_hook_fn(self, layer_name):
        def get_scaling_mat_grad(module, ginput, goutput):
            pass

        return get_scaling_mat_grad
    
    def _perturb_hook_fn(self, layer_name):
        def perturb_activations(module, input, output):
            # This function can be used to perturb the weights of the layer
            # For example, you can add noise or apply some transformation
            # Here we just pass the input through without modification
            return output

        return perturb_activations

    def _register_hooks_recursive(self, cp_modules: dict, prefix=""):
        print("Registering forward hooks...")
        for name, layer in cp_modules:
            layer_name = self.name_prefix + name
            hook = layer.register_forward_hook(self._hook_fn(layer_name))
            self.hooks.append(hook)

    def _register_bw_hooks_recursive(self, cp_modules: dict, prefix=""):
        print("Registering backward hooks...")
        for name, layer in cp_modules:
            layer_name = self.name_prefix + name
            bw_hook = layer.register_full_backward_hook(self._bw_hook_fn(layer_name))
            self.hooks.append(bw_hook)

    def _register_bw_hooks_singular(self, cp_modules: dict, prefix=""):
        # Iterate through the layers
        for name, layer in cp_modules:
            # Remove the previous hook if it exists
            if self.hooks:
                prev_hook = self.hooks.pop()
                prev_hook.remove()
            
            # Register a new backward hook for the current layer
            layer_name = self.name_prefix + name
            bw_hook = layer.register_full_backward_hook(self._bw_hook_fn(layer_name))
            self.hooks.append(bw_hook)
            
            # Yield the current hook
            yield name, layer

    def _register_hooks_singular(self, cp_modules: dict, prefix=""):
        # Iterate through the layers
        for name, layer in cp_modules:
            # Remove the previous hook if it exists
            if self.hooks:
                prev_hook = self.hooks.pop()
                prev_hook.remove()
            
            # Register a new backward hook for the current layer
            layer_name = self.name_prefix + name
            hook = layer.register_forward_hook(self._perturb_hook_fn(layer_name))
            bw_hook = layer.register_full_backward_hook(self._bw_hook_fn(layer_name))
            self.hooks.append(hook)
            self.hooks.append(bw_hook)
            
            # Yield the current hook
            yield name, layer

    def attach_hooks(self):
        self._register_hooks_recursive(self.cp_modules)

    def attach_bw_hooks(self):
        self._register_bw_hooks_recursive(self.bw_cp_modules)

    def clear_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()


class ShapeHook(Hookstuff):
    def _hook_fn(self, layer_name, last_feat=False):
        def get_intermediate_shapes(module, input, output):
            # x = input[0].detach().clone().double()
            x = input[0].detach().float()
            if x.dim() > 3:
                x = x.reshape(x.shape[0], -1, x.shape[-1])
            elif x.dim() == 2:
                x = x.unsqueeze(0)
            self.input_shape[layer_name] = list(x.shape)
            self.input_shape[layer_name].extend([module.out_features, 0])
            del input, output, module, x
            return
        return get_intermediate_shapes


import matplotlib.pyplot as plt

def plot_compression_rates(rank_dict):
    layer_names = list(rank_dict.keys())
    compression_rates = [rate for name, rate in rank_dict.items()]
    if not isinstance(compression_rates[0], float):
        return

    plt.figure(figsize=(10, 6))
    plt.bar(layer_names, compression_rates, color='skyblue')
    plt.xlabel('Layer Names')
    plt.ylabel('Compression Rate')
    plt.title('Compression Rates by Layer')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()
    plt.savefig('compression_rates.png', dpi=300)