import torch


def noise_estimation_loss(model,
                          x0: torch.Tensor,
                          t: torch.LongTensor,
                          e: torch.Tensor,
                          b: torch.Tensor, keepdim=False):
    a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
    x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
    output = model(x, t.float())
    if keepdim:
        return (e - output).square().sum(dim=(1, 2, 3))
    else:
        return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)

def noise_deepcache_loss(model,
                          x0: torch.Tensor,
                          t: torch.LongTensor,
                          e: torch.Tensor,
                          b: torch.Tensor, 
                          prv_f: torch.Tensor=None, keepdim=False):
    a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
    x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
    if prv_f is not None:
        output, feature = model(x, t.float(), prv_f = prv_f)
    else:
        output, feature = model(x, t.float())

    if keepdim:
        return (e - output).square().sum(dim=(1, 2, 3)), feature
    else:
        return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0), feature


loss_registry = {
    'simple': noise_estimation_loss,
    'deepcache': noise_deepcache_loss
}
