# Simple BNN implementation, for use with HMC code in toy example
import torch as t
import torch.nn as nn

device = t.device('cpu')


class BNN(nn.Module):
    def __init__(self, Dx, H, Dy, train_x, train_y, log_noise=None):
        super(BNN, self).__init__()

        self.Dx = Dx
        self.H = H
        self.Dy = Dy
        self.num_layers = len(H)
        self.num_weights = (Dx + 1)*H[0]
        for i in range(self.num_layers-1):
            self.num_weights += (H[i] + 1)*H[i+1]
        self.num_weights += (H[self.num_layers-1] + 1)*Dy

        self.train_x = train_x
        self.train_y = train_y

        if log_noise is not None:
            self.log_noise_var = log_noise
        else:
            self.log_noise_var = t.zeros(1, device=device)

    def forward(self, inputs, weights):
        inputs = inputs.unsqueeze(1).unsqueeze(2)
        uw = self.Dx*self.H[0]  # upper limit of weights index for current layer
        ub = uw + self.H[0]  # upper limit of bias index for current layer
        weights_l = weights[:, :uw].view(-1, self.Dx, self.H[0])
        bias_l = weights[:, uw:ub].view(-1, 1, self.H[0])
        outputs = (inputs.matmul(weights_l) + bias_l).clamp(min=0)
        for i in range(self.num_layers-1):
            uw = ub + self.H[i]*self.H[i+1]
            weights_l = weights[:, ub:uw].view(-1, self.H[i], self.H[i+1])
            ub = uw + self.H[i+1]
            bias_l = weights[:, uw:ub].view(-1, 1, self.H[i+1])
            outputs = (outputs.matmul(weights_l) + bias_l).clamp(min=0)
        uw = ub + self.H[self.num_layers-1]*self.Dy
        weights_l = weights[:, ub:uw].view(-1, self.H[self.num_layers-1], self.Dy)
        ub = uw + self.Dy
        bias_l = weights[:, uw:ub].view(-1, 1, self.Dy)
        outputs = outputs.matmul(weights_l) + bias_l

        return outputs.squeeze(2)

    # Potential function for HMC
    def potential(self, weights):
        if weights.shape[1] > self.num_weights:
            weights = weights[:, :-1]
            self.log_noise_var = weights[:, -1]

        prior_term = 0

        uw = self.Dx*self.H[0]  # upper limit of weights index for current layer
        ub = uw + self.H[0]  # upper limit of bias index for current layer
        weights_l = weights[:, :uw].view(-1, self.Dx, self.H[0])
        bias_l = weights[:, uw:ub].view(-1, 1, self.H[0])
        outputs = (self.train_x.unsqueeze(1).unsqueeze(2).matmul(weights_l) + bias_l).clamp(min=0)
        prior_term += 0.5*(self.Dx+1)*((weights_l**2).sum((1, 2)) + (bias_l**2).sum((1, 2)))
        for i in range(self.num_layers-1):
            uw = ub + self.H[i]*self.H[i+1]
            weights_l = weights[:, ub:uw].view(-1, self.H[i], self.H[i+1])
            ub = uw + self.H[i+1]
            bias_l = weights[:, uw:ub].view(-1, 1, self.H[i+1])
            prior_term += 0.5 * (self.H[i]+1) * ((weights_l ** 2).sum((1, 2)) + (bias_l ** 2).sum((1, 2)))
            outputs = (outputs.matmul(weights_l) + bias_l).clamp(min=0)
        uw = ub + self.H[self.num_layers-1]*self.Dy
        weights_l = weights[:, ub:uw].view(-1, self.H[self.num_layers-1], self.Dy)
        ub = uw + self.Dy
        bias_l = weights[:, uw:ub].view(-1, 1, self.Dy)
        outputs = (outputs.matmul(weights_l) + bias_l).squeeze(2)
        prior_term += 0.5 * (self.H[self.num_layers-1]+1) * ((weights_l ** 2).sum((1, 2)) + (bias_l ** 2).sum((1, 2)))

        likelihood_term = 0.5*outputs.shape[0]*self.log_noise_var +\
                          0.5*((outputs-self.train_y.unsqueeze(1))**2).sum((0, 2))/t.exp(self.log_noise_var)

        return prior_term + likelihood_term
