import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as Data
import numpy as np

class Net(torch.nn.Module):
    """
    1 hidden layer Relu network
    """
    def __init__(self, n_feature, n_hidden, n_output, init_scale=1, bias_hidden=True, bias_output=False, balanced=True, **kwargs):
        """
        n_feature: dimension of input
        n_hidden: number of hidden neurons
        n_output: dimension of output
        init_scale: all the weights are initialized ~ N(0, init_scale^2) where m is the input dimension of this layer
        bias_hidden: if True, use bias parameters in hidden layer. Use no bias otherwise
        bias_output: if True, use bias parameters in output layer. Use no bias otherwise
        balanced: if True, use a balanced initialisation
        """
        super(Net, self).__init__()
        self.init_scale = init_scale
        
        self.hidden = torch.nn.Linear(n_feature, n_hidden, bias=bias_hidden)   # hidden layer with rescaled init
        torch.nn.init.normal_(self.hidden.weight.data, std=init_scale)
        if bias_hidden:
            torch.nn.init.normal_(self.hidden.bias.data, std=init_scale)
            
        self.predict = torch.nn.Linear(n_hidden, n_output, bias=bias_output)   # output layer with rescaled init
        if balanced: # balanced initialisation
            if bias_hidden:
                neuron_norms = (self.hidden.weight.data.norm(dim=1).square()+self.hidden.bias.data.square()).sqrt()
            else:
                neuron_norms = (self.hidden.weight.data.norm(dim=1).square()).sqrt()
            self.predict.weight.data = 2*torch.bernoulli(0.5*torch.ones_like(self.predict.weight.data)) -1
            self.predict.weight.data *= neuron_norms
        else:
            torch.nn.init.normal_(self.predict.weight.data, std=init_scale)
        if bias_output:
            torch.nn.init_normal_(self.predict.bias.data, std=init_scale)
            
                  
        self.activation = kwargs.get('activation', torch.nn.ReLU()) # activation of hidden layer
        
        if kwargs.get('zero_output', False):
            # ensure that the estimated function is 0 at initialisation
            half_n = int(n_hidden/2)
            self.hidden.weight.data[half_n:] = self.hidden.weight.data[:half_n]
            if bias_hidden:
                self.hidden.bias.data[half_n:] = self.hidden.bias.data[:half_n]
            self.predict.weight.data[0, half_n:] = -self.predict.weight.data[0, :half_n]
            if bias_output:
                self.predict.bias.data[0, half_n:] = -self.predict.bias.data[0, :half_n]
        
        self.skip_ = kwargs.get('skip_connection', False)    
        if self.skip_:
            # add a free skip connection in the parameterisation
            self.skip = torch.nn.Linear(n_feature, n_output, bias=True)
            # initialise the skip connection as 0
            self.skip.weight.data *= 0
            self.skip.bias.data *= 0

            

    def forward(self, z):
        out = self.predict(self.activation(self.hidden(z)))
        if self.skip_:
        	out += self.skip(z)
        return out


