import torch
import math
import numpy as np


def truncated_normal_(tensor, mean=0, std=1):
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)


def get_fan_in(tensor, mode):
    if mode != "fan_in":
        raise NotImplementedError("support only wage normal")

    dimensions = tensor.ndimension()
    if dimensions < 2:
        assert False, tensor.shape
    elif dimensions == 2:
        fan_in = tensor.size(1)
    else:
        num_input_fmaps = tensor.size(1)
        receptive_field_size = 1
        if tensor.dim() > 2:
            receptive_field_size = tensor[0][0].numel()
        fan_in = num_input_fmaps * receptive_field_size
    return fan_in


def wage_init_(tensor, bits_W, factor=1.0, mode="fan_in", beta=1.5):
    fan_in = get_fan_in(tensor, mode)
    limit_min = beta / (2 ** (bits_W - 1))
    flimit = math.sqrt(3 * factor / fan_in)
    limit = max(limit_min, flimit)
    tensor.data.uniform_(-limit, limit)


def get_scale(tensor, bits_W, factor=1.0, mode="fan_in", beta=1.5, fan_in=None):
    if fan_in is None:
        fan_in = get_fan_in(tensor, mode)
    limit_min = beta / (2 ** (bits_W - 1))
    flimit = math.sqrt(3 * factor / fan_in)

    # Should be `limit_min / limit` according to the paper, but i think the equation is wrong
    scale = limit_min / flimit

    scale = 2 ** round(np.log2(scale))
    return max(scale, 1.0)
