from ..utils.getting_modules import get_module_by_name
from ..utils.tensor_sorting import sort_layer
import torch.nn as nn


class SortedWeightsInit:
    def __init__(self, layer_names):
        self.layer_names = layer_names

    def apply(self, model: nn.Module):
        for name in self.layer_names:
            layer = get_module_by_name(module=model, name=name)
            sorted_layer = sort_layer(layer=layer)
            setattr(model, name, sorted_layer)
        return model
