from typing import Iterable, List, Tuple, Union
import math
import logging
import torch

__all__ = [
    "count_params",
    "clip_grad_norm",
    "clip_grad_value",
    "compute_param_norm",
    "param_log",
    "split_params_for_weight_decay",
]

PARAMETERS_DTYPE = Union[torch.Tensor, Iterable[torch.Tensor]]

logger = logging.getLogger("pado")


@torch.no_grad()
def count_params(parameters: PARAMETERS_DTYPE) -> Tuple[int, int]:
    """
    Count number of parameters.
    :param parameters:          iterable of parameters (List, Tuple, Iter, ...)
    :return:
            (#tensors, #elements)
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    count: int = 0
    count_elements: int = 0
    for p in parameters:
        p: torch.Tensor
        if not p.requires_grad:
            continue
        count += 1
        count_elements += p.numel()
    return count, count_elements


@torch.no_grad()
def compute_param_norm(parameters: PARAMETERS_DTYPE,
                       norm_type: float = 2.0) -> torch.Tensor:
    """
    Compute parameter norm.
    :param parameters:          iterable of parameters (List, Tuple, Iter, ...)
    :param norm_type:           default l2 norm
    :return:
            (1,)
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.requires_grad]
    if len(parameters) == 0:
        return torch.as_tensor(0., dtype=torch.float32)

    device = parameters[0].device
    total_norm = torch.norm(torch.stack([torch.norm(p, norm_type).to(device) for p in parameters]), norm_type)
    return total_norm


def clip_grad_norm(parameters: PARAMETERS_DTYPE,
                   clip_value: float = 0.0,
                   norm_type: float = 2.0) -> torch.Tensor:
    """
    Compute and clip gradient by norm. Gradients are modified in-place.
    :param parameters:          iterable of parameters (List, Tuple, Iter, ...)
    :param clip_value:          maximum norm value to clip
    :param norm_type:           type of norm, default l2 norm
    :return:
            (1,)                norm or norm
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    clip_value = float(clip_value)
    norm_type = float(norm_type)
    if len(parameters) == 0:
        return torch.tensor(0., dtype=torch.float32)

    device = parameters[0].grad.device
    if norm_type == math.inf:
        norms = [p.grad.detach().abs().max().to(device) for p in parameters]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
                                norm_type)
    if torch.isnan(total_norm) or torch.isinf(total_norm):
        for p in parameters:  # track down what caused the nan/inf
            if torch.any(torch.isnan(p.grad)) or torch.any(torch.isnan(p.grad)):
                p_name = p._name if hasattr(p, "_name") else ""
                logger.warning(f"Gradient of {p_name}({tuple(p.shape)}) is NaN or Inf.")
    else:  # not nan/inf
        if clip_value > 0.0:
            clip_coefficient = clip_value / (total_norm + 1e-6)
            if clip_coefficient < 1:
                for p in parameters:
                    p.grad.detach().mul_(clip_coefficient.to(p.grad.device))
    return total_norm


def clip_grad_value(parameters: PARAMETERS_DTYPE,
                    clip_value: float = 0.0) -> torch.Tensor:
    """
    Compute and clip gradient by value. Gradients are modified in-place.
    :param parameters:          iterable of parameters (List, Tuple, Iter, ...)
    :param clip_value:          maximum norm value to clip
    :return:
            (1,)                average absolute value
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    clip_value = float(clip_value)
    if len(parameters) == 0:
        return torch.tensor(0., dtype=torch.float32)

    device = parameters[0].grad.device
    total_amplitude = torch.mean(torch.stack([p.grad.detach().abs().sum().to(device) for p in parameters]))
    if torch.isnan(total_amplitude) or torch.isinf(total_amplitude):
        for p in parameters:  # track down what caused the nan/inf
            if torch.any(torch.isnan(p.grad)) or torch.any(torch.isnan(p.grad)):
                p_name = p._name if hasattr(p, "_name") else ""
                logger.warning(f"Gradient of {p_name}({tuple(p.shape)}) is NaN or Inf.")
    else:  # not nan/inf
        if clip_value > 0.0:
            for p in parameters:
                p.grad.data.clamp_(min=-clip_value, max=clip_value)

    return total_amplitude


def param_log(model: torch.nn.Module) -> str:
    s = "-" * 72 + "\n"
    s += "Parameters:\n"
    for param_name, param in model.named_parameters():
        s += f"... {param_name:<60}\t{str(tuple(param.shape)):<20}(std: {torch.std(param, unbiased=False).item():.3f})\n"
    s += "-" * 72 + "\n"
    return s


def split_params_for_weight_decay(parameters: PARAMETERS_DTYPE) -> List:
    no_decay_params = []
    base_params = []
    for p in parameters:
        if p.ndim <= 1:
            no_decay_params.append(p)
        else:
            base_params.append(p)
    parameters_for_opt = [
        {"params": base_params},
        {"params": no_decay_params, "weight_decay": 0.0}
    ]
    return parameters_for_opt
