import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

import math
from torch.autograd import Variable
import utils_quant as qt
import numpy as np


p=2147483647 #31 bit
trunc_bits = 0


def update_max(current_max, x):
	next_max = max(abs(x.data.max()), abs(x.data.min()))
	if current_max is None or current_max < next_max:
		return next_max
	else:
		return current_max


__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']

def _weights_init(m):
	classname = m.__class__.__name__
	if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
		init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
	def __init__(self, lambd):
		super(LambdaLayer, self).__init__()
		self.lambd = lambd

	def forward(self, x):
		return self.lambd(x)

# Stochastic ReLU used in Circa
def arelu(x, tr, enable_neg):
	xl = x.long()
	rs = torch.floor(torch.mul(torch.rand(xl.shape), p)).to("cuda").long()
	px = torch.add(p,xl)
	xf = torch.where(xl<0, px, xl)    #field
	xr = torch.add(xf, rs)
	x0 = torch.where(xr<p, xr, xr-p)
	nr = torch.sub(p, rs)
	x1 = torch.where(nr<p, nr, nr-p)

	xreconst = torch.where(x0+x1<p, x0+x1, x0+x1-p)

	t = int(math.pow(2, tr))
	x0t = x0//t
	if  enable_neg:
		# zero is neg and 1 is pos
		# if rt==x0t sign is assigned negaitve
		sign = torch.where(rs//t<x0t, torch.ones(x.shape).to("cuda"), torch.zeros(x.shape).to("cuda"))
		

	if not enable_neg:
		# if rt==x0t sign is assigned positive
		sign = torch.where(rs//t>x0t, torch.zeros(x.shape).to("cuda"), torch.ones(x.shape).to("cuda"))
	

	relu = x*sign

	truesign = torch.where(x<0, torch.zeros(x.shape).to("cuda"), torch.ones(x.shape).to("cuda"))
	badsign = torch.sum(truesign!=sign)

	badsign_pos = torch.logical_and(torch.logical_and((truesign==1), (sign==0)), x!=0).sum()
	badsign_neg = torch.logical_and(torch.logical_and((truesign==0), (sign==1)), x!=0).sum()
	
	pos = torch.sum(x>=0)
	neg = torch.sum(x<0)

	return relu, badsign_pos, badsign_neg, pos, neg 

# Model/realization of Stochastic ReLU for the fault model validation
def srelu(x, tr):
	
	xl = torch.mul(x, math.pow(2, 2*14-8))
	
	## ----- For PosZero case -----------##

	# Collecting all the activations in the truncation range
	pos_trunc_range = xl[(xl>0) & (xl<2**tr)]
	

	# Error probaility for the activations in the truncation range
	error_prob =  torch.div(torch.add(torch.mul(pos_trunc_range,-1),2**tr), 2**tr).to("cuda")

	# uniformly sample probabilities between 0 and 1
	prob = torch.rand(torch.numel(pos_trunc_range)).to("cuda")

	#Assigning zero if its randomly sampled probability is less than error rate
	pos_trunc_range[prob < error_prob] = 0

	#set values in original tensor
	xl[(xl>0) & (xl<2**tr)] = pos_trunc_range

	badsign_pos_sc = torch.sum(pos_trunc_range==0)
	pos_sc = torch.sum(x>=0)

	#implementation of ReLU function
	relu = torch.where(xl >= 0, xl, torch.zeros(xl.shape).to("cuda"))
	relu = torch.div(relu, math.pow(2, 2*14-8))
	
	return relu, badsign_pos_sc, pos_sc


class BasicBlockM(nn.Module):
	expansion = 1

	def __init__(self, in_planes, planes, stride=1, option='A'):
		super(BasicBlockM, self).__init__()
		self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
		self.bn1 = nn.BatchNorm2d(planes)
		#self.relu1 = nn.ReLU(inplace=True)
		self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
		self.bn2 = nn.BatchNorm2d(planes)
		#self.relu2 = nn.ReLU(inplace=True)

		self.shortcut = nn.Sequential()
		if stride != 1 or in_planes != planes:
			if option == 'A':
				"""
				For CIFAR10 ResNet paper uses option A.
				"""
				self.shortcut = LambdaLayer(lambda x:
											F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
			elif option == 'B':
				self.shortcut = nn.Sequential(
					 nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
					 nn.BatchNorm2d(self.expansion * planes)
				)

		self.badsign_pos1 = 0
		self.badsign_pos2 = 0
		self.badsign_neg1 = 0
		self.badsign_neg2 = 0
		self.pos1 = 0 
		self.pos2 = 0
		self.neg1 = 0 
		self.neg2 = 0 
			
	def forward(self, x):
		
		out = self.conv1(x)
		out.data = qt.trunc(out.data, qt.qbits-2)
		out = self.bn1(out)
		
		out, badsign_pos1, badsign_neg1, pos1, neg1 = arelu(out, trunc_bits, enable_neg=False)
		out.data = qt.trunc(out.data, qt.qbits-2)

		out = self.conv2(out)
		out.data = qt.trunc(out.data, qt.qbits-2)
		out = self.bn2(out)
		
		x.data = qt.scale(x.data, qt.qbits-2)
		out += self.shortcut(x)
		
		out, badsign_pos2, badsign_neg2, pos2, neg2 = arelu(out, trunc_bits, enable_neg=False)
		out.data = qt.trunc(out.data, qt.qbits-2)
		
		self.badsign_pos1 += badsign_pos1
		self.badsign_pos2 += badsign_pos2
		self.badsign_neg1 += badsign_neg1
		self.badsign_neg2 += badsign_neg2
		self.pos1 += pos1
		self.pos2 += pos2
		self.neg1 += neg1
		self.neg2 += neg2
				
		return out


class ResNet(nn.Module):
	def __init__(self, block, num_blocks, num_classes=100):
		super(ResNet, self).__init__()
		self.in_planes = 16

		self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
		self.bn1 = nn.BatchNorm2d(16)
		#self.relu1 = nn.ReLU(inplace=True)
	
		self.layer1 = self._make_layer(BasicBlockM, 16, num_blocks[0], stride=1)
		self.layer2 = self._make_layer(BasicBlockM, 32, num_blocks[1], stride=2)
		self.layer3 = self._make_layer(BasicBlockM, 64, num_blocks[2], stride=2)
		self.linear = nn.Linear(64, num_classes)

		self.apply(_weights_init)

		self.layer_max = [None]*6
		self.badsign_pos_conv1 = 0
		self.badsign_neg_conv1 = 0	
		self.pos_conv1 = 0
		self.neg_conv1 = 0

	def _make_layer(self, block, planes, num_blocks, stride):
		strides = [stride] + [1]*(num_blocks-1)
		layers = []
		
		for stride in strides:
			layers.append(block(self.in_planes, planes, stride))
			self.in_planes = planes * block.expansion
		
		return nn.Sequential(*layers)

	def quantize(self, x):
		x.data = qt.scale(x.data, qt.qbits-6)
		
		self.layer_max[0] = update_max(self.layer_max[0], x)
		
		out = self.conv1(x)
		out.data = qt.trunc(out.data, qt.qbits-2)
		
		self.layer_max[1] = update_max(self.layer_max[1], out)
		out = self.bn1(out)

		out, badsign_pos_conv1, badsign_neg_conv1, pos_conv1, neg_conv1 = arelu(out, trunc_bits, enable_neg=False)
		out.data = qt.trunc(out.data, qt.qbits-2)
		
		out = self.layer1(out)
		self.layer_max[2] = update_max(self.layer_max[2], out)
		out = self.layer2(out)
		self.layer_max[3] = update_max(self.layer_max[3], out)
		out = self.layer3(out)
		self.layer_max[4] = update_max(self.layer_max[4], out)
		out = F.avg_pool2d(out, out.size()[3])
		out = out.view(out.size(0), -1)
		out = self.linear(out)
		out.data = qt.trunc(out.data, qt.qbits-2)
		self.layer_max[5] = update_max(self.layer_max[5], out)

		self.badsign_pos_conv1 += badsign_pos_conv1
		self.badsign_neg_conv1 += badsign_neg_conv1
		self.pos_conv1 += pos_conv1
		self.neg_conv1 += neg_conv1
		return out

	
def resnet20():
	return ResNet(BasicBlockM, [3, 3, 3])


def resnet32():
	return ResNet(BasicBlockM, [5, 5, 5])


def test(net):
	import numpy as np
	total_params = 0

	for x in filter(lambda p: p.requires_grad, net.parameters()):
		total_params += np.prod(x.data.numpy().shape)
	print("Total number of params", total_params)
	print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))


if __name__ == "__main__":
	for net_name in __all__:
		if net_name.startswith('resnet'):
			print(net_name)
			test(globals()[net_name]())
			print()
