# coding: utf-8

import torch

def signed_logaddexp(x1, x2, sign1, sign2, return_sign=False, eps=1e-8):
	maxabs = torch.maximum(x1.abs(), x2.abs())
	scaled_sum = (sign1*(x1-maxabs).exp()+sign2*(x2-maxabs).exp())
	out = maxabs + scaled_sum.clamp_min(eps).abs().log()
	if return_sign:
		return out,scaled_sum.sign()
	return out