import numpy as np
import torch


def flip_binary_labels(targets, p=0.05):
    assert set(list(targets.detach().cpu().numpy())).issubset({0, 1})
    assert 0 <= p <= 1
    opposite_targets = -1 * targets + 1.

    # Determine which labels to flip.
    flip_or_not = torch.bernoulli(p * torch.ones_like(targets))

    # Compute the final labels.
    final_targets = (1 - flip_or_not) * targets + flip_or_not * opposite_targets
    return final_targets.long()


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.dl.losses.noisy_labels
    """
    test_num = 0

    if test_num == 0:
        # Construct the labels.
        curr_labels = np.random.binomial(n=1, p=0.5, size=(1000000,))
        curr_labels = torch.tensor(curr_labels).long()

        # Get the flipped labels.
        p = 0.05
        new_labels = flip_binary_labels(targets=curr_labels, p=p)

        # Make sure only the correct number of relevant targets are flipped.
        print(f"Expected:{p}, observed: {torch.sum(torch.abs(curr_labels - new_labels)) / new_labels.shape[0]}")
