# Code source: cleanlab github
import cleanlab
import numpy as np
from sklearn.datasets import load_digits
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score



def generate_noisy_labels(y, num_errors):
    # Add lots of errors to labels
    labels = np.array(y)
    error_indices = np.random.choice(len(y), num_errors, replace=False)
    for i in error_indices:
        # Switch to some wrong label thats a different class
        wrong_label = np.random.choice(np.delete(range(10), labels[i]))
        # Confirm that we indeed added NUM_ERRORS label errors
        labels[i] = wrong_label
    assert sum(labels != y) == num_errors
    return labels


def find_label_error(labels, pred_probs, confident_joint):
    # STEP 2 - Find label errors

    # We arbitrarily choose at least 5 examples left in every class.
    # Regardless of whether some of them might be label errors.
    MIN_NUM_PER_CLASS = 5
    # Leave at least MIN_NUM_PER_CLASS examples per class.
    # NOTE prune_count_matrix is transposed (relative to confident_joint)
    prune_count_matrix = cleanlab.filter._keep_at_least_n_per_class(
        prune_count_matrix=confident_joint.T,
        n=MIN_NUM_PER_CLASS,
    )
    K = len(np.unique(labels))

    s_counts = np.bincount(labels)
    noise_masks_per_class = []
    # For each row in the transposed confident joint
    for k in range(K):
        noise_mask = np.zeros(len(pred_probs), dtype=bool)
        pred_probs_k = pred_probs[:, k]
        if s_counts[k] > MIN_NUM_PER_CLASS:  # Don't prune if not MIN_NUM_PER_CLASS
            for j in range(K):  # noisy label index (k is the true label index)
                if k != j:  # Only prune for noise rates, not diagonal entries
                    num2prune = prune_count_matrix[k][j]
                    if num2prune > 0:
                        # num2prune'th largest p(classk) - p(class j)
                        # for x with noisy label j
                        margin = pred_probs_k - pred_probs[:, j]
                        s_filter = labels == j
                        threshold = -np.partition(-margin[s_filter], num2prune - 1)[
                            num2prune - 1
                        ]
                        noise_mask = noise_mask | (s_filter & (margin >= threshold))
            noise_masks_per_class.append(noise_mask)
        else:
            noise_masks_per_class.append(np.zeros(len(labels), dtype=bool))

    # Boolean label error mask
    label_errors_bool = np.stack(noise_masks_per_class).any(axis=0)

    # Remove label errors if given label == model prediction
    for i, pred_label in enumerate(pred_probs.argmax(axis=1)):
        # np.all lets this work for multi_label and single label
        if label_errors_bool[i] and np.all(pred_label == labels[i]):
            label_errors_bool[i] = False
    return label_errors_bool


def get_confident_joint(labels, pred_probs):
    # STEP 1 - Compute confident joint
    # Verify inputs
    labels = np.asarray(labels)
    pred_probs = np.asarray(pred_probs)

    # Find the number of unique classes if K is not given
    K = pred_probs.shape[1] #len(np.unique(labels))

    # Estimate the probability thresholds for confident counting
    # You can specify these thresholds yourself if you want
    # as you may want to optimize them using a validation set.
    # By default (and provably so) they are set to the average class prob.
    thresholds = [np.mean(pred_probs[:, k][labels == k]) for k in range(K)]  # P(label^=k|label=k)
    thresholds = np.asarray(thresholds)

    # Compute confident joint
    confident_joint = np.zeros((K, K), dtype=int)
    for i, row in enumerate(pred_probs):
        s_label = labels[i]
        # Find out how many classes each example is confidently labeled as
        confident_bins = row >= thresholds - 1e-6
        num_confident_bins = sum(confident_bins)
        # If more than one conf class, inc the count of the max prob class
        if num_confident_bins == 1:
            confident_joint[s_label][np.argmax(confident_bins)] += 1
        elif num_confident_bins > 1:
            confident_joint[s_label][np.argmax(row)] += 1

    # Normalize confident joint (use cleanlab, trust me on this)
    confident_joint = cleanlab.count.calibrate_confident_joint(confident_joint, labels)
    return confident_joint


def get_val_confident_learning(is_mislabel, noisy_labels, pred_probs):
    # try:
    confident_joint = get_confident_joint(noisy_labels, pred_probs)
    # except:
    #     import pdb; pdb.set_trace()
    label_errors_bool = find_label_error(noisy_labels, pred_probs, confident_joint)
    return f1_score(is_mislabel, label_errors_bool), accuracy_score(is_mislabel, label_errors_bool), label_errors_bool


