from bypass.train import BypassTrainerBase
from bypass.core.activation import ActivationForBypass
import torch_pruning as tp
from torch_pruning.pruner import BasePruningFunc
import torch

class BypassActivationPruner(BasePruningFunc):

    def prune_out_channels(self, layer: ActivationForBypass, idxs: list):
        if layer.num_parameters == 1: # prune nothing
            return layer
        keep_idxs = list(set(range(layer.num_parameters)) - set(idxs))
        keep_idxs.sort()
        layer.num_parameters = layer.num_parameters-len(idxs)
        layer.delta = self._prune_parameter_and_grad(layer.delta, keep_idxs, 0)
        return layer

    prune_in_channels = prune_out_channels

    # def prune_in_channels(self, layer:nn.Module, idxs: Sequence[int]) -> nn.Module:
    #    return self.prune_out_channels(layer=layer, idxs=idxs)

    def get_out_channels(self, layer):
        if layer.num_parameters == 1:
            return None
        else:
            return layer.num_parameters

    def get_in_channels(self, layer):
        return self.get_out_channels(layer=layer)

if __name__ == '__main__':
    trainer=BypassTrainerBase.load_from("/workspace/jaeheun_MildPruning/save/mild_pruning_W/cifar10_2c2d/20230913-081426/save/model_opt2_end.pt")

    DG = tp.DependencyGraph()
    DG.register_customized_layer(ActivationForBypass,BypassActivationPruner())
    DG.build_dependency(trainer.model, example_inputs=torch.randn(1,*trainer.configs.input_shape).cuda())
    print(1)
