import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class DBNN_Base(nn.Module):
    
    def __init__(self, likelihood, scale_init, alpha = 1):
        super().__init__()
        
        self.likelihood = likelihood
        self.alpha = alpha
        self.scale_factor = nn.Parameter(torch.Tensor([scale_init]).to(device))

    def forward_latent(self, data, out_var = None):
        out_mean = data
        
        if isinstance(self.model, nn.Sequential):
            for layer in self.model:
                out_mean, out_var = layer.forward(out_mean, out_var)

        out_var = out_var / (self.alpha * self.scale_factor)
        
        return out_mean, out_var