import torch
import torch.nn as nn

def flatten_nn(model):
    weights = []
    for name, param in model.named_parameters():
        weights.append(param.flatten())
    return torch.cat(weights, dim=0)


def unflatten_nn(model, weights):
    weight_dict = model.state_dict()
    w_pointer = 0
    for name, param in model.named_parameters():
        len_w = param.data.numel()
        weight_dict[name] = weights[w_pointer:w_pointer +
                                        len_w].view(param.shape).float().to(weights.device, non_blocking=True)
        w_pointer += len_w

    return weight_dict


def get_n_params(model):
    D = 0
    for name, param in model.named_parameters():
        D += param.numel()
        
    return D


def get_avg_and_std(param_list):
    print('==> Calculating avg and std..')
    param_mean = torch.mean(param_list,dim=0,keepdim=False)
    param_std = torch.std(param_list,dim=0,keepdim=False)
    return param_mean, param_std

def mask_by_proportion(model, variation_vector, proportion):
    print('==> masking..')
    idx = torch.argsort(variation_vector, descending=True)[
        :int(proportion*len(variation_vector))]
    idx = idx.sort()[0]
    mask_vec = torch.zeros_like(variation_vector)
    mask_vec[idx] = 1.

    # define layer-wise masks based on this threshold, to then prune weights with it
    mask = unflatten_nn(model, mask_vec)
    return mask, idx


def mask_by_value(model, variation_vector, value):
    print('==> masking..')
    mask_vec = variation_vector > value
    mask_vec = mask_vec.float()
    idx = torch.argsort(mask_vec, descending=True)[:int(sum(mask_vec))]
    idx = idx.sort()[0]

    mask = unflatten_nn(model, mask_vec)
    return mask, idx


def get_masked_value(param_list, idx):
    return param_list[:, idx]

def print_nonzeros(mask):
    """ Print table of zeros and non-zeros count """

    remain = total = 0
    for name in mask:
        nz_params = len(mask[name].nonzero(as_tuple=False))
        total_params = mask[name].numel()
        remain += nz_params
        total += total_params
        remaining = f"{nz_params:7} / {total_params:7} ({100 * nz_params / total_params:6.2f}%)"
        print(
            f"{name:35s} | remaining = {remaining} | pruned = {total_params - nz_params:7d} | shape = {mask[name].size()}")
    compr_rate = f"{total/remain:10.2f}x  ({100 * (total-remain) / total:6.2f}% pruned)"
    print("====================================================================================================")
    print(
        f"remaining: {remain}, pruned: {total - remain}, total: {total}, compression rate: {compr_rate}")