#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import torch
from helpers import get_fshifts, get_fullsize
import time


def pred_loss(feat, neighbors, priorw, coupling, prec=1,
              device=None, **kwargs):
    """evaluates the prediction loss for a feature map and parameters for
    the neighborhood connections. This is used for training the
    parameters of the random field.

    Parameters
    ----------
    feat : torch.Tensor
        Batch of featuremaps for which we want to optimize the prediction
    neighbors : list or numpy.ndarray
        integer indices for the neighbors
    priorw : torch.Tensor
        n_neigh long vector of log-prior ratios w=1 vs w=0 for each neighbor
    coupling : torch.Tensor
        n_neigh x n_features matrix of log-coupling strengths ("precisions")
    prec : torch.Tensor
        n_feat long vector of feature precisions
        (default=1, which is adequate for previous normalisation)
    device : torch.device or string, optional
        device to create intermediate Tensors on. The default is None.

    Returns
    -------
    loss : torch.Tensor
        loss fo optimization, i.e. ratio of prediction from each image to
        itself to the average prediction over images

    """
    if device is None:
        device = feat.device
    neighbors = np.array(neighbors)
    c = torch.exp(coupling)
    p = torch.exp(priorw)
    p = p / (1 + p)
    priorw_stack = torch.stack([priorw, -priorw], -1)
    log_p = priorw_stack - torch.log(1 + torch.exp(priorw_stack))
    sel = torch.Tensor([[[[1]]], [[[0]]]]).to(device)
    # Derived on paper:
    # factor is sqrt(det(prec))
    # det(prec) of coupled Gaussian for one feature is (prec**2 + 2*prec*C)
    # det(prec) of uncoupled Gaussian is  prec**2
    # thus normalizing factor between the two is:
    # sqrt((prec**2 + 2*prec*C) / prec**2)
    # = sqrt(1 + 2 * C/prec)
    # normalizer = torch.sqrt(1 + 2 * c / prec)
    # normalizer = normalizer.prod(dim=1)
    log_n = torch.log(1 + 2 * c / prec).sum(dim=1) / 2
    c = torch.unsqueeze(torch.unsqueeze(c, -1), -1)
    fact = torch.zeros(neighbors.shape[0], feat.shape[0], feat.shape[0],
                       feat.shape[2], feat.shape[3], device=device)
    for i_neigh in range(neighbors.shape[0]):
        fsmall, fshiftsmall = get_fshifts(feat, neighbors[i_neigh])
        log_ps = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(
            log_p[i_neigh], -1), -1), -1)
        for iIm in range(fsmall.shape[0]):
            fdiff = fshiftsmall - fsmall[iIm]
            # sum weighted differences
            fdiff = torch.sum(c[i_neigh] * fdiff ** 2, dim=1)
            # calculate factor (with added numerical stability)
            log_f = log_n[i_neigh] - (fdiff / 2)
            factor = torch.logsumexp((sel * log_f) + log_ps, 0)
            # unstable computation was:
            # factor = torch.log(
            #     (1 - p[i_neigh])
            #    + p[i_neigh] * normalizer[i_neigh] * torch.exp(- fdiff / 2))

            # put fact into the predicted place, to allow further normalizing
            # each factor is put into two slots, because we do prediction
            # in both directions
            if neighbors[i_neigh, 0] >= 0 and neighbors[i_neigh, 1] >= 0:
                fact[i_neigh, iIm, :, (fact.shape[3] - fdiff.shape[1]):,
                     (fact.shape[4] - fdiff.shape[2]):
                     ] += factor
                fact[i_neigh, iIm, :, :fdiff.shape[1],
                     :fdiff.shape[2]
                     ] += factor
            elif neighbors[i_neigh, 0] >= 0 and neighbors[i_neigh, 1] < 0:
                fact[i_neigh, iIm, :, (fact.shape[3] - fdiff.shape[1]):,
                     :fdiff.shape[2]
                     ] += factor
                fact[i_neigh, iIm, :, :fdiff.shape[1],
                     (fact.shape[4] - fdiff.shape[2]):
                     ] += factor
            elif neighbors[i_neigh, 0] < 0 and neighbors[i_neigh, 1] >= 0:
                fact[i_neigh, iIm, :, :fdiff.shape[1],
                     (fact.shape[4] - fdiff.shape[2]):
                     ] += factor
                fact[i_neigh, iIm, :, (fact.shape[3] - fdiff.shape[1]):,
                     :fdiff.shape[2]
                     ] += factor
            elif neighbors[i_neigh, 0] < 0 and neighbors[i_neigh, 1] < 0:
                fact[i_neigh, iIm, :, :fdiff.shape[1],
                     :fdiff.shape[2]
                     ] += factor
                fact[i_neigh, iIm, :, (fact.shape[3] - fdiff.shape[1]):,
                     (fact.shape[4] - fdiff.shape[2]):
                     ] += factor
    fact_p = torch.sum(fact, 0)
    fact_prod = fact_p - torch.max(fact_p)
    loss = -torch.mean(
        torch.einsum('iijk->ijk', fact_prod) - torch.logsumexp(fact_prod, 0))
    return loss


def pred_loss_shuffle(feat, neighbors, priorw, coupling, prec=1,
                      device=None, **kwargs):
    """evaluates the prediction loss for a feature map and parameters for
    the neighborhood connections. This is used for training the
    parameters of the random field.
    This version uses a shuffled version of the original image for normalisation
    per factor. This trains each factor individually.

    Parameters
    ----------
    feat : torch.Tensor
        Batch of featuremaps for which we want to optimize the prediction
    neighbors : list or numpy.ndarray
        integer indices for the neighbors
    priorw : torch.Tensor
        n_neigh long vector of log-prior ratios w=1 vs w=0 for each neighbor
    coupling : torch.Tensor
        n_neigh x n_features matrix of log-coupling strengths ("precisions")
    prec : torch.Tensor
        n_feat long vector of feature precisions
        (default=1, which is adequate for previous normalisation)
    device : torch.device or string, optional
        device to create intermediate Tensors on. The default is None.

    Returns
    -------
    loss : torch.Tensor
        loss fo optimization, i.e. log-ratio of predicitons for the correct
        image vs. the shuffled image.

    """
    neighbors = np.array(neighbors)
    c = torch.exp(coupling)
    priorw_stack = torch.stack([priorw, -priorw], -1)
    log_p = priorw_stack - torch.logaddexp(torch.Tensor([0]).to(device), priorw_stack)
    # Derived on paper:
    # factor is sqrt(det(prec))
    # det(prec) of coupled Gaussian for one feature is (prec**2 + 2*prec*C)
    # det(prec) of uncoupled Gaussian is  prec**2
    # thus normalizing factor between the two is:
    # sqrt((prec**2 + 2*prec*C) / prec**2)
    # = sqrt(1 + 2 * C/prec)
    # normalizer = torch.sqrt(1 + 2 * c / prec)
    # normalizer = normalizer.prod(dim=1)
    log_n = torch.log(1 + 2 * c / prec).sum(dim=1) / 2
    c = torch.unsqueeze(torch.unsqueeze(c, -1), -1)
    loss = 0
    for i_neigh in range(neighbors.shape[0]):
        fsmall, fshiftsmall = get_fshifts(feat, neighbors[i_neigh])
        perm = torch.randperm(fsmall.shape[0] * fsmall.shape[2] * fsmall.shape[3])
        fsmall_shuffle = fsmall.permute(1, 0, 2, 3).reshape(
            fsmall.shape[1], -1)[:, perm].view(
            [fsmall.shape[1], fsmall.shape[0], fsmall.shape[2], fsmall.shape[3]]).permute(
            1, 0, 2, 3)
        fdiff = fshiftsmall - fsmall
        fdiff_shuffle = fshiftsmall - fsmall_shuffle
        # sum weighted differences
        fdiff = torch.sum(c[i_neigh] * fdiff ** 2, dim=1)
        fdiff_shuffle = torch.sum(c[i_neigh] * fdiff_shuffle ** 2, dim=1)
        # calculate factor (with added numerical stability)
        log_f = log_n[i_neigh] - (fdiff / 2)
        log_f_shuffle = log_n[i_neigh] - (fdiff_shuffle / 2)
        # log_p = priorw[i_neigh] - torch.logaddexp(priorw[i_neigh], torch.Tensor([0]))
        factor = torch.logaddexp(log_p[i_neigh][1], log_p[i_neigh][0] + log_f)
        factor_shuffle = torch.logaddexp(log_p[i_neigh][1], log_p[i_neigh][0] + log_f_shuffle)

        # InfoNCE per factor
        log_norm = torch.logsumexp(factor_shuffle.flatten(), 0)
        loss += torch.mean(torch.logaddexp(factor, log_norm) - factor)
    return loss


def pred_loss_pos(feat, neighbors, priorw, coupling, prec=1,
                  device=None, n_pos=10, stop_grad=False, **kwargs):
    """evaluates the prediction loss for a feature map and parameters for
    the neighborhood connections. This is used for training the
    parameters of the random field.
    This version uses a random collection of positions with the possibility
    for a stop-gradient for the normalization

    Parameters
    ----------
    feat : torch.Tensor
        Batch of featuremaps for which we want to optimize the prediction
    neighbors : list or numpy.ndarray
        integer indices for the neighbors
    priorw : torch.Tensor
        n_neigh long vector of log-prior ratios w=1 vs w=0 for each neighbor
    coupling : torch.Tensor
        n_neigh x n_features matrix of log-coupling strengths ("precisions")
    prec : torch.Tensor
        n_feat long vector of feature precisions
        (default=1, which is adequate for previous normalisation)
    device : torch.device or string, optional
        device to create intermediate Tensors on. The default is None.
    n_pos : number of random positions to use as negative samples

    Returns
    -------
    loss : torch.Tensor
        loss fo optimization, i.e. ratio of prediction factors at a location
        vs the prediction function for randomly chosen values

    """
    if device is None:
        device = feat.device
    neighbors = np.array(neighbors)
    c = torch.exp(coupling)
    p = torch.exp(priorw)
    p = p / (1 + p)
    priorw_stack = torch.stack([priorw, -priorw], -1)
    log_p = priorw_stack - torch.logaddexp(torch.tensor([0.0], device=device), priorw_stack)
    # Derived on paper:
    # factor is sqrt(det(prec))
    # det(prec) of coupled Gaussian for one feature is (prec**2 + 2*prec*C)
    # det(prec) of uncoupled Gaussian is  prec**2
    # thus normalizing factor between the two is:
    # sqrt((prec**2 + 2*prec*C) / prec**2)
    # = sqrt(1 + 2 * C/prec)
    # normalizer = torch.sqrt(1 + 2 * c / prec)
    # normalizer = normalizer.prod(dim=1)
    log_n = torch.log(1 + 2 * c / prec).sum(dim=1) / 2
    c = torch.unsqueeze(torch.unsqueeze(c, -1), -1)
    fact = torch.zeros(feat.shape[0],
                       feat.shape[2], feat.shape[3], device=device)
    fact_n = torch.zeros(n_pos, feat.shape[0],
                         feat.shape[2], feat.shape[3], device=device)
    # choose randomly sampled normalization positions
    indices = np.random.randint(feat.shape[0] * feat.shape[2] * feat.shape[3], size=n_pos)
    pos = np.unravel_index(indices, (feat.shape[0], feat.shape[2], feat.shape[3]))
    if stop_grad:
        norm_vects = feat[pos[0], :, pos[1], pos[2]].detach()
    else:
        norm_vects = feat[pos[0], :, pos[1], pos[2]]
    for i_neigh in range(neighbors.shape[0]):
        fsmall, fshiftsmall = get_fshifts(feat, neighbors[i_neigh])
        fdiff = fshiftsmall - fsmall
        fdiff_n1 = fshiftsmall - norm_vects.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
        fdiff_n2 = fsmall - norm_vects.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
        # sum weighted differences
        fdiff = torch.sum(c[i_neigh] * fdiff ** 2, dim=1)
        fdiff_n1 = torch.sum(c[i_neigh] * fdiff_n1 ** 2, dim=2)
        fdiff_n2 = torch.sum(c[i_neigh] * fdiff_n2 ** 2, dim=2)
        # calculate factor (with added numerical stability)
        log_f = log_n[i_neigh] - (fdiff / 2)
        log_f_n1 = log_n[i_neigh] - (fdiff_n1 / 2)
        log_f_n2 = log_n[i_neigh] - (fdiff_n2 / 2)
        factor = torch.logaddexp(log_f + log_p[i_neigh, 0], log_p[i_neigh, 1])
        factor_n1 = torch.logaddexp(log_f_n1 + log_p[i_neigh, 0], log_p[i_neigh, 1])
        factor_n2 = torch.logaddexp(log_f_n2 + log_p[i_neigh, 0], log_p[i_neigh, 1])
        # unstable computation was:
        # factor = torch.log(
        #     (1 - p[i_neigh])
        #    + p[i_neigh] * normalizer[i_neigh] * torch.exp(- fdiff / 2))
        # put fact into the predicted place, to allow further normalizing
        # each factor is put into two slots, because we do prediction
        # in both directions,
        fact += get_fullsize(factor, neighbors[i_neigh])
        fact += get_fullsize(factor, -neighbors[i_neigh])
        fact_n += get_fullsize(factor_n1, -neighbors[i_neigh])
        fact_n += get_fullsize(factor_n2, neighbors[i_neigh])
    fact_norm = torch.cat((fact_n, fact.unsqueeze(0)), 0)
    loss = -torch.mean(fact - torch.logsumexp(fact_norm, 0))
    return loss


def pred_loss_pos2(feat, neighbors, priorw, coupling, prec=1,
                   device=None, n_pos=10, stop_grad=True, **kwargs):
    """evaluates the prediction loss for a feature map and parameters for
    the neighborhood connections. This is used for training the
    parameters of the random field.
    This version uses a random collection of positions with a stop-gradient
    for the normalization

    Parameters
    ----------
    feat : torch.Tensor
        Batch of featuremaps for which we want to optimize the prediction
    neighbors : list or numpy.ndarray
        integer indices for the neighbors
    priorw : torch.Tensor
        n_neigh long vector of log-prior ratios w=1 vs w=0 for each neighbor
    coupling : torch.Tensor
        n_neigh x n_features matrix of log-coupling strengths ("precisions")
    prec : torch.Tensor
        n_feat long vector of feature precisions
        (default=1, which is adequate for previous normalisation)
    device : torch.device or string, optional
        device to create intermediate Tensors on. The default is None.
    n_pos : number of random positions to use as negative samples

    Returns
    -------
    loss : torch.Tensor
        loss fo optimization, i.e. ratio of prediction factors at a location
        vs the prediction function for randomly chosen values

    """
    if device is None:
        device = feat.device
    neighbors = np.array(neighbors)
    c = torch.exp(coupling)
    p = torch.exp(priorw)
    p = p / (1 + p)
    priorw_stack = torch.stack([priorw, -priorw], -1)
    log_p = priorw_stack - torch.log(1 + torch.exp(priorw_stack))
    sel = torch.Tensor([[[[1]]], [[[0]]]]).to(device)
    # Derived on paper:
    # factor is sqrt(det(prec))
    # det(prec) of coupled Gaussian for one feature is (prec**2 + 2*prec*C)
    # det(prec) of uncoupled Gaussian is  prec**2
    # thus normalizing factor between the two is:
    # sqrt((prec**2 + 2*prec*C) / prec**2)
    # = sqrt(1 + 2 * C/prec)
    # normalizer = torch.sqrt(1 + 2 * c / prec)
    # normalizer = normalizer.prod(dim=1)
    log_n = torch.log(1 + 2 * c / prec).sum(dim=1) / 2
    c = torch.unsqueeze(torch.unsqueeze(c, -1), -1)
    fact = torch.zeros(feat.shape[0],
                       feat.shape[2], feat.shape[3], device=device)
    fact_n = torch.zeros(n_pos, feat.shape[0],
                         feat.shape[2], feat.shape[3], device=device)
    # choose randomly sampled normalization positions
    indices = np.random.randint(feat.shape[0] * feat.shape[2] * feat.shape[3], size=n_pos)
    pos = np.unravel_index(indices, (feat.shape[0], feat.shape[2], feat.shape[3]))
    if stop_grad:
        norm_vects = feat[pos[0], :, pos[1], pos[2]].detach()
    else:
        norm_vects = feat[pos[0], :, pos[1], pos[2]]
    for i_neigh in range(neighbors.shape[0]):
        fsmall, fshiftsmall = get_fshifts(feat, neighbors[i_neigh])
        log_ps = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(
            log_p[i_neigh], -1), -1), -1)
        fdiff = fshiftsmall - fsmall
        fdiff_n1 = fshiftsmall - norm_vects.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
        fdiff_n2 = fsmall - norm_vects.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
        # sum weighted differences
        fdiff = torch.sum(c[i_neigh] * fdiff ** 2, dim=1)
        fdiff_n1 = torch.sum(c[i_neigh] * fdiff_n1 ** 2, dim=2)
        fdiff_n2 = torch.sum(c[i_neigh] * fdiff_n2 ** 2, dim=2)

        # calculate factor (with added numerical stability)
        log_f = log_n[i_neigh] - (fdiff / 2)
        log_f_n1 = log_n[i_neigh] - (fdiff_n1 / 2)
        log_f_n2 = log_n[i_neigh] - (fdiff_n2 / 2)
        factor = torch.logsumexp((sel * log_f) + log_ps, 0)
        factor_n1 = torch.logsumexp((sel.unsqueeze(1) * log_f_n1) + log_ps.unsqueeze(1), 0)
        factor_n2 = torch.logsumexp((sel.unsqueeze(1) * log_f_n2) + log_ps.unsqueeze(1), 0)
        # unstable computation was:
        # factor = torch.log(
        #     (1 - p[i_neigh])
        #    + p[i_neigh] * normalizer[i_neigh] * torch.exp(- fdiff / 2))

        # put fact into the predicted place, to allow further normalizing
        # each factor is put into two slots, because we do prediction
        # in both directions,
        if neighbors[i_neigh, 0] >= 0 and neighbors[i_neigh, 1] >= 0:
            fact[:, (fact.shape[1] - fdiff.shape[1]):,
                 (fact.shape[2] - fdiff.shape[2]):
                 ] += factor
            fact[:, :fdiff.shape[1],
                 :fdiff.shape[2]
                 ] += factor
            fact_n[:, :, (fact.shape[1] - fdiff.shape[1]):,
                   (fact.shape[2] - fdiff.shape[2]):
                   ] += factor_n2
            fact_n[:, :, :fdiff.shape[1],
                   :fdiff.shape[2]
                   ] += factor_n1
        elif neighbors[i_neigh, 0] >= 0 and neighbors[i_neigh, 1] < 0:
            fact[:, (fact.shape[1] - fdiff.shape[1]):,
                 :fdiff.shape[2]
                 ] += factor
            fact[:, :fdiff.shape[1],
                 (fact.shape[2] - fdiff.shape[2]):
                 ] += factor
            fact_n[:, :, (fact.shape[1] - fdiff.shape[1]):,
                   :fdiff.shape[2]
                   ] += factor_n2
            fact_n[:, :, :fdiff.shape[1],
                   (fact.shape[2] - fdiff.shape[2]):
                   ] += factor_n1
        elif neighbors[i_neigh, 0] < 0 and neighbors[i_neigh, 1] >= 0:
            fact[:, :fdiff.shape[1],
                 (fact.shape[2] - fdiff.shape[2]):
                 ] += factor
            fact[:, (fact.shape[1] - fdiff.shape[1]):,
                 :fdiff.shape[2]
                 ] += factor
            fact_n[:, :, :fdiff.shape[1],
                   (fact.shape[2] - fdiff.shape[2]):
                   ] += factor_n2
            fact_n[:, :, (fact.shape[1] - fdiff.shape[1]):,
                   :fdiff.shape[2]
                   ] += factor_n1
        elif neighbors[i_neigh, 0] < 0 and neighbors[i_neigh, 1] < 0:
            fact[:, :fdiff.shape[1],
                 :fdiff.shape[2]
                 ] += factor
            fact[:, (fact.shape[1] - fdiff.shape[1]):,
                 (fact.shape[2] - fdiff.shape[2]):
                 ] += factor
            fact_n[:, :, :fdiff.shape[1],
                   :fdiff.shape[2]
                   ] += factor_n1
            fact_n[:, :, (fact.shape[1] - fdiff.shape[1]):,
                   (fact.shape[2] - fdiff.shape[2]):
                   ] += factor_n2
    fact_norm = torch.cat((fact_n, fact.unsqueeze(0)), 0)
    loss = -torch.mean(fact - torch.logsumexp(fact_norm, 0))
    return loss


def pred_loss_pos3(feat, neighbors, priorw, coupling, prec=1,
                   device=None, n_pos=100, stop_grad=True, **kwargs):
    """evaluates the prediction loss for a feature map and parameters for
    the neighborhood connections. This is used for training the
    parameters of the random field.
    This version uses a random collection of positions with a stop-gradient
    for the normalization.
    In this version the stop gradient allows large n_pos without memory
    problems when the stopgradient is used.

    Parameters
    ----------
    feat : torch.Tensor
        Batch of featuremaps for which we want to optimize the prediction
    neighbors : list or numpy.ndarray
        integer indices for the neighbors
    priorw : torch.Tensor
        n_neigh long vector of log-prior ratios w=1 vs w=0 for each neighbor
    coupling : torch.Tensor
        n_neigh x n_features matrix of log-coupling strengths ("precisions")
    prec : torch.Tensor
        n_feat long vector of feature precisions
        (default=1, which is adequate for previous normalisation)
    device : torch.device or string, optional
        device to create intermediate Tensors on. The default is None.
    n_pos : number of random positions to use as negative samples

    Returns
    -------
    loss : torch.Tensor
        loss fo optimization, i.e. ratio of prediction factors at a location
        vs the prediction function for randomly chosen values

    """
    if device is None:
        device = feat.device
    neighbors = np.array(neighbors)
    c = torch.exp(coupling)
    p = torch.exp(priorw)
    p = p / (1 + p)
    priorw_stack = torch.stack([priorw, -priorw], -1)
    log_p = priorw_stack - torch.log(1 + torch.exp(priorw_stack))
    sel = torch.Tensor([[[[1]]], [[[0]]]]).to(device)
    # Derived on paper:
    # factor is sqrt(det(prec))
    # det(prec) of coupled Gaussian for one feature is (prec**2 + 2*prec*C)
    # det(prec) of uncoupled Gaussian is  prec**2
    # thus normalizing factor between the two is:
    # sqrt((prec**2 + 2*prec*C) / prec**2)
    # = sqrt(1 + 2 * C/prec)
    # normalizer = torch.sqrt(1 + 2 * c / prec)
    # normalizer = normalizer.prod(dim=1)
    log_n = torch.log(1 + 2 * c / prec).sum(dim=1) / 2
    c = torch.unsqueeze(torch.unsqueeze(c, -1), -1)
    fact = torch.zeros(feat.shape[0],
                       feat.shape[2], feat.shape[3], device=device)
    # choose randomly sampled normalization positions
    indices = np.random.randint(feat.shape[0] * feat.shape[2] * feat.shape[3], size=n_pos)
    pos = np.unravel_index(indices, (feat.shape[0], feat.shape[2], feat.shape[3]))
    if stop_grad:
        norm_vects = feat[pos[0], :, pos[1], pos[2]].detach()
    else:
        norm_vects = feat[pos[0], :, pos[1], pos[2]]
    for i_neigh in range(neighbors.shape[0]):
        fsmall, fshiftsmall = get_fshifts(feat, neighbors[i_neigh])
        log_ps = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(
            log_p[i_neigh], -1), -1), -1)
        fdiff = fshiftsmall - fsmall
        # sum weighted differences
        fdiff = torch.sum(c[i_neigh] * fdiff ** 2, dim=1)
        # calculate factor (with added numerical stability)
        log_f = log_n[i_neigh] - (fdiff / 2)
        factor = torch.logsumexp((sel * log_f) + log_ps, 0)
        # unstable computation was:
        # factor = torch.log(
        #     (1 - p[i_neigh])
        #    + p[i_neigh] * normalizer[i_neigh] * torch.exp(- fdiff / 2))

        # put fact into the predicted place, to allow further normalizing
        # each factor is put into two slots, because we do prediction
        # in both directions,
        fact += get_fullsize(factor, neighbors[i_neigh])
        fact += get_fullsize(factor, -neighbors[i_neigh])
    if stop_grad:
        fact_norm = fact.detach()
        c_d = c.detach()
        log_n_d = log_n.detach()
        log_ps_d = log_ps.detach()
    else:
        fact_norm = fact
    for norm_v in norm_vects:
        fact_n = torch.zeros(feat.shape[0],
                             feat.shape[2], feat.shape[3], device=device)
        for i_neigh in range(neighbors.shape[0]):
            fsmall, fshiftsmall = get_fshifts(feat, neighbors[i_neigh])
            if stop_grad:
                fsmall = fsmall.detach()
                fshiftsmall = fshiftsmall.detach()
            fdiff_n1 = fshiftsmall - norm_v.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            fdiff_n2 = fsmall - norm_v.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            if stop_grad:
                fdiff_n1 = torch.sum(c_d[i_neigh] * fdiff_n1 ** 2, dim=1)
                fdiff_n2 = torch.sum(c_d[i_neigh] * fdiff_n2 ** 2, dim=1)
                log_f_n1 = log_n_d[i_neigh] - (fdiff_n1 / 2)
                log_f_n2 = log_n_d[i_neigh] - (fdiff_n2 / 2)
                factor_n1 = torch.logsumexp((sel * log_f_n1) + log_ps_d, 0)
                factor_n2 = torch.logsumexp((sel * log_f_n2) + log_ps_d, 0)
            else:
                fdiff_n1 = torch.sum(c[i_neigh] * fdiff_n1 ** 2, dim=1)
                fdiff_n2 = torch.sum(c[i_neigh] * fdiff_n2 ** 2, dim=1)
                log_f_n1 = log_n[i_neigh] - (fdiff_n1 / 2)
                log_f_n2 = log_n[i_neigh] - (fdiff_n2 / 2)
                factor_n1 = torch.logsumexp((sel * log_f_n1) + log_ps, 0)
                factor_n2 = torch.logsumexp((sel * log_f_n2) + log_ps, 0)
            fact_n += get_fullsize(factor_n1, -neighbors[i_neigh])
            fact_n += get_fullsize(factor_n2, neighbors[i_neigh])
        fact_norm = torch.logaddexp(fact_norm, fact_n)
    loss = -torch.mean(fact - fact_norm)
    return loss
