import torch
import math
import numpy as np

qbits=14


def quantize(x, qp):
	x = torch.mul(x, math.pow(2, qp))
	x = torch.round(x)
	x = x/math.pow(2,qp)
	return x

def scale(x, qp):
    x = torch.mul(x, math.pow(2, qp))
    x = torch.round(x)
    return x

def trunc(x, qp):
    return x//math.pow(2, qp)


def npscale(x, qp):
    x = x*math.pow(2, qp)
    x = np.floor(x).astype(int)
    return x



def analyze_model(model):
    # activation values
    print("layer max = {}".format([x.item() for x in model.layer_max]))

    # weights
    for name, param in model.named_parameters():
        print ("{}: {}".format(name, max(abs(param.data.max()), abs(param.data.min()))))

    

def quantize_params(model):
    for name, param in model.named_parameters():
        param.data = quantize(param.data, qbits)
        #print ("{}: {}".format(name, max(abs(param.data.max()), abs(param.data.min()))))

#FIXME need to scale biases again
def scale_params(model, alpha, beta):
    for child in model.children():
        for layer in child.modules():
            if(isinstance(layer,torch.nn.modules.batchnorm.BatchNorm2d)):
                layer.running_mean = scale(layer.running_mean, qbits-alpha)
                #print(layer.running_mean)

    for name, param in model.named_parameters():
        #if name[-4:]=='bias':
        #    param.data = scale(param.data, qbits-6)

        if name[-4:]=='bias':
            param.data = scale(param.data, 2*qbits-alpha-beta)
        else:
            param.data = scale(param.data, qbits-beta)
        #print ("{}: {}".format(name, max(abs(param.data.max()), abs(param.data.min()))))

