import torch
import torch.nn.functional as F
from metrics import es_sentiment
from utils import gather_log_probs, mask_hf_labels, masked_mean
import sys


def balanced_bce(log_probs, labels, eps=torch.finfo(torch.float32).eps):
    assert labels.max() <= 1
    assert labels.min() >= 0

    pos_losses = -log_probs[labels == 1]
    neg_probs = 1 - log_probs.exp()
    neg_probs[neg_probs == 0] += eps  # for numerical stability
    neg_losses = -neg_probs.log()[labels == 0]
    pos_loss = pos_losses.mean() if pos_losses.numel() > 0 else 0
    neg_loss = neg_losses.mean() if neg_losses.numel() > 0 else 0

    return pos_loss + neg_loss


def kl_loc_loss(pre, post, mask=None):
    pre = pre.to(torch.float32).contiguous()
    post = post.to(torch.float32).contiguous()

    sequence = pre.dim() == 3
    pre_ = pre.view(-1, pre.shape[-1])
    post_ = post.view(pre_.shape)
    assert pre_.shape[0] == post_.shape[0]

    if not sequence:
        if pre_.shape[-1] == 1:  # No masking needed for binary classification
            return (pre.sigmoid() * (F.logsigmoid(pre) - F.logsigmoid(post))).mean() + (
                (-pre).sigmoid() * (F.logsigmoid(-pre) - F.logsigmoid(-post))
            ).mean()
    else:  # We have sequences of predictions; masking needed
        if pre_.shape[-1] > 1:
            assert mask is not None
            mask_ = mask.view(pre_.shape[0])
            kl = (pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1))).sum(-1)
            return (kl * mask_).sum() / mask_.sum()

    raise NotImplementedError


def binary_log_probs(pred, targ, should_reduce=True):
    assert targ.max() <= 1
    assert targ.min() >= 0
    neg_mask = torch.ones_like(pred)
    neg_mask[targ == 0] *= -1
    pred = pred * neg_mask
    log_probs = F.logsigmoid(pred)
    acc = (log_probs.exp() > 0.5).float()
    if should_reduce:
        acc = acc.mean()
    return {
        "acc": acc,
        "log_prob": log_probs.mean(),
        "prob": log_probs.exp().mean(),
        "nll": -log_probs.mean(),
        "n_tokens": log_probs.shape[0]
    }


def multiclass_log_probs(
    pred,
    raw_targets,
    shift=True,
    eps=torch.finfo(torch.float32).eps,
    should_reduce=True,
    **kwargs,
):
    NULL_TOKEN = 0  # a placeholder used for masked target locations

    pred = pred.clone()
    mask, targ = mask_hf_labels(raw_targets)
    if shift and pred.dim() == 3:  # Dealing with sequences
        pred = pred[:, :-1]  # Remove last prediction in sequence
        targ = targ[:, 1:]  # Shift to align predictions and targets
        mask = targ != -100

    unmasked_log_probs = gather_log_probs(pred, targ)

    pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
    correct = pred_ids == targ
    if pred.dim() == 3:
        correct = (pred_ids == targ).all(-1)  # We want to get the whole sequence right
    acc = correct.float()
    if should_reduce:
        acc = acc.mean()

    if "inner_sent" in kwargs:
        # Only use outer samples with the same sentiment as the inner sample
        same_sent_mask = torch.tensor([i == o for i, o in zip(kwargs["inner_sent"], kwargs["outer_sent"])], device=pred.device)
        good_mask = mask * same_sent_mask.unsqueeze(-1)
        bad_mask = mask * (~same_sent_mask.unsqueeze(-1))

        good_log_prob = masked_mean(unmasked_log_probs, good_mask)
        bad_log_prob = masked_mean((1 - unmasked_log_probs.exp() + eps).log(), bad_mask)

        n_tokens = good_mask.float().sum()
        avg_log_prob = good_log_prob

        if kwargs["unlikelihood"]:
            nll = -good_log_prob - bad_log_prob
        else:
            nll = -good_log_prob
    else:
        n_tokens = mask.float().sum()
        avg_log_prob = (unmasked_log_probs * mask.float()).sum() / n_tokens
        nll = -avg_log_prob

    info_dict = {
        "acc": acc,
        "log_prob": avg_log_prob,
        "prob": avg_log_prob.exp(),
        "n_tokens": n_tokens,
        "nll": nll
    }

    if "inner_sent" in kwargs:
        info_dict.update(es_sentiment(kwargs["pre_edit_logits"],
                                      kwargs["post_edit_logits"],
                                      raw_targets,
                                      same_sent_mask,
                                      q_mask=None if "q_mask" not in kwargs.keys() else torch.tensor(kwargs["q_mask"], device=pred.device)))

    return info_dict


def masked_log_probs(pred, targ, shift=True, **kwargs):
    pred = pred.to(torch.float32)

    if not (pred.dim() == 2 or pred.dim() == 3):
        raise RuntimeError(f"Expected pred to have 2 or 3 dimensions, got {pred.shape}")

    if pred.shape[-1] == 1:
        should_reduce = True
        if "should_reduce" in kwargs:
            should_reduce = kwargs["should_reduce"]
        return binary_log_probs(pred, targ, should_reduce=should_reduce)
    else:
        return multiclass_log_probs(pred, targ, shift=shift, **kwargs)


def test_masked_log_probs():
    print()
    N = 10000
    pred = torch.randn(10, 15, N)
    targ = torch.randint(0, N, (10, 15))
    true_pred = pred.clone()
    true_pred.scatter_(2, targ.unsqueeze(-1), 5)
    true_pred = true_pred.roll(-1, 1)

    half_pred = true_pred.clone()
    mask = torch.arange(10) % 2 == 0
    half_pred[mask] = pred[mask]

    pred_ = pred.clone()
    true_pred_ = true_pred.clone()
    half_pred_ = half_pred.clone()
    targ_ = targ.clone()

    print(masked_log_probs(pred, targ, return_acc=True))
    print(masked_log_probs(true_pred, targ, return_acc=True))
    print(masked_log_probs(half_pred, targ, return_acc=True))

    assert (pred == pred_).all()
    assert (targ == targ_).all()
    assert (half_pred == half_pred_).all()
    assert (true_pred == true_pred_).all()

    import pdb; pdb.set_trace()

    pred = torch.randn(1000, 15, 1)
    targ = torch.randint(0, 2, (1000, 15))

    print(masked_log_probs(pred, targ, return_acc=True))


if __name__ == "__main__":
    torch.manual_seed(0)

    test_masked_log_probs()
