import torch
import math
import numpy as np

qbits=14


def quantize(x, qp):
	
	x = torch.round(torch.mul(x, 2**qp))
	x = torch.div(x, 2**qp)
	return x

def scale(x, qp):
	
	x = torch.round(torch.mul(x, 2**qp))
	return x

def trunc(x, qp):
	return torch.div(x, 2**qp)


def npscale(x, qp):
	x = x*math.pow(2, qp)
	x = np.floor(x).astype(int)
	return x



def analyze_model(model):
	# activation values
	print("layer max = {}".format([x.item() for x in model.layer_max]))

	# weights
	for name, param in model.named_parameters():
		print ("{}".format(max(abs(param.data.max()), abs(param.data.min()))))

	

def quantize_params(model):
	for name, param in model.named_parameters():
		param.data = quantize(param.data, qbits-1)
		#print ("{}: {}".format(name, max(abs(param.data.max()), abs(param.data.min()))))

#FIXME need to scale biases again
def scale_params(model):
	for child in model.children():
		for layer in child.modules():
			if(isinstance(layer,torch.nn.modules.batchnorm.BatchNorm2d)):
				layer.running_mean = scale(layer.running_mean, qbits-6)
				#print(layer.running_mean)

	for name, param in model.named_parameters():
		#if name[-4:]=='bias':
		#    param.data = scale(param.data, qbits-6)

		if name[-4:]=='bias':
			param.data = scale(param.data, 2*qbits-6-1)
		else:
			param.data = scale(param.data, qbits-1)
		#print ("{}: {}".format(name, max(abs(param.data.max()), abs(param.data.min()))))

