Source code for FishLeg.fishleg_likelihood

import torch
import numpy as np

#####
# TODO: Add an abstract syntax class here for users to be able to create their own custom likelihoods.
[docs]class FishLikelihood: pass
##### # Note, need to check that the recuction of the nll is correct, default reducation is mean
[docs]class GaussianLikelihood(FishLikelihood): def __init__(self, sigma_init, sigma_fixed): self.sigma_init = sigma_init self.sigma_fixed = sigma_fixed
[docs] def init_theta(self): return self.sigma_init
[docs] def init_lam(self, scale): return scale
[docs] def nll(self, theta, y_pred, y): """ Negative Log Likelihood Function params: theta : sigma (standard deviation of the conditional density) y_pred : y predicted by the model y : y_true """ sigma = theta if self.sigma_fixed: sigma = self.sigma_init return 0.5 * ( torch.nn.MSELoss()(y, y_pred) / sigma + np.log(sigma**2) / y_pred.shape[0] )
[docs] def sample(self, theta, y_pred): """ Sample from model's conditional density params: theta : sigma (standard deviation of the conditional density) y_pred : y predicted by the model """ sigma = theta if self.sigma_fixed: sigma = self.sigma_init return y_pred + torch.normal(0, sigma, y_pred.shape)
[docs] def ef(self, lam, u): return 0.5 * ((lam * u) ** 2)
[docs]class BernoulliLikelihood(FishLikelihood): def __init__(self): pass
[docs] def init_theta(self): return 1.0
[docs] def init_lam(self, scale): return scale
[docs] def nll(self, theta, y_pred, y): return torch.nn.BCEWithLogitsLoss(y, y_pred)
[docs] def sample(self, theta, y_pred): # Check this sampler pred_dist = torch.normal(y_pred, 1.0) return 1.0 * torch.bernoulli(pred_dist)
[docs]class SoftMaxLikelihood(FishLikelihood): def __init__(self): pass
[docs] def init_theta(self): return 0.0
[docs] def init_lam(self, scale): # return scale (this needs to be a torch tensor to calculate gradients) return torch.tensor(scale, requires_grad=True)
[docs] def dense_to_one_hot(self, labels_dense, num_classes=None): """Convert class labels from scalars to one-hot vectors. labeld_dense is a list of labels num_classes is the number of possible classes """ num_labels = np.array(labels_dense).shape[0] index_offset = np.arange(num_labels) * num_classes labels_one_hot = np.zeros((num_labels, num_classes)) labels_one_hot.flat[index_offset + np.array(labels_dense).ravel()] = 1 return labels_one_hot.tolist()
[docs] def nll(sef, theta, y_pred, y): logits = torch.nn.functional.log_softmax(y_pred, dim=1) return -1.0 * np.mean(np.sum(logits * y, axis=1))
[docs] def sample(self, theta, y_pred): logits = np.log(y_pred) return self.dense_to_one_hot(torch.distributions.Categorical(logits=logits))
[docs] def ef(self, lam, u): return 0.0 * (lam * u)