import numpy as np


def is_fading_exp_run(args):
    return "fading_in_exp" in args.__dict__ and args.fading_in_exp


def construct_corrupted_dataset(label_noise, opacity, features, labels, seed):
    """
    Constructs a dataset with images corrupted by fading in the other class' digit.

    :param label_noise: Fraction of images that will be corrupted.
    :param opacity: The degree of fading in the other class' digit.
    :param features: Original dataset features.
    :param labels: Dataset labels (will be retained)
    :param seed: Seed used for randomized operations (e.g., determining the instances to be corrupted).
    :return: Returns the adjusted images, the original labels and a corruption mask storing the fading in opacity of
    each image (0 if no corruption has been done).
    """
    assert 0 <= label_noise <= 1, "Label_noise must be between 0 and 1."
    assert opacity <= 1, "Opacity must be lower equals 1 (or negative for uniform sampling)."

    assert features is not None, "Features must be provided."
    assert labels is not None, "Labels must be provided."

    assert len(features) == len(labels), "Features and labels must have the same length."

    new_features = features.detach().clone()

    opacitiy_mask = np.zeros(len(features))

    np.random.seed(seed)
    for i in range(features.shape[0]):
        if np.random.random() < label_noise:
            selected_idx = np.random.randint(0, len(labels[labels != labels[i]]))
            selected_instances = features[labels != labels[i]][selected_idx]

            if opacity < 0:
                instance_opacity = np.random.random()
            else:
                instance_opacity = opacity

            new_features[i] = (1. - instance_opacity) * features[i] + instance_opacity * selected_instances
            opacitiy_mask[i] = instance_opacity

    return new_features, labels, opacitiy_mask


def determine_fading_memorization(args, train_features, train_labels, target_means, opacities):
    corr_mask = opacities > 0

    pointwise_distances = np.zeros(args.classes, dtype=float)
    for i in range(args.classes):
        corrupted_class_instances = train_features[corr_mask][train_labels[corr_mask][:, 1] == i]
        weights = opacities[corr_mask][train_labels[corr_mask][:, 1] == i]
        for j in range(corrupted_class_instances.shape[0]):
            weight = weights[j]
            pointwise_distances[i] += weight * np.linalg.norm(corrupted_class_instances[j] - target_means[i])
        pointwise_distances[i] /= corrupted_class_instances.shape[0]
    return pointwise_distances
