from torch import nn
import torch
from torch import nn as nn
from torch.autograd import Variable
import torch.autograd as autograd
import numpy as np
import torch.nn.functional as F
import random
from torch.autograd import grad

class EnsembleLinear(nn.Linear):
	def __init__(self, ensemble_size, in_features, out_features):
		nn.Module.__init__(self)
		self.in_features = in_features
		self.out_features = out_features
		self.weight = nn.Parameter(torch.Tensor(ensemble_size, in_features, out_features))
		self.bias = nn.Parameter(torch.Tensor(ensemble_size, 1, out_features))
		self.reset_parameters()

	def forward(self, x):
		return torch.baddbmm(self.bias, x, self.weight)


class SupspaceIndicator(nn.Module):
	def __init__(self, dims, k=49, dim_out=1):
		super(SupspaceIndicator, self).__init__()
		self.k = k

		self.mask = EnsembleLinear(dims, self.k, dim_out)
		
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
			elif isinstance(m, nn.BatchNorm2d):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)

	def forward(self, x):
		mask = self.mask(x.transpose(0, 1)).transpose(0, 1).squeeze(-1)
		return nn.Sigmoid()(mask), mask

class SupspaceIndicatorFC(nn.Module):
	def __init__(self, dims):
		super(SupspaceIndicatorFC, self).__init__()


		self.model = nn.Linear(dims, dims)
		
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
			elif isinstance(m, nn.BatchNorm2d):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)

	def forward(self, x):
		mask = self.model(x)
		return nn.Sigmoid()(mask), mask


class Classifier(nn.Module):
	def __init__(self, dim, classes):
		super(Classifier, self).__init__()
		self.model = nn.Linear(dim, classes)
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
			elif isinstance(m, nn.BatchNorm2d):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)

	def forward(self, x):
		x = self.model(x)
		return x

class Sigmoid_Classifier(nn.Module):
	def __init__(self, dim, classes):
		super(Sigmoid_Classifier, self).__init__()
		
		self.model = nn.Linear(dim, classes)
	  
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
			elif isinstance(m, nn.BatchNorm2d):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)

	def forward(self, x):
		x = self.model(x)
		return nn.Sigmoid()(x)

class DSigmoid_Classifier(nn.Module):
	def __init__(self, dim, classes, drop_rate):
		super(DSigmoid_Classifier, self).__init__()
		
		self.model = nn.Sequential(
			nn.Linear(dim, classes),
			nn.Dropout(drop_rate)
			)

	  
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
			elif isinstance(m, nn.BatchNorm2d):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)

	def forward(self, x):
		x = self.model(x)
		return nn.Sigmoid()(x)

