from scipy.integrate import quad
import torch


def grad_loss(y, a, y0, y1: torch.Tensor, return_grad=False):
    if return_grad:
        y1.requires_grad_(True)
        y1.retain_grad()
    mixed_threshold = y1*a+(1-a)*y0
    values = (y <= mixed_threshold).float()
    out = torch.mean((y1-y)*(2*a-1)*values)
    if return_grad:
        out.backward()
        return out, y1.grad
    else:
        return out


def ipw_loss(y, a, y0, y1: torch.Tensor, propensities, return_grad=False):
    if return_grad:
        y1.requires_grad_(True)
        y1.retain_grad()
    mixed_threshold = y1*a+(1-a)*y0
    weights = a/propensities-(1-a)/(1-propensities)
    values = (y <= mixed_threshold).float()
    out = torch.mean((y1-y)*weights*values)
    if return_grad:
        out.backward()
        return out, y1.grad
    else:
        return out


# Note: this is not actually the loss. Just a function whose gradient gives the desired output.
def dr_loss(y, a, y0, y1: torch.Tensor, propensities, cdf_0s, cdf_1s, return_grad=False):
    cdf_0s = cdf_0s.detach()
    cdf_1s = cdf_1s.detach()
    if return_grad:
        y1.requires_grad_(True)
        y1.retain_grad()
    mixed_threshold = y1*a+(1-a)*y0
    mixed_cdfs = a*cdf_1s+(1-a)*cdf_0s
    weights = a/propensities-(1-a)/(1-propensities)
    values = (y <= mixed_threshold).float()
    out = torch.mean((y1-y)*(weights*(values-mixed_cdfs)+cdf_1s-cdf_0s))
    if return_grad:
        out.backward()
        return out, y1.grad
    else:
        return out


def true_dr_loss(y, x, a, y0, y1: torch.Tensor, propensities, cdf_0s, cdf_1_func, w1=None):
    cdf_0s = cdf_0s.detach()
    if w1 is None:
        w1 = [None]*len(y0)
    # Do simple part
    mixed_threshold = y1*a+(1-a)*y0
    mixed_cdfs = (1-a)*cdf_0s
    weights = a/propensities-(1-a)/(1-propensities)
    values = (y <= mixed_threshold).float()
    out_sample = (y1-y)*(weights*(values-mixed_cdfs)-cdf_0s)
    cdf_part = torch.zeros_like(y1)
    # Do the cdf part
    for i, (single_y1, single_y, single_x, single_w1) in enumerate(zip(y1, y, x, w1)):
        cdf_part[i] = quad(cdf_1_func, single_y.item(), single_y1.item(), args=(single_x, single_w1),
                           epsabs=1e-3, limit=1000)[0]
        return out_sample+cdf_part*(propensities-a)/propensities


def doubled_dr_loss(y, a, y0, y1: torch.Tensor, propensities, cdf_0s, cdf_1s, return_grad=False):
    cdf_0s = cdf_0s.detach()
    cdf_1s = cdf_1s.detach()
    if return_grad:
        y1.requires_grad_(True)
        y1.retain_grad()
    mixed_threshold = y1*a+(1-a)*y0
    mixed_cdfs = a*cdf_1s+(1-a)*cdf_0s
    weights = a/propensities-(1-a)/(1-propensities)
    values = (y <= mixed_threshold).float()-(y > mixed_threshold).float()
    out = torch.mean((y1-y)*(weights*(values+1-2*mixed_cdfs)+2*(cdf_1s-cdf_0s)))
    if return_grad:
        out.backward()
        return out, y1.grad
    else:
        return out


def doubled_grad_loss(y, a, y0, y1: torch.Tensor, return_grad=False):
    if return_grad:
        y1.requires_grad_(True)
        y1.retain_grad()
    mixed_threshold = y1*a+(1-a)*y0
    values = (y <= mixed_threshold).float()
    out = torch.mean((y1-y)*(2*a-1)*values+(y-y1)*(2*a-1)*(1-values))
    if return_grad:
        out.backward()
        return out, y1.grad
    else:
        return out


def doubled_ipw_loss(y, a, y0, y1: torch.Tensor, propensities, return_grad=False):
    if return_grad:
        y1.requires_grad_(True)
        y1.retain_grad()
    mixed_threshold = y1*a+(1-a)*y0
    weights = a/propensities-(1-a)/(1-propensities)
    values = (y <= mixed_threshold).float()
    out = torch.mean(weights*(values*(y1-y)+(1-values)*(y-y1)))
    if return_grad:
        out.backward()
        return out, y1.grad
    else:
        return out
