import torch
import torch.nn as nn
import torch.nn.utils.prune as prune


def pruning_model(model, amount, state_only=False):

    parameters_to_prune =[]
    if not state_only:
        for name, m in model.named_modules():
            # print(f'todo: {isinstance(m, nn.LSTMCell)}')
            if isinstance(m, nn.LSTMCell):
                # print(f'\n\n{ name}\n\n')
                # print(f'\n\n{"recurs" in name:}\n\n')

                parameters_to_prune.append((m, 'weight_ih'))
                parameters_to_prune.append((m, 'weight_hh'))
                parameters_to_prune.append((m, 'bias_ih'))
                parameters_to_prune.append((m, 'bias_hh'))
            if isinstance(m, nn.Conv2d):
                parameters_to_prune.append((m, 'weight'))
            elif isinstance(m, nn.Linear):
                parameters_to_prune.append((m, 'weight'))
                parameters_to_prune.append((m, 'bias'))
    else:
        pass


    print('Global pruning module number = {}'.format(len(parameters_to_prune)))
    parameters_to_prune = tuple(parameters_to_prune)
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount,
    )

def prune_model_custom(model, mask_dict):

    for name,m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            prune.CustomFromMask.apply(m, 'weight', mask=mask_dict[name+'.weight_mask'])
        elif isinstance(m, nn.LSTMCell):
            prune.CustomFromMask.apply(m, 'weight_ih', mask=mask_dict[name+'.weight_ih_mask'])
            prune.CustomFromMask.apply(m, 'weight_hh', mask=mask_dict[name+'.weight_hh_mask'])
            prune.CustomFromMask.apply(m, 'bias_ih', mask=mask_dict[name+'.bias_ih_mask'])
            prune.CustomFromMask.apply(m, 'bias_hh', mask=mask_dict[name+'.bias_hh_mask'])
        elif isinstance(m, nn.Linear):
            prune.CustomFromMask.apply(m, 'weight', mask=mask_dict[name+'.weight_mask'])
            prune.CustomFromMask.apply(m, 'bias', mask=mask_dict[name+'.bias_mask'])

def remove_prune(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            prune.remove(m, 'weight')
        elif isinstance(m, nn.LSTMCell):
            prune.remove(m, 'weight_ih')
            prune.remove(m, 'weight_hh')
            prune.remove(m, 'bias_ih')
            prune.remove(m, 'bias_hh')
        elif isinstance(m, nn.Linear):
            prune.remove(m, 'weight')
            prune.remove(m, 'bias')

def extract_mask(model_dict):
    new_dict = {}
    for key in model_dict.keys():
        if 'mask' in key:
            new_dict[key] = model_dict[key]

    return new_dict

def check_sparsity(model):
    zero_count = 0
    all_count = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            zero_count += float(torch.sum(m.weight == 0))
            all_count += float(m.weight.nelement())
        elif isinstance(m, nn.LSTMCell):

            zero_count += float(torch.sum(m.weight_ih == 0))
            all_count += float(m.weight_ih.nelement())
            zero_count += float(torch.sum(m.weight_hh == 0))
            all_count += float(m.weight_hh.nelement())
            zero_count += float(torch.sum(m.bias_ih == 0))
            all_count += float(m.bias_ih.nelement())
            zero_count += float(torch.sum(m.bias_hh == 0))
            all_count += float(m.bias_hh.nelement())

        elif isinstance(m, nn.Linear):
            zero_count += float(torch.sum(m.weight == 0))
            all_count += float(m.weight.nelement())
            zero_count += float(torch.sum(m.bias == 0))
            all_count += float(m.bias.nelement())

    remain_weight = 100*(1-zero_count/all_count)
    print('remain weight = {:.4f}%'.format(remain_weight))

    return remain_weight






import numpy as np 

def print_dict(model_dict):
    for key in model_dict.keys():
        print(key, model_dict[key].shape)

def count_pruning_weight_rate(model, rnn_type=nn.LSTM, bidirectional=True, Linear=False):

    all_count = 0
    overall_number = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            all_count += float(m.weight.nelement())
        elif isinstance(m, rnn_type):
            all_count += float(m.weight_ih_l0.nelement())
            all_count += float(m.weight_hh_l0.nelement())
            if bidirectional:
                all_count += float(m.weight_ih_l0_reverse.nelement())
                all_count += float(m.weight_hh_l0_reverse.nelement())
        elif isinstance(m, nn.Linear):
            if Linear:
                all_count += float(m.weight.nelement())

    for key in model.state_dict().keys():
        overall_number += model.state_dict()[key].nelement()

    print('parameters of pruning scope = {}'.format(all_count))
    print('parameters overall = {}'.format(overall_number))
    print('rate = {:.6f}'.format(all_count/overall_number))

def check_mask(mask_dict):
    zero_count = 0
    all_count = 0 
    for key in mask_dict.keys():
        zero_count += float(torch.sum(mask_dict[key] == 0))
        all_count += float(mask_dict[key].nelement())
    print(100*(1-zero_count/all_count))

def zero_rate(tensor):
    zero = float(torch.sum(tensor == 0))
    all_count = float(tensor.nelement())
    return 100*zero/all_count

def check_different(model, mask_dict, rnn_type=nn.LSTM, bidirectional=True, Linear=False):

    for name,m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            print(name+'.weight', 'mask rate = {:.2f}'.format(zero_rate(mask_dict[name+'.weight_mask'])), 'module rate = {:.2f}'.format(zero_rate(m.weight)))
        elif isinstance(m, rnn_type):
            print(name+'.weight_ih_l0', 'mask rate = {:.2f}'.format(zero_rate(mask_dict[name+'.weight_ih_l0_mask'])), 'module rate = {:.2f}'.format(zero_rate(m.weight_ih_l0)))
            print(name+'.weight_hh_l0', 'mask rate = {:.2f}'.format(zero_rate(mask_dict[name+'.weight_hh_l0_mask'])), 'module rate = {:.2f}'.format(zero_rate(m.weight_hh_l0)))

            if bidirectional:
                print(name+'.weight_ih_l0_reverse', 'mask rate = {:.2f}'.format(zero_rate(mask_dict[name+'.weight_ih_l0_reverse_mask'])), 'module rate = {:.2f}'.format(zero_rate(m.weight_ih_l0_reverse)))
                print(name+'.weight_hh_l0_reverse', 'mask rate = {:.2f}'.format(zero_rate(mask_dict[name+'.weight_hh_l0_reverse_mask'])), 'module rate = {:.2f}'.format(zero_rate(m.weight_hh_l0_reverse)))

        elif isinstance(m, nn.Linear):
            if Linear:
                print(name+'.weight', 'mask rate = {:.2f}'.format(zero_rate(mask_dict[name+'.weight_mask'])), 'module rate = {:.2f}'.format(zero_rate(m.weight)))


if __name__ == 'main':
    labels = list(np.random.randint(0,100,1000))

    ################### example ###########################
    model = DeepSpeech(rnn_hidden_size=1024,
                                nb_layers=5,
                                labels=labels,
                                rnn_type=nn.GRU,
                                audio_conf=1,
                                bidirectional=True)
    print(model)


    #global magnitude pruning
    pruning_model(model, amount)

    #extract mask weight in model.state_dict()
    mask_dict = extract_mask(model.state_dict())


    #remove pruning for the purpose of loading original weight
    remove_prune(model, rnn_type=nn.GRU, bidirectional=True, Linear=True)

    #rewind to original weight
    #be carefull:  don't let original_weight and model.state_dict() share the same memory
    model.load_state_dict(original_weight)


    #pruning with custom mask
    prune_model_custom(model, mask_dict, rnn_type=nn.GRU, bidirectional=True, Linear=True)


    #check current sparsity
    check_sparsity(model)

    # re-trian some epochs then pruning (from Line 185)
















