import torch_pruning as tp
from torch_pruning import ops
from torch_pruning.pruner import BasePruningFunc
import torch

def prune_tp(module:torch.nn.Module,prune_indices,target='out'):
    try:
        pruning_func = tp.function.PrunerBox[ops.module2type(module)]
    except:
        print(f'{type(module)} is not supported by torch_pruner. Possibly it is not a leaf node')
        return module
    if target == 'out':
        pruning_func.prune_out_channels(module,prune_indices.cpu().numpy())
    elif target == 'in':
        pruning_func.prune_in_channels(module,prune_indices.cpu().numpy())
    else:
        raise NotImplementedError()
    
    module.register_buffer(f'{target}_mask',prune_indices)
    return module

def prune_loader_tp(module:torch.nn.Module,state_dict:dict):

    for name, weight in state_dict.items():
        if 'delta' in name:
            continue
        if 'relu1.out_mask' in name:
            continue
        if 'relu2.out_mask' in name:
            continue
        parent_module_name = '.'.join(name.split('.')[:-1])
        parent_module = module.get_submodule(parent_module_name)
        attr_name = name.split('.')[-1]
        if not hasattr(parent_module,attr_name):
            continue
        model_weight =  getattr(parent_module,attr_name)
        if model_weight is None:
            setattr(parent_module,attr_name,torch.nn.Parameter(data =weight.data,requires_grad=True))
            model_weight = getattr(parent_module,attr_name)
        # try:
        #     model_weight=module.get_parameter(name)
        # except AttributeError:
        #     model_weight = module.get_buffer(name)
        if model_weight.shape == weight.shape:
            continue
        if model_weight.shape[0]!=weight.shape[0]: # out_prune
            # parent_module = module.get_submodule('.'.join(name.split('.')[:-1]))
            prune_nodes = model_weight.shape[0]-weight.shape[0]
            prune_tp(parent_module,torch.tensor(range(prune_nodes)),'out')
            # del(parent_module.out_mask)
            
            # continue
        if model_weight.dim() > 1 and model_weight.shape[1]!=weight.shape[1]: # in_prune
            # parent_module = module.get_submodule('.'.join(name.split('.')[:-1]))
            prune_nodes = model_weight.shape[1]-weight.shape[1]
            prune_tp(parent_module,torch.tensor(range(prune_nodes)),'in')

    missing_keys, unexpected_keys = module.load_state_dict(state_dict,strict=False)
    if len(unexpected_keys) > 0:
        for name in unexpected_keys:
            if not name.endswith('out_mask') and not  name.endswith('delta'):
                raise KeyError(f'unexpected keys found: {unexpected_keys}')
    if len(missing_keys) > 0:
        print(f'[WARN] missing keys found while loading model: {[x for x in missing_keys if not x.endswith("mask")]}')
    return module