import torch
import torch.nn as nn
from efficientnet_pytorch import EfficientNet

def standardize(x, bn_stats):
	if bn_stats is None:
		return x

	bn_mean, bn_var = bn_stats

	view = [1] * len(x.shape)
	view[1] = -1
	x = (x - bn_mean.view(view)) / torch.sqrt(bn_var.view(view) + 1e-5)

	# if variance is too low, just ignore
	x *= (bn_var.view(view) != 0).float()
	return x


def clip_data(data, max_norm):
	norms = torch.norm(data.reshape(data.shape[0], -1), dim=-1)
	scale = (max_norm / norms).clamp(max=1.0)
	data *= scale.reshape(-1, 1, 1, 1)
	return data


def get_num_params(model):
	return sum(p.numel() for p in model.parameters() if p.requires_grad)


class StandardizeLayer(nn.Module):
	def __init__(self, bn_stats):
		super(StandardizeLayer, self).__init__()
		self.bn_stats = bn_stats

	def forward(self, x):
		return standardize(x, self.bn_stats)


class ClipLayer(nn.Module):
	def __init__(self, max_norm):
		super(ClipLayer, self).__init__()
		self.max_norm = max_norm

	def forward(self, x):
		return clip_data(x, self.max_norm)


class CIFAR10_CNN(nn.Module):
	def __init__(self, in_channels=3, act_type=None, input_norm=None,  **kwargs):
		super(CIFAR10_CNN, self).__init__()
		self.in_channels = in_channels
		self.features = None
		self.classifier = None
		self.norm = None
		self.act_type=act_type
		self.build(input_norm, **kwargs)

	def build(self, input_norm=None, num_groups=None,
			  bn_stats=None, size=None):



		self.norm = nn.Identity()

		if self.act_type == 'ReLU':
			self.act = nn.ReLU()
		elif self.act_type == 'Tanh':
			self.act = nn.Tanh()

		self.MP=nn.MaxPool2d(kernel_size=2, stride=2)


		self.conv1=nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
		self.conv2=nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)


		self.conv3=nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
		self.conv4=nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)


		self.conv5=nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
		self.conv6=nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)


		self.Emb=nn.Linear(128 * 4 * 4, 128)

		self.Logits=nn.Linear(128, 10)

	def forward(self, x):
		if self.in_channels != 3:
			x = self.norm(x.view(-1, self.in_channels, 8, 8))
		conv1=self.conv1(x)
		conv1_a=self.act(conv1)
		conv2=self.conv2(conv1_a)
		conv2_a=self.act(conv2)
		MP1=self.MP(conv2_a)

		conv3=self.conv3(MP1)
		conv3_a=self.act(conv3)
		conv4=self.conv4(conv3_a)
		conv4_a=self.act(conv4)
		MP2=self.MP(conv4_a)

		conv5=self.conv5(MP2)
		conv5_a=self.act(conv5)
		conv6=self.conv6(conv5_a)
		conv6_a=self.act(conv6)
		MP3=self.MP(conv6_a)

		MP3 = MP3.view(MP3.size(0), -1)
		
		Emb=self.Emb(MP3)
		Emb_a=self.act(Emb)

		Logits=self.Logits(Emb_a)

		return Logits,[conv1,conv2,conv3,conv4,conv5,conv6,Emb]

# class CIFAR10_CNN(nn.Module):
# 	def __init__(self, in_channels=3, act_type=None, input_norm=None,  **kwargs):
# 		super(CIFAR10_CNN, self).__init__()
# 		self.in_channels = in_channels
# 		self.features = None
# 		self.classifier = None
# 		self.norm = None
# 		self.act_type=act_type
# 		self.build(input_norm, **kwargs)

# 	def build(self, input_norm=None, num_groups=None,
# 			  bn_stats=None, size=None):

# 		if self.in_channels == 3:
# 			if size == "small":
# 				cfg = [16, 16, 'M', 32, 32, 'M', 64, 'M']
# 			else:
# 				cfg = [32, 32, 'M', 64, 64, 'M', 128, 128, 'M']

# 			self.norm = nn.Identity()
# 		else:
# 			if size == "small":
# 				cfg = [16, 16, 'M', 32, 32]
# 			else:
# 				cfg = [64, 'M', 64]
# 			if input_norm is None:
# 				self.norm = nn.Identity()
# 			elif input_norm == "GroupNorm":
# 				self.norm = nn.GroupNorm(num_groups, self.in_channels, affine=False)
# 			else:
# 				self.norm = lambda x: standardize(x, bn_stats)

# 		layers = []
# 		if self.act_type == 'ReLU':
# 			act = nn.ReLU
# 		elif self.act_type == 'Tanh':
# 			act = nn.Tanh

# 		c = self.in_channels
# 		for v in cfg:
# 			if v == 'M':
# 				layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
# 			else:
# 				conv2d = nn.Conv2d(c, v, kernel_size=3, stride=1, padding=1)

# 				layers += [conv2d, act()]
# 				c = v

# 		self.features = nn.Sequential(*layers)

# 		if self.in_channels == 3:
# 			hidden = 128
# 			self.classifier = nn.Sequential(nn.Linear(c * 4 * 4, hidden), act(), nn.Linear(hidden, 10))
# 		else:
# 			self.classifier = nn.Linear(c * 4 * 4, 10)
# 	def forward(self, x):
# 		if self.in_channels != 3:
# 			x = self.norm(x.view(-1, self.in_channels, 8, 8))
# 		x_f = self.features(x)
# 		x = x_f.view(x_f.size(0), -1)
# 		x = self.classifier(x)
# 		return x,x_f




class MNIST_CNN(nn.Module):
	def __init__(self, in_channels=1, act_type=None, input_norm=None,  **kwargs):
		super(MNIST_CNN, self).__init__()
		self.in_channels = in_channels
		self.features = None
		self.classifier = None
		self.norm = None
		self.act_type=act_type

		self.build(input_norm, **kwargs)

	def build(self, input_norm=None, num_groups=None,
			  bn_stats=None, size=None):
		
		ch1, ch2 = (16, 32) 
		cfg = [(ch1, 8, 2, 2), 'M', (ch2, 4, 2, 0), 'M']
		self.norm = nn.Identity()


		layers = []

		if self.act_type == 'ReLU':
			self.act = nn.ReLU()
		elif self.act_type == 'Tanh':
			self.act = nn.Tanh()

		
		filters1, k_size1, stride1, pad1 = cfg[0]
		filters2, k_size2, stride2, pad2 = cfg[2]
		hidden = 32

		self.Conv1=nn.Conv2d(1, filters1, kernel_size=k_size1, stride=stride1, padding=pad1)
		self.Conv2=nn.Conv2d(filters1, filters2, kernel_size=k_size2, stride=stride2, padding=pad2)
		self.MP=nn.MaxPool2d(kernel_size=2, stride=1)

		
		self.Emb = nn.Linear(filters2 * 4 * 4, hidden)
		self.Logits= nn.Linear(hidden, 10)

	def forward(self, x):

		conv1 = self.Conv1(x)
		conv1_act = self.act(conv1)
		MP1 = self.MP(conv1_act)

		conv2 = self.Conv2(MP1)
		conv2_act = self.act(conv2)
		MP2 = self.MP(conv2_act)

		MP2 = MP2.view(MP2.size(0), -1)
		Emb = self.Emb(MP2)
		Emb_act = self.act(Emb)
		Logits= self.Logits(Emb_act)
		return Logits, [conv1,conv2,Emb]
# class MNIST_CNN(nn.Module):
# 	def __init__(self, in_channels=1, act_type=None, input_norm=None,  **kwargs):
# 		super(MNIST_CNN, self).__init__()
# 		self.in_channels = in_channels
# 		self.features = None
# 		self.classifier = None
# 		self.norm = None
# 		self.act_type=act_type

# 		self.build(input_norm, **kwargs)

# 	def build(self, input_norm=None, num_groups=None,
# 			  bn_stats=None, size=None):
# 		if self.in_channels == 1:
# 			ch1, ch2 = (16, 32) if size is None else (32, 64)
# 			cfg = [(ch1, 8, 2, 2), 'M', (ch2, 4, 2, 0), 'M']
# 			self.norm = nn.Identity()
# 		else:
# 			ch1, ch2 = (16, 32) if size is None else (32, 64)
# 			cfg = [(ch1, 3, 2, 1), (ch2, 3, 1, 1)]
# 			if input_norm == "GroupNorm":
# 				self.norm = nn.GroupNorm(num_groups, self.in_channels, affine=False)
# 			elif input_norm == "BN":
# 				self.norm = lambda x: standardize(x, bn_stats)
# 			else:
# 				self.norm = nn.Identity()

# 		layers = []

# 		if self.act_type == 'ReLU':
# 			act = nn.ReLU
# 		elif self.act_type == 'Tanh':
# 			act = nn.Tanh
# 		c = self.in_channels
# 		for v in cfg:
# 			if v == 'M':
# 				layers += [nn.MaxPool2d(kernel_size=2, stride=1)]
# 			else:
# 				filters, k_size, stride, pad = v
# 				conv2d = nn.Conv2d(c, filters, kernel_size=k_size, stride=stride, padding=pad)

# 				layers += [conv2d, act()]
# 				c = filters

# 		self.features = nn.Sequential(*layers)

# 		hidden = 32
# 		self.classifier = nn.Sequential(nn.Linear(c * 4 * 4, hidden),
# 										act(),
# 										nn.Linear(hidden, 10))

# 	def forward(self, x):
# 		if self.in_channels != 1:
# 			x = self.norm(x.view(-1, self.in_channels, 7, 7))
# 		x_f = self.features(x)
# 		x = x_f.view(x_f.size(0), -1)
# 		x = self.classifier(x)
# 		return x, x_f


class ScatterLinear(nn.Module):
	def __init__(self, in_channels, hw_dims, input_norm=None, classes=10, clip_norm=None, **kwargs):
		super(ScatterLinear, self).__init__()
		self.K = in_channels
		self.h = hw_dims[0]
		self.w = hw_dims[1]
		self.fc = None
		self.norm = None
		self.clip = None
		self.build(input_norm, classes=classes, clip_norm=clip_norm, **kwargs)

	def build(self, input_norm=None, num_groups=None, bn_stats=None, clip_norm=None, classes=10):
		self.fc = nn.Linear(self.K * self.h * self.w, classes)

		if input_norm is None:
			self.norm = nn.Identity()
		elif input_norm == "GroupNorm":
			self.norm = nn.GroupNorm(num_groups, self.K, affine=False)
		else:
			self.norm = lambda x: standardize(x, bn_stats)

		if clip_norm is None:
			self.clip = nn.Identity()
		else:
			self.clip = ClipLayer(clip_norm)

	def forward(self, x):
		x = self.norm(x.view(-1, self.K, self.h, self.w))
		x = self.clip(x)
		x = x.reshape(x.size(0), -1)
		x = self.fc(x)
		return x


class German_FC(nn.Module):
	def __init__(self, in_channels=1, act_type=None, input_norm=None,  **kwargs):
		super(German_FC, self).__init__()
		self.in_channels = in_channels
		self.features = None
		self.classifier = None
		self.norm = None
		self.act_type=act_type

		self.build(input_norm, **kwargs)

	def build(self, input_norm=None, num_groups=None,
			  bn_stats=None, size=None):

		if self.act_type == 'ReLU':
			act = nn.ReLU
		elif self.act_type == 'Tanh':
			act = nn.Tanh
		
	

		self.features = nn.Sequential(
								    torch.nn.Linear(61, 124),
								    act()
								)

		self.classifier = nn.Sequential(nn.Linear(124, 124),
										act(),
										nn.Linear(124, 2))

	def forward(self, x):
		
		if self.in_channels != 1:
			x = self.norm(x.view(-1, self.in_channels, 7, 7))
		x_f = self.features(x)
		x = x_f.view(x_f.size(0), -1)
		x = self.classifier(x)
		return x, x_f


class Efficient_MLP(nn.Module):
	def __init__(self, in_channels=1, act_type=None, input_norm=None,  **kwargs):
		super(Efficient_MLP, self).__init__()
		self.in_channels = in_channels
		self.features = None
		self.classifier = None
		self.norm = None
		self.act_type=act_type

		self.build(input_norm, **kwargs)

	def build(self, input_norm=None, num_groups=None,
			  bn_stats=None, size=None):

		if self.act_type == 'ReLU':
			act = nn.ReLU
		elif self.act_type == 'Tanh':
			act = nn.Tanh
		
	

		self.features = nn.Sequential(
								    torch.nn.Linear(62720, 256),
								    act()
								)

		self.classifier = nn.Sequential(nn.Linear(256, 256),
										act(),
										nn.Linear(256, 100))

	def forward(self, x):
		
		x_f = self.features(x)
		x = self.classifier(x_f)
		return x, x_f

CNNS = {
	"cifar10": CIFAR10_CNN,
	"cifar100": Efficient_MLP,
	"fmnist": MNIST_CNN,
	"mnist": MNIST_CNN,
	"german": German_FC,
}




