import torch
import torch.nn as nn
import torch.nn.functional as F


def step(x):
    return torch.heaviside(x, torch.tensor(0.0))


def bincount(x: torch.LongTensor, n_bins: int, dim=-1):
    # bincount along any dimension
    # c.f. https://github.com/pytorch/pytorch/issues/32306
    assert x.dtype == torch.int64, "only integral (int64) tensor is supported"
    cnt = x.new_zeros(*[n_bins if d == dim else x.shape[d] for d in range(len(x.shape))])
    # no scalar or broadcasting `src` support yet
    # c.f. https://github.com/pytorch/pytorch/issues/5740
    return cnt.scatter_add_(dim=dim, index=x, src=x.new_ones(()).expand_as(x))


def ternary_to_id(x: torch.LongTensor) -> torch.LongTensor:
    device = x.device
    nb = x.shape[-1]
    bases = 3 ** torch.arange(nb, device=device)
    ids = ((x + 1) * bases).sum(-1)
    return ids


def invert_permutation(p: torch.LongTensor, result_length) -> torch.LongTensor:
    return torch.scatter(
        input=p.new_zeros(result_length),
        dim=0,
        index=p,
        src=torch.arange(p.shape[0], device=p.device)
    )


@torch.jit.script
def fused_clip(max_norm: float, total_norm):
    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef = torch.clip(clip_coef, 0.0, 1.0)
    return clip_coef


@torch.no_grad()
def clip_grad_norm_(
        parameters, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    if len(parameters) == 0:
        return torch.tensor(0.)
    if norm_type == float('inf'):
        norms = [p.grad.detach().abs().max() for p in parameters]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([
            torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
    clip_coef = fused_clip(max_norm, total_norm)
    for p in parameters:
        p.grad.detach().mul_(clip_coef)
    return total_norm


class ScaledWeightConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(ScaledWeightConv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride,
            padding, dilation, groups, bias
        )

        # self.weight = q(w_q)
        # alpha: scale for each output channel
        self.scales = nn.Parameter(torch.ones(
            self.weight.shape[0], 1, 1, 1
        ))

        # w' = w / alpha
        self.w_prime = nn.Parameter(self.weight.data)

        # \hat{w_q} ~= w'
        self.register_buffer('quant_scores', self.weight.data)

        self.use_float = False

    def forward(self, x):
        if self.use_float:
            weight = self.w_prime * self.scales
        else:
            weight = self.weight * self.scales
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

