import torch
from torch.distributions import Normal


# all loss functions take the same arguments as we abstractly utilize them in the `train.py`


def huber(losses, percentile=1):
    c = torch.quantile(torch.abs(losses), percentile)
    # print(c**2 + 2 * c * torch.clamp_min(torch.abs(losses) - c, 0) - losses**2)
    # assert torch.all(c**2 + 2 * c * torch.clamp_min(torch.abs(losses) - c, 0) - losses**2 >= 0)
    return torch.minimum(
        losses**2,
        c**2 + 2 * c * torch.clamp_min(torch.abs(losses) - c, 0),
    )


def no_update(log_pfs, log_pbs, log_fs, log_r, energy, states, coef_matrix, percentile=1):
    return torch.tensor(0.0).to(log_pfs.device)


def pis(log_pfs, log_pbs, log_fs, log_r, energy, states, coef_matrix, percentile=1):
    normalization_constant = float(1 / energy.ndim)
    losses = normalization_constant * (log_pfs.sum(-1) - log_pbs.sum(-1) - log_r)
    return losses.mean()


def tlm(log_pfs, log_pbs, log_fs, log_r, energy, states, coef_matrix, percentile=1):
    losses = -log_pbs.sum(-1)
    return losses.mean()


def tb(log_pfs, log_pbs, log_fs, log_r, energy, states, coef_matrix, percentile=1):
    losses = 0.5 * (log_pfs.sum(-1) + log_fs[:, 0] - log_pbs.sum(-1) - log_r)
    return huber(losses, percentile).mean()


def tb_avg(log_pfs, log_pbs, log_fs, log_r, energy, states, coef_matrix, percentile=1):
    log_Z = (log_r + log_pbs.sum(-1) - log_pfs.sum(-1)).mean(dim=0, keepdim=True)
    losses = 0.5 * (log_Z + (log_pfs.sum(-1) - log_r - log_pbs.sum(-1)))

    return huber(losses, percentile).mean()


def db(log_pfs, log_pbs, log_fs, log_r, energy, states, coef_matrix, percentile=1):
    with torch.no_grad():
        log_fs[:, -1] = energy.log_reward(states[:, -1]).detach()
    losses = 0.5 * huber(log_pfs + log_fs[:, :-1] - log_pbs - log_fs[:, 1:], percentile).sum(-1)
    return losses.mean()


def subtb(log_pfs, log_pbs, log_fs, log_r, energy, states, coef_matrix, percentile=1):
    diff_logp = log_pfs - log_pbs
    diff_logp_padded = torch.cat((torch.zeros((diff_logp.shape[0], 1)).to(diff_logp), diff_logp.cumsum(dim=-1)), dim=1)
    A1 = diff_logp_padded.unsqueeze(1) - diff_logp_padded.unsqueeze(2)
    A2 = log_fs[:, :, None] - log_fs[:, None, :] + A1
    # A2 = A2**2
    A2 = huber(A2, percentile)
    return torch.stack([torch.triu(A2[i] * coef_matrix, diagonal=1).sum() for i in range(A2.shape[0])]).sum()

