import torch
import torch.nn as nn
import numpy as np
from models import Swish, MemoryEfficientSwish


def same_mask(m1, m2):
    if m1 is None: return False
    if m1.keys() != m2.keys(): return False
    return all(np.array_equal(m1[key], m2[key]) for key in m1)


def remove_bias(model, ignore=None):
    for name, p in model.named_parameters():
        if ignore is not None and ignore in name: continue
        if 'bias' in name:
            tensor = p.data.cpu().numpy()
            weight_dev = p.device
            p.data = torch.from_numpy(np.zeros(tensor.shape)).float().to(weight_dev)
    return


def get_alive_idx(args, model):
    # Get aive_idx / num_alive
    num_alive = 0
    alive_idx = {}
    alive_mask = {}

    # Set alive index / number for pruning
    for name, p in model.named_parameters():
        tensor = p.data.cpu().numpy()
        
        num_alive += tensor.size
        alive_idx[name] = np.full(tensor.shape, True)

        if args.mode == 'prune' and args.prune_method == 'mp':
            alive_mask[name] = np.where(tensor == 0, 0, 1)
        else:
            alive_mask[name] = np.ones(tensor.shape)

    return num_alive, alive_idx, alive_mask


def prune_by_mask(model, mask, ignore=None):
    '''
    Pruning using mask
    '''
    for name, p in model.named_parameters():
        if ignore is not None and ignore in name:
            continue

        if 'module.' in name and not 'module.' in list(mask.keys())[0]:
            name = name[7:]

        if name in mask:
            tensor = p.data.cpu().numpy()
            weight_dev = p.device
            p.data = torch.from_numpy(tensor * mask[name]).float().to(weight_dev)
    return


def activation_to_relu(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU6) or isinstance(child, Swish) or isinstance(child, MemoryEfficientSwish):
            # print("change: ", child)
            setattr(model, child_name, nn.ReLU(inplace=True))
        else:
            activation_to_relu(child)