'''ResNet in PyTorch.

For Pre-activation ResNet, see 'preact_resnet.py'.

Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
	Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.nn.init as init

import math
from torch.autograd import Variable
import utils_quant as qt
import numpy as np


p=2147483647 #31 bit

trunc_bits = 0

def update_max(current_max, x):
	next_max = max(abs(x.data.max()), abs(x.data.min()))
	if current_max is None or current_max < next_max:
		return next_max
	else:
		return current_max

class Truncate(nn.Module):         
	def __init__(self, qp):
		super(Truncate, self).__init__()
		self.qp = qp

	def forward(self, x):
		return torch.div(x, 2**self.qp)

__all__ = ['ResNet', 'resnet18', 'resnet34']

def _weights_init(m):
	classname = m.__class__.__name__
	#print(classname)
	if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
		init.kaiming_normal_(m.weight)

# Stochastic ReLU used in Circa
def arelu(x, tr, enable_neg):
	xl = x.long()
	rs = torch.floor(torch.mul(torch.rand(xl.shape), p)).to("cuda").long()
	px = torch.add(p,xl)
	xf = torch.where(xl<0, px, xl)    #field
	xr = torch.add(xf, rs)
	x0 = torch.where(xr<p, xr, xr-p)
	nr = torch.sub(p, rs)
	x1 = torch.where(nr<p, nr, nr-p)

	xreconst = torch.where(x0+x1<p, x0+x1, x0+x1-p)

	t = int(math.pow(2, tr))
	x0t = x0//t
	if  enable_neg:
		# zero is neg and 1 is pos
		# if rt==x0t sign is assigned negaitve
		sign = torch.where(rs//t<x0t, torch.ones(x.shape).to("cuda"), torch.zeros(x.shape).to("cuda"))
		

	if not enable_neg:
		# if rt==x0t sign is assigned positive
		sign = torch.where(rs//t>x0t, torch.zeros(x.shape).to("cuda"), torch.ones(x.shape).to("cuda"))
	

	relu = x*sign

	truesign = torch.where(x<0, torch.zeros(x.shape).to("cuda"), torch.ones(x.shape).to("cuda"))
	badsign = torch.sum(truesign!=sign)

	badsign_pos = torch.logical_and(torch.logical_and((truesign==1), (sign==0)), x!=0).sum()
	badsign_neg = torch.logical_and(torch.logical_and((truesign==0), (sign==1)), x!=0).sum()
	
	pos = torch.sum(x>=0)
	neg = torch.sum(x<0)

	return relu, badsign_pos, badsign_neg, pos, neg 

# Model/realization of Stochastic ReLU for the fault model validation
def srelu(x, tr):
	
	xl = torch.mul(x, math.pow(2, 2*14-8))
	
	## ----- For PosZero case -----------##

	# Collecting all the activations in the truncation range
	pos_trunc_range = xl[(xl>0) & (xl<2**tr)]
	

	# Error probaility for the activations in the truncation range
	error_prob =  torch.div(torch.add(torch.mul(pos_trunc_range,-1),2**tr), 2**tr).to("cuda")

	# uniformly sample probabilities between 0 and 1
	prob = torch.rand(torch.numel(pos_trunc_range)).to("cuda")

	#Assigning zero if its randomly sampled probability is less than error rate
	pos_trunc_range[prob < error_prob] = 0

	#set values in original tensor
	xl[(xl>0) & (xl<2**tr)] = pos_trunc_range

	badsign_pos_sc = torch.sum(pos_trunc_range==0)
	pos_sc = torch.sum(x>=0)

	#implementation of ReLU function
	relu = torch.where(xl >= 0, xl, torch.zeros(xl.shape).to("cuda"))
	relu = torch.div(relu, math.pow(2, 2*14-8))
	
	return relu, badsign_pos_sc, pos_sc

class BasicBlockM(nn.Module):
	expansion = 1

	def __init__(self, in_planes, planes, stride=1):
		super(BasicBlockM, self).__init__()
		self.conv1 = nn.Conv2d(
			in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
		self.bn1 = nn.BatchNorm2d(planes)
		#self.relu1 = nn.ReLU(inplace=True)
		self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
							   stride=1, padding=1, bias=False)
		self.bn2 = nn.BatchNorm2d(planes)
		self.stride = stride
		self.in_planes = in_planes
		self.planes = planes
		self.shortcut = nn.Sequential()
		if stride != 1 or in_planes != self.expansion*planes:
			self.shortcut = nn.Sequential(
				nn.Conv2d(in_planes, self.expansion*planes,
						  kernel_size=1, stride=stride, bias=False), Truncate(qt.qbits-1),
				nn.BatchNorm2d(self.expansion*planes)
			)
		#self.relu2 = nn.ReLU(inplace=True)

		self.badsign_pos1 = 0
		self.badsign_pos2 = 0
		self.badsign_neg1 = 0
		self.badsign_neg2 = 0
		self.pos1 = 0 
		self.pos2 = 0
		self.neg1 = 0 
		self.neg2 = 0 

	def forward(self, x):
		out = self.conv1(x)
		out.data = qt.trunc(out.data, qt.qbits-1)
		out = self.bn1(out)

		out, badsign_pos1, badsign_neg1, pos1, neg1 = arelu(out, trunc_bits, enable_neg=True)
		out.data = qt.trunc(out.data, qt.qbits-1)

		out = self.conv2(out)
		out.data = qt.trunc(out.data, qt.qbits-1)
		out = self.bn2(out)
		if self.stride != 1 or self.in_planes != self.expansion*self.planes:
			out = out
		else: 
			x.data = qt.scale(x.data, qt.qbits-1)

		out += self.shortcut(x)

		out, badsign_pos2, badsign_neg2, pos2, neg2 = arelu(out, trunc_bits, enable_neg=True)
		out.data = qt.trunc(out.data, qt.qbits-1)

		self.badsign_pos1 += badsign_pos1
		self.badsign_pos2 += badsign_pos2
		self.badsign_neg1 += badsign_neg1
		self.badsign_neg2 += badsign_neg2
		self.pos1 += pos1
		self.pos2 += pos2
		self.neg1 += neg1
		self.neg2 += neg2

		return out


class ResNet(nn.Module):
	def __init__(self, block, num_blocks, num_classes=100):
		super(ResNet, self).__init__()
		self.in_planes = 64

		self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
							   stride=1, padding=1, bias=False)
		self.bn1 = nn.BatchNorm2d(64)
		#self.relu = nn.ReLU(inplace=True)
		self.layer1 = self._make_layer(BasicBlockM, 64, num_blocks[0], stride=1)
		self.layer2 = self._make_layer(BasicBlockM, 128, num_blocks[1], stride=2)
		self.layer3 = self._make_layer(BasicBlockM, 256, num_blocks[2], stride=2)
		self.layer4 = self._make_layer(BasicBlockM, 512, num_blocks[3], stride=2)
		self.linear = nn.Linear(512*block.expansion, num_classes)

		self.apply(_weights_init)

		self.layer_max = [None]*7
		self.badsign_pos_conv1 = 0
		self.badsign_neg_conv1 = 0
		self.pos_conv1 = 0
		self.neg_conv1 = 0

	def _make_layer(self, block, planes, num_blocks, stride):
		strides = [stride] + [1]*(num_blocks-1)
		layers = []
		for stride in strides:
			layers.append(block(self.in_planes, planes, stride))
			self.in_planes = planes * block.expansion
		return nn.Sequential(*layers)

	
	def quantize(self, x):
		x.data = qt.scale(x.data, qt.qbits-6)

		self.layer_max[0] = update_max(self.layer_max[0], x)
		out = self.conv1(x)
		out.data = qt.trunc(out.data, qt.qbits-1)
		self.layer_max[1] = update_max(self.layer_max[1], out)
		out = self.bn1(out)
		
		out, badsign_pos_conv1, badsign_neg_conv1, pos_conv1, neg_conv1 = arelu(out, trunc_bits, enable_neg=True)
		out.data = qt.trunc(out.data, qt.qbits-1)
		out = self.layer1(out)
		self.layer_max[2] = update_max(self.layer_max[2], out)
		out = self.layer2(out)
		self.layer_max[3] = update_max(self.layer_max[3], out)
		out = self.layer3(out)
		self.layer_max[4] = update_max(self.layer_max[4], out)
		out = self.layer4(out)
		self.layer_max[5] = update_max(self.layer_max[5], out)
		out = F.avg_pool2d(out, 4)
		out = out.view(out.size(0), -1)
		out = self.linear(out)
		out.data = qt.trunc(out.data, qt.qbits-1)
		self.layer_max[6] = update_max(self.layer_max[6], out)

		self.badsign_pos_conv1 += badsign_pos_conv1
		self.badsign_neg_conv1 += badsign_neg_conv1
		self.pos_conv1 += pos_conv1
		self.neg_conv1 += neg_conv1

		return out

	
def resnet18():
	return ResNet(BasicBlockM, [2, 2, 2, 2])


def resnet34():
	return ResNet(BasicBlockM, [3, 4, 6, 3])

def test():
	net = ResNet18()
	y = net(torch.randn(1, 3, 32, 32))
	print(y.size())

if __name__ == "__main__":
	for net_name in __all__:
		if net_name.startswith('resnet'):
			print(net_name)
			test(globals()[net_name]())
			print()

# test()
