#this has been modified from https://github.com/imirzadeh/stable-continual-learning with MIT licence
import torch.nn as nn
from torch.nn.functional import relu, avg_pool2d
import RBFLayer


def conv3x3(in_planes, out_planes, stride=1):
	return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
	expansion = 1

	def __init__(self, in_planes, planes, stride=1, config={}):
		super(BasicBlock, self).__init__()
		self.conv1 = conv3x3(in_planes, planes, stride)
		self.conv2 = conv3x3(planes, planes)

		self.shortcut = nn.Sequential()
		if stride != 1 or in_planes != self.expansion * planes:
			self.shortcut = nn.Sequential(
				nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
			)
		self.IC1 = nn.Sequential(nn.InstanceNorm2d(planes),
			nn.Dropout(p=config['dropout'])
			)

		self.IC2 = nn.Sequential(nn.InstanceNorm2d(planes),
			nn.Dropout(p=config['dropout'])
			)

	def forward(self, x):
		out = self.conv1(x)
		out = self.IC1(out)
		out = relu(out)

		out = self.conv2(out)
		out = self.IC2(out)
		out = out + self.shortcut(x)
		out = relu(out)
		return out


class ResNet(nn.Module):
	def __init__(self, block, num_blocks, num_classes, nf, config={}):
		super(ResNet, self).__init__()
		self.in_planes = nf

		self.conv1 = conv3x3(3, nf * 1)
		self.bn1 = nn.InstanceNorm2d(nf * 1)
		self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1, config=config)
		self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2, config=config)
		self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2, config=config)
		self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2, config=config)
		self.only_embedding = False
		if 'classification_head' in config and not config['classification_head']:
			self.only_embedding = True
		self.linear = nn.Linear(nf * 8 * block.expansion, num_classes)

	def _make_layer(self, block, planes, num_blocks, stride, config):
		strides = [stride] + [1] * (num_blocks - 1)
		layers = []
		for stride in strides:
			layers.append(block(self.in_planes, planes, stride, config=config))
			self.in_planes = planes * block.expansion
		return nn.Sequential(*layers)

	def calc_representation(self, x):
		bsz = x.size(0)
		out = relu(self.bn1(self.conv1(x.view(bsz, 3, 32, 32))))
		out = self.layer1(out)
		out = self.layer2(out)
		out = self.layer3(out)
		out = self.layer4(out)
		out = avg_pool2d(out, 4)
		out = out.view(out.size(0), -1)
		return out

	def forward(self, x):
		bsz = x.size(0)
		out = relu(self.bn1(self.conv1(x.view(bsz, 3, 32, 32))))
		out = self.layer1(out)
		out = self.layer2(out)
		out = self.layer3(out)
		out = self.layer4(out)
		out = avg_pool2d(out, 4)
		out = out.view(out.size(0), -1)
		if not self.only_embedding:
			out = self.linear(out)
		return out


class MNISTModel(nn.Module):

	def __init__(self, nclasses, config={}):
		super(MNISTModel, self).__init__()
		self.layer1 = nn.Linear(1*28*28, 256)
		self.layer2 = nn.Linear(256, 256)
		self.only_embedding = False
		if 'classification_head' in config and not config['classification_head']:
			self.only_embedding = True
		self.linear = nn.Linear(256, nclasses)

	def calc_representation(self, x):
		bsz = x.size(0)
		out = self.layer1(x.view(bsz, -1))
		out = self.layer2(out)
		return out

	def forward(self, x):
		bsz = x.size(0)
		out = self.layer1(x.view(bsz, -1))
		out = self.layer2(out)
		if not self.only_embedding:
			out = self.linear(out)
		return out


def ResNet18(nclasses=100, nf=20, config={}):
	net = ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf, config=config)
	return net
