import  torch as tr
import torch.nn as nn

class quadexp(nn.Module):
	def __init__(self, sigma = 2.):
		super(quadexp,self).__init__()
		self.sigma = sigma
	def forward(self,x):
		return tr.exp(-x**2/(self.sigma**2))

class OneHiddenLayer(nn.Module):
	def __init__(self,d_int, H, d_out,non_linearity = quadexp(),bias=False):
		super(OneHiddenLayer,self).__init__()
		self.linear1 = tr.nn.Linear(d_int, H,bias=bias)
		self.linear2 = tr.nn.Linear(H, d_out,bias=bias)
		self.non_linearity = non_linearity
		self.d_int = d_int
		self.d_out = d_out

	def weights_init(self,center, std):
		self.linear1.weights_init(center,std)
		self.linear2.weights_init(center,std)

	def forward(self, x):
		h1_relu = self.linear1(x).clamp(min=0)
		h2_relu = self.linear2(h1_relu)
		h2_relu = self.non_linearity(h2_relu)

		return h2_relu

class SphericalTeacher(tr.utils.data.Dataset):

	def __init__(self,network, N_samples, dtype, device):
		D = network.d_int
		self.device = device
		self.source = tr.distributions.multivariate_normal.MultivariateNormal(tr.zeros(D ,dtype=dtype,device=device), tr.eye(D,dtype=dtype,device=device))
		source_samples = self.source.sample([N_samples])
		inv_norm = 1./tr.norm(source_samples,dim=1)
		self.X = tr.einsum('nd,n->nd',source_samples,inv_norm)
		self.total_size = N_samples
		self.network = network

		with tr.no_grad():
			self.Y = self.network(self.X)

	def __len__(self):
		return self.total_size
	def __getitem__(self,index):
		return self.X[index,:],self.Y[index,:]
