import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.stats import truncnorm

class BasicBlock(nn.Module):
	def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
		super(BasicBlock, self).__init__()
		self.bn1 = nn.BatchNorm2d(in_planes)
		self.relu1 = nn.ReLU(inplace=True)
		self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
							   padding=1, bias=False)
		self.bn2 = nn.BatchNorm2d(out_planes)
		self.relu2 = nn.ReLU(inplace=True)
		self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
							   padding=1, bias=False)
		self.droprate = dropRate
		self.equalInOut = (in_planes == out_planes)
		self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
							   padding=0, bias=False) or None
	def forward(self, x):
		if not self.equalInOut:
			x = self.relu1(self.bn1(x))
		else:
			out = self.relu1(self.bn1(x))
		out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
	
		if self.droprate > 0:
			out = F.dropout(out, p=self.droprate, training=self.training)
		out = self.conv2(out)

		return torch.add(x if self.equalInOut else self.convShortcut(x), out)

class BasicBlock_rotatedrelu_maam(nn.Module):
	def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
		super(BasicBlock_rotatedrelu_maam, self).__init__()
		self.bn1 = nn.BatchNorm2d(in_planes)

		aa = np.random.uniform(0.0,1.0,in_planes)
		cc = np.random.uniform(0.0,1.0,out_planes)

		self.a = nn.Parameter(torch.tensor((aa>0.5)*(truncnorm.rvs( (np.tan(35.0/180*np.pi)-1.0)/np.sqrt(3.0), (np.tan(55.0/180*np.pi)-1.0)/np.sqrt(3.0), size=in_planes)*np.sqrt(3.0)+1.0) \
							+ (aa<=0.5)*(truncnorm.rvs((np.tan(-55.0/180*np.pi)+1.0)/np.sqrt(3.0), (np.tan(-35.0/180*np.pi)+1.0)/np.sqrt(3.0), size=in_planes)*np.sqrt(3.0)-1.0)).float(),requires_grad=True)
		self.c = nn.Parameter(torch.tensor((cc>0.5)*(truncnorm.rvs( (np.tan(35.0/180*np.pi)-1.0)/np.sqrt(3.0), (np.tan(55.0/180*np.pi)-1.0)/np.sqrt(3.0), size=out_planes)*np.sqrt(3.0)+1.0) \
							+ (cc<=0.5)*(truncnorm.rvs((np.tan(-55.0/180*np.pi)+1.0)/np.sqrt(3.0), (np.tan(-35.0/180*np.pi)+1.0)/np.sqrt(3.0), size=out_planes)*np.sqrt(3.0)-1.0)).float(),requires_grad=True) 
		

		self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
							   padding=1, bias=False)
		self.bn2 = nn.BatchNorm2d(out_planes)
		self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
							   padding=1, bias=False)
		self.droprate = dropRate
		self.equalInOut = (in_planes == out_planes)
		self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
							   padding=0, bias=False) or None
	def forward(self, x):
		temp = self.bn1(x)
		a1 = torch.repeat_interleave(torch.repeat_interleave(self.a.repeat(temp.shape[0],1).unsqueeze(2), temp.shape[2], dim=2).unsqueeze(3), temp.shape[3],dim=3)		
		if not self.equalInOut:			
			x = torch.mul(a1,torch.relu(temp))
		else:
			out = torch.mul(a1,torch.relu(temp))

		temp1 = self.bn2(self.conv1(out if self.equalInOut else x))
	
		c1 = torch.repeat_interleave(torch.repeat_interleave(self.c.repeat(temp1.shape[0],1).unsqueeze(2), temp1.shape[2], dim=2).unsqueeze(3), temp1.shape[3],dim=3)
		out = torch.mul(c1,torch.relu(self.bn2(self.conv1(out if self.equalInOut else x))))		
		
		if self.droprate > 0:
			out = F.dropout(out, p=self.droprate, training=self.training)
		out = self.conv2(out)
		return torch.add(x if self.equalInOut else self.convShortcut(x), out)


class NetworkBlock(nn.Module):
	def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
		super(NetworkBlock, self).__init__()
		self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
	def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
		layers = []
		for i in range(int(nb_layers)):
			layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
		return nn.Sequential(*layers)
	def forward(self, x):
		return self.layer(x)

class WideResNet(nn.Module):
	def __init__(self, depth, typer, num_classes, widen_factor=1, dropRate=0.0):
		super(WideResNet, self).__init__()
		nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
		assert((depth - 4) % 6 == 0)
		n = (depth - 4) / 6
		if typer == "ReLU":
			block = BasicBlock
		elif typer == "RReLU":
			block = BasicBlock_rotatedrelu_maam
		# 1st conv before any network block
		self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
							   padding=1, bias=False)
		# 1st block
		self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
		# 2nd block
		self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
		# 3rd block
		self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
		# global average pooling and classifier
		self.bn1 = nn.BatchNorm2d(nChannels[3])
		self.relu = nn.ReLU(inplace=True)
		self.fc = nn.Linear(nChannels[3], num_classes)
		self.nChannels = nChannels[3]

		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
			elif isinstance(m, nn.BatchNorm2d):
				m.weight.data.fill_(1)
				m.bias.data.zero_()
			elif isinstance(m, nn.Linear):
				m.bias.data.zero_()
	def forward(self, x):
		out = self.conv1(x)
		out = self.block1(out)
		out = self.block2(out)
		out = self.block3(out)
		out = self.relu(self.bn1(out))
		out = F.avg_pool2d(out, 8)
		out = out.view(-1, self.nChannels)
		return self.fc(out)
