#!/usr/bin/env python
# -*-coding:utf-8 -*-
import torch.nn.functional as F 
import torch.nn as nn 

def reconstruction_loss(x, x_recon, distribution="gaussian"):
    batch_size = x.size(0)
    assert batch_size != 0
    recon_loss = nn.MSELoss(reduction='sum')(x, x_recon).div(batch_size)
    return recon_loss
