import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import copy


prune_crit_dict = {
    'magnitude': prune.L1Unstructured, 
    'random': prune.RandomUnstructured
}

class Augmenter:
    def __init__(self, config, copy=True, aux_model=None):
        self.aux_model = aux_model

        self.prune_perc = config.prune_perc
        self.prune_method = config.prune_method
        self.prune_criterion = prune_crit_dict[config.prune_criterion]
        
        self.prune_layer_th = config.prune_layer_th
        self.prune_inc_perc = config.prune_inc_perc
        
        self.reinit_method = config.reinit_method
        self.reinit_layer_th = config.reinit_layer_th

        self.copy = copy
        
    
    def augment(self, encoder):
        if self.copy:
            encoder = copy.deepcopy(encoder)
        
        if self.prune_method != '-':
            model = self.prune(encoder)
        if self.reinit_method != '-':
            model = self.reinit(encoder)
        return model
    
    def prune(self, model):
        parameters_to_prune = []
        for _, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                parameters_to_prune.append((module, "weight"))

        if self.prune_method == 'global':
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method=self.prune_criterion,
                amount=self.prune_perc,
            )
            
        elif self.prune_method == 'threshold':
            parameters_to_prune = parameters_to_prune[self.prune_layer_th:]
            if self.prune_layer_th < 10:    
                print(f'pruning the last {self.prune_layer_th} conv layers')
                parameters_to_prune = parameters_to_prune[-1 * self.prune_layer_th:]
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method=self.prune_criterion,
                amount=self.prune_perc,
            )
        
        for module, _ in parameters_to_prune:
            if isinstance(module, torch.nn.Conv2d):
                prune.remove(module, "weight")
        return model
    
    def reinit(self, model):
        reinit_params = []
        for _, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                reinit_params.append((module))

        if self.reinit_method == 'threshold':
            reinit_params = reinit_params[self.reinit_layer_th:]
            for module in reinit_params:
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
        
        return model
