import torch
import numpy as np
#####
# TODO: Add an abstract syntax class here for users to be able to create their own custom likelihoods.
[docs]class FishLikelihood:
pass
#####
# Note, need to check that the recuction of the nll is correct, default reducation is mean
[docs]class GaussianLikelihood(FishLikelihood):
def __init__(self, sigma_init, sigma_fixed):
self.sigma_init = sigma_init
self.sigma_fixed = sigma_fixed
[docs] def init_theta(self):
return self.sigma_init
[docs] def init_lam(self, scale):
return scale
[docs] def nll(self, theta, y_pred, y):
"""
Negative Log Likelihood Function
params:
theta : sigma (standard deviation of the conditional density)
y_pred : y predicted by the model
y : y_true
"""
sigma = theta
if self.sigma_fixed:
sigma = self.sigma_init
return 0.5 * (
torch.nn.MSELoss()(y, y_pred) / sigma + np.log(sigma**2) / y_pred.shape[0]
)
[docs] def sample(self, theta, y_pred):
"""
Sample from model's conditional density
params:
theta : sigma (standard deviation of the conditional density)
y_pred : y predicted by the model
"""
sigma = theta
if self.sigma_fixed:
sigma = self.sigma_init
return y_pred + torch.normal(0, sigma, y_pred.shape)
[docs] def ef(self, lam, u):
return 0.5 * ((lam * u) ** 2)
[docs]class BernoulliLikelihood(FishLikelihood):
def __init__(self):
pass
[docs] def init_theta(self):
return 1.0
[docs] def init_lam(self, scale):
return scale
[docs] def nll(self, theta, y_pred, y):
return torch.nn.BCEWithLogitsLoss(y, y_pred)
[docs] def sample(self, theta, y_pred):
# Check this sampler
pred_dist = torch.normal(y_pred, 1.0)
return 1.0 * torch.bernoulli(pred_dist)
[docs]class SoftMaxLikelihood(FishLikelihood):
def __init__(self):
pass
[docs] def init_theta(self):
return 0.0
[docs] def init_lam(self, scale):
# return scale (this needs to be a torch tensor to calculate gradients)
return torch.tensor(scale, requires_grad=True)
[docs] def dense_to_one_hot(self, labels_dense, num_classes=None):
"""Convert class labels from scalars to one-hot vectors.
labeld_dense is a list of labels
num_classes is the number of possible classes
"""
num_labels = np.array(labels_dense).shape[0]
index_offset = np.arange(num_labels) * num_classes
labels_one_hot = np.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + np.array(labels_dense).ravel()] = 1
return labels_one_hot.tolist()
[docs] def nll(sef, theta, y_pred, y):
logits = torch.nn.functional.log_softmax(y_pred, dim=1)
return -1.0 * np.mean(np.sum(logits * y, axis=1))
[docs] def sample(self, theta, y_pred):
logits = np.log(y_pred)
return self.dense_to_one_hot(torch.distributions.Categorical(logits=logits))
[docs] def ef(self, lam, u):
return 0.0 * (lam * u)