import torch
from typing import Dict, Union, List, Tuple
from metrics import flops
import math

TOL = 1e-8

def global_sparsity(module: torch.nn.Module, param_type: Union[str, None] = None) -> float:
    """Returns the global sparsity of module (mostly of entire model)"""
    n_total, n_zero = 0., 0.
    param_list = ['weight', 'bias'] if not param_type else [param_type]
    for name, module in module.named_modules():
        for param_type in param_list:
            if hasattr(module, param_type) and not isinstance(getattr(module, param_type), type(None)):
                param = getattr(module, param_type)
                n_param = float(torch.numel(param))
                n_zero_param = float(torch.sum(param == 0))
                n_total += n_param
                n_zero += n_zero_param
    return float(n_zero) / n_total

def per_layer_sparsity(model: torch.nn.Module) -> Dict[str, float]:
    """Returns the per-layer-sparsity of model"""
    per_layer_sparsity_dict = dict()
    param_type = 'weight'   # Only compute for weights, since we do not sparsify biases
    for name, submodule in model.named_modules():
        if hasattr(submodule, param_type) and not isinstance(getattr(submodule, param_type), type(None)):
            if name in per_layer_sparsity_dict:
                continue
            per_layer_sparsity_dict[name] = global_sparsity(submodule, param_type=param_type)
    return per_layer_sparsity_dict

def compression_rate(module: torch.nn.Module, param_type: Union[str, None] = None) -> float:
    """Returns the global compression rate of module (mostly of entire model)
    defined as |W|/||W||_0 which is equivalent to 1/1-sparsity for sparsity < 1
    """
    s = global_sparsity(module, param_type=param_type)
    if s == 1:
        # The global sparsity is 1, hence the compression rate is infinitely large.
        return {}
    return 1./(1-s)

def almost_compression_rate(module: torch.nn.Module, param_type: Union[str, None] = None, tol: float = TOL) -> float:
    """Returns the global almost compression rate (all w s.t. |w| <= tol) of module (mostly of entire model)"""
    s = global_almost_sparsity(module=module, param_type=param_type, tol=tol)
    if s == 1:
        # The global sparsity is 1, hence the compression rate is infinitely large.
        return {}
    return 1. / (1 - s)

def global_almost_sparsity(module: torch.nn.Module, param_type: Union[str, None] = None, tol: float = TOL) -> float:
    """Returns the global almost sparsity (all w s.t. |w| <= tol) of module (mostly of entire model)"""
    n_total, n_zero = 0., 0.
    param_list = ['weight', 'bias'] if not param_type else [param_type]
    for name, module in module.named_modules():
        for param_type in param_list:
            if hasattr(module, param_type) and not isinstance(getattr(module, param_type), type(None)):
                param = getattr(module, param_type)
                n_param = float(torch.numel(param))
                n_zero_param = float(torch.sum(torch.abs(param) <= tol))
                n_total += n_param
                n_zero += n_zero_param
    return float(n_zero) / n_total

@torch.no_grad()
def get_flops(model, x_input):
    return flops.flops(model, x_input)

@torch.no_grad()
def get_theoretical_speedup(n_flops: int, n_nonzero_flops: int) -> float:
    if n_nonzero_flops == 0:
        # Would yield infinite speedup
        return {}
    return float(n_flops)/n_nonzero_flops

@torch.no_grad()
def get_parameter_count(model: torch.nn.Module) -> Tuple[int, int]:
    n_total = 0
    n_nonzero = 0
    param_list = ['weight', 'bias']
    for name, module in model.named_modules():
        for param_type in param_list:
            if hasattr(module, param_type) and not isinstance(getattr(module, param_type), type(None)):
                p = getattr(module, param_type)
                n_total += int(p.numel())
                n_nonzero += int(torch.sum(p != 0))
    return n_total, n_nonzero

@torch.no_grad()
def get_active_inactive_ratio(model: torch.nn.Module, comparison_model: torch.nn.Module) -> Tuple[float, float]:
    n_total = 0
    n_active = 0

    # We have to map parameters by name as they can be in different ordering
    param_list = ['weight', 'bias']
    paramDict = {model:{param_type:{} for param_type in param_list}, comparison_model:{param_type:{} for param_type in param_list}}
    for m in [model, comparison_model]:
        for name, module in m.named_modules():
            for param_type in param_list:
                if hasattr(module, param_type) and not isinstance(getattr(module, param_type), type(None)):
                    p = getattr(module, param_type)
                    paramDict[m][param_type][name] = p

    for param_type in paramDict[model].keys():
        for param_name in paramDict[model][param_type].keys():
            p = paramDict[model][param_type][param_name]
            p_orig = paramDict[comparison_model][param_type][param_name]
            n_total += int(p.numel())
            n_active += int(torch.sum(torch.abs(p) >= torch.abs(p_orig)))
    active_ratio = float(n_active)/n_total
    return active_ratio, 1-active_ratio

@torch.no_grad()
def get_distance_to_pruned(model: torch.nn.Module, sparsity: float) -> Tuple[float, float]:
    prune_vector = torch.cat([module.weight.view(-1) for name, module in model.named_modules() if hasattr(module, 'weight')
                               and not isinstance(module.weight, type(None)) and not isinstance(module,
                                                                                                torch.nn.BatchNorm2d)])
    n_params = float(prune_vector.numel())
    k = int((1-sparsity)*n_params)
    total_norm = float(torch.norm(prune_vector, p=2))
    pruned_norm = float(torch.norm(torch.topk(torch.abs(prune_vector), k=k).values, p=2))
    distance_to_pruned = math.sqrt(abs(total_norm**2 - pruned_norm**2))
    rel_distance_to_pruned = distance_to_pruned/total_norm if total_norm > 0 else 0
    return distance_to_pruned, rel_distance_to_pruned

@torch.no_grad()
def get_distance_to_origin(model: torch.nn.Module) -> float:
    prune_vector = torch.cat([module.weight.view(-1) for name, module in model.named_modules() if hasattr(module, 'weight')
                               and not isinstance(module.weight, type(None)) and not isinstance(module,
                                                                                                torch.nn.BatchNorm2d)])
    return float(torch.norm(prune_vector, p=2))
