import torch
import numpy as np

def eps_loss(x, output):
    assert "loss" in output.keys()
    return output["loss_dict"]["loss_simple"].sum([1,2,3]).mean()

def pos_pair_loss(x, output):
    pos_pair = ((output["pos"] - output["pival"]) ** 2).sum([1,2,3]).mean()
    return pos_pair

def neg_pair_loss(x, output):
    neg_pair = ((output["neg"] - output["ori_neg"]) ** 2).sum([1,2,3]).mean()
    return neg_pair

def uc_pair_loss(x, output):
    uc_pair = ((output["uc"] - output["ori_uc"]) ** 2).sum([1,2,3]).mean()
    return uc_pair

def pival_loss(x, output):
    uc_pair = ((output["new_pival"] - output["pival"]) ** 2).sum([1,2,3]).mean()
    return uc_pair