import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
from tqdm import tqdm

from pruning import prune_by_mask, remove_bias, activation_to_relu, same_mask, get_mask_by_saliency


def get_mask_saliency_AAS(args, saliency, origin_model, batch, device, data_shape, num_classes, \
                                                num_alive, alive_idx, alive_mask, masked=None):
    '''
    Get mask for All-alive subnetwork with given saliency scores.
    '''

    mask_kwargs = {
        'batch': batch,
        'device': device,
        'data_shape': data_shape,
        'num_classes': num_classes,
        'num_alive': num_alive,
        'alive_idx': alive_idx,
        'alive_mask': alive_mask
    }

    mask = get_mask_by_saliency(args, saliency, masked=masked, **mask_kwargs)
    processed_model = copy.deepcopy(origin_model)

    # Convert activations to ReLU
    activation_to_relu(processed_model)

    # Weight to absolute value
    for name, p in processed_model.named_parameters():
        p.data = torch.abs(p.data)

    processed_model.eval()
    for name, m in processed_model.named_modules():
        if 'bn' in name:
            m.running_mean = m.running_mean * 0
            m.running_var = torch.ones(m.running_var.shape).float().to(m.running_var.device)
        
    # Remove biases
    remove_bias(processed_model)

    tmp_mask = {}
    dead_mask = None
    while not same_mask(tmp_mask, mask):
        tmp_mask = copy.deepcopy(mask)
        
        # Pre-process model (make all weights to one, remove bias)
        model = copy.deepcopy(processed_model)

        # Pruning with specific method
        prune_by_mask(model, mask)
        
        # Take arbitrary gradients
        model.zero_grad()
        one_data = torch.ones((1,) + data_shape)
        assert torch.numel(one_data) == torch.nonzero(one_data).size(0)
        
        one_target = torch.ones((1,num_classes)) * -1

        one_data, one_target = one_data.to(device), one_target.to(device)
        output = model(one_data)
        loss = F.mse_loss(output, one_target, reduction='none')
        loss.sum().backward()

        # Get zero-row/col gradient
        grad_mask = {}
        for name, p in model.named_parameters():
            grad = p.grad.data.cpu().numpy()
            grad = np.where(grad == 0, 0, 1)

            # Get grad mask of bn/bias
            if 'bn' in name or 'bias' in name:
                grad_mask[name] = grad
                continue

            # make 2-dimension
            grad = grad.reshape(grad.shape[0], grad.shape[1], -1)
            grad = grad.sum(-1)

            # ony leave all-zero row/col
            grad = np.sum(grad, axis=1, keepdims=True).dot(np.sum(grad, axis=0, keepdims=True))
            grad = np.where(grad == 0, 0, 1)
            grad_mask[name] = grad

        # Adjust shape of mask
        for name, p in model.named_parameters():
            if 'bias' in name: continue
            if 'bn' in name: continue

            if name in grad_mask:
                w = p.data.cpu().numpy()
                grad_mask[name] = grad_mask[name].reshape(grad_mask[name].shape[:2] + \
                                            (1,) * (len(w.shape) - len(grad_mask[name].shape)))

        # Make new mask
        if dead_mask is not None:
            for name in grad_mask:
                dead_mask[name] = dead_mask[name] * grad_mask[name]
        else:
            dead_mask = {}
            for name in grad_mask:
                dead_mask[name] = grad_mask[name]

        # Remove original models' pruned weights
        for name in dead_mask:
            if name in alive_mask:
                dead_mask[name] = alive_mask[name] * dead_mask[name]

        mask = get_mask_by_saliency(args, saliency, masked=dead_mask, **mask_kwargs)

    model.zero_grad()
    processed_model.zero_grad()
    del model
    del processed_model

    return mask