
import contextlib
import enum
import logging
import math
import pdb

import numpy as np
import torch
import tqdm
import ujson as json

import sklearn.cluster

import batchbald
import influence_functions
import greedy_coreset
from utils_multiple_choice import processors, load_and_cache_examples_from_file
from augment_datafiles_with_cls import embed_examples_clsfrozen, embed_examples_dense_surprisal
from utils_common import get_task_type

logger = logging.getLogger(__name__)

class NullContext(contextlib.AbstractContextManager):
    def __init__(self):
        pass
    def __enter__(self, *args, **kwargs):
        return None
    def __exit__(self, *args, **kwargs):
        pass

def entropy(probs):
    return -sum(prob*math.log(prob) for prob in probs)

def softmax(np_arr):
    arr = np.exp(np_arr - np.max(np_arr))
    return arr / np.sum(arr)

def intdiv_ceil(dividend, divisor):
    """Integer division of dividend / divisor, but round the result upward.
    This is different than simply wrapping float division in ceil() if the
    numbers are too big to be accurately represented as floats."""
    return (dividend + divisor - 1) // divisor


class ActiveLearner:
    r"""Samples elements based on active learning (pretending each new labeled
    example is expensive to obtain).
    Arguments:
        model (Module): the model being trained (used to judge candidate
                        examples)
        data_source (Dataset): dataset to sample from
        run_model_fn (function): A function that takes (args, model, batch) as input and returns (loss, logits)
    """

    def __init__(self, config, model, data_source, args, run_model_fn):
        """Constructor."""
        self.config = config
        self.model = model
        self.data_source = data_source
        self.args = args
        self.run_model_fn = run_model_fn

        self.refill_increment = self.config.get('refill_increment', 500)
        self.seed_dataset_size = self.config.get('seed_dataset_size', self.refill_increment)
        self._max_labels = self.config.get('max_labels', 5000)
        #self._influence_config = self.config.get('influence_config', {'influence_type': 'koh'})
        self.output_predictions_path = self.config.get('output_predictions_path', None)

        #self._aux_dev_path = self.config.get('aux_dev_path', None)

        if self.config.get('score_normalization_method', 'softmax') != 'softmax' or self.config.get('us_temperature', {'type': 'manual', 'value': 0.0})['value'] != 0.0:
            raise ValueError("The us_temperature and score_normalization_method config options are no longer supported")
        if 'score_normalization_method' in self.config or 'us_temperature' in self.config:
            logger.warning('The score_normalization_method and us_temperature config options are no longer supported and should not be included.')

        # Set seed this way because we expect other code set the seed via
        # np.random.seed for reproducibility
        self._rng = np.random.default_rng(np.random.randint(1, 2e9))

        self._human_labeled_indices = self._init_seed_set()

    def _init_seed_set(self):
        return set(self._rng.choice(range(len(self.data_source)), self.seed_dataset_size, replace=False))

    def get_human_labeled_dataset(self):
        """Get a torch Dataset of all the human-labeled examples"""
        all_data = [self.data_source[i] for i in self._human_labeled_indices]
        all_data = zip(*all_data)
        all_data = [torch.stack(data) for data in all_data]
        return torch.utils.data.TensorDataset(*all_data)

    def get_human_labeled_indices(self):
        """Get a List of the indices (in the data source) of human-labeled examples"""
        return list(self._human_labeled_indices)

    def get_max_labels(self):
        return self._max_labels

    def get_unlabeled_indices(self):
        return list(set(range(len(self.data_source))) - self._human_labeled_indices)

    def select_new_examples(self, size):
        raise NotImplementedError

    def acquire_batch(self, size):
        new_examples = self.select_new_examples(size)
        self._human_labeled_indices.update(set(new_examples))
        return new_examples

    def __len__(self):
        return self._max_labels


class RandomActiveLearner(ActiveLearner):

    def select_new_examples(self, size):
        candidate_indices = self.get_unlabeled_indices()
        return list(self._rng.choice(candidate_indices, min(len(candidate_indices), size), replace=False))


class ScoreActiveLearner(ActiveLearner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        #self._minibatch_scorer = self._init_minibatch_scorer()
        self.needs_grads = False

    def select_new_examples(self, size):
        candidate_indices = self.get_unlabeled_indices()

        result_scores = self.score_dataset(self.data_source, candidate_indices)

        if len(result_scores) != len(candidate_indices):
            raise ValueError("len(result_scores) != len(candidate_indices)")

        return list(map(lambda x: x[1], sorted(zip(result_scores, candidate_indices), reverse=True)[:size]))

    def score_dataset(self, data_source, candidate_indices):
        """Score the candidate indices, which are assumed to be indices into
        data_source.  Returns a List of scores corresponding to
        candidate_indices (i.e., a List the same length as candidate_indices
        where result[0] is the score corresponding to candidate_indices[0],
        etc)."""

        # NOTE: We always set model.eval() here. Even though we _do_ want
        # gradients for the expected_grad_norm score method, since we aren't
        # doing updates we don't necessarily want the randomness that comes
        # with the train-mode of things like dropout.
        was_training = self.model.training
        self.model.eval()

        batch_size = self.args.per_gpu_eval_batch_size * max(1, self.args.n_gpu)

        if self.needs_grads:
            # Although it might seem like a higher batch size would be more
            # efficient because the forward computation can be
            # parallelized, the backward pass is significantly slower. This
            # is almost certainly because we individually backward each
            # example in the batch. When backproping just one example from
            # the batch, the input gradients for all other examples along
            # the batch dimension will be zero. However, autograd doesn't
            # know this and will still do the computations to get the zeros
            # rather than trying to compute the gradient in a sparse way.
            # Therefore, a batch size of 1 is unfortunately the most
            # efficient.
            batch_size = 1

        result_scores = []
        predictions_json_obj = {
            'logits': [],
            'labels': [],
            'candidate_indices': candidate_indices,
            }
        for i in tqdm.trange(intdiv_ceil(len(candidate_indices), batch_size), desc="AL-batch"):
            batch = data_source[candidate_indices[i*batch_size:min(len(candidate_indices), (i+1)*batch_size)]]
            batch = tuple(t.to(self.args.device) for t in batch)

            grad_context = NullContext() if self.needs_grads else torch.no_grad()
            with grad_context:
                labels = batch[3]

                logits = self.run_model_fn(self.args, self.model, batch)[1]
                probs = torch.softmax(logits, dim=1)

                if self.output_predictions_path is not None:
                    predictions_json_obj['logits'].extend(logits.clone().detach().cpu().numpy().tolist())
                    predictions_json_obj['labels'].extend(labels.clone().detach().cpu().numpy().tolist())

                result_scores.extend(self.score_model_outputs(labels, logits, probs))

        if self.output_predictions_path is not None:
            with open(self.output_predictions_path, 'a') as f:
                f.write(json.dumps(predictions_json_obj) + '\n')

        self.model.train(mode=was_training)

        return result_scores

    def score_model_outputs(self, labels, logits, probs):
        raise NotImplementedError()

class EntropyActiveLearner(ScoreActiveLearner):
    def score_model_outputs(self, labels, logits, probs):
        entropies = -torch.sum(probs * torch.log(probs), dim=1)
        return entropies.tolist()

class LeastConfActiveLearner(ScoreActiveLearner):
    def score_model_outputs(self, labels, logits, probs):
        least_conf = -torch.max(probs, dim=1)[0]
        return least_conf.tolist()

class ExpectedGradientNormActiveLearner(ScoreActiveLearner):
    def score_model_outputs(self, labels, logits, probs):
        batch_size = int(labels.shape[0])
        loss_fn = torch.nn.CrossEntropyLoss()
        num_classes = logits.shape[-1]
        scores = []
        for ex_num in range(batch_size):
            expected_grad_norm = 0
            for label_idx in range(num_classes):
                loss = loss_fn(logits[ex_num:ex_num+1, :], torch.Tensor([label_idx], dtype=torch.long, device=self.args.device))

                # Retain graph except on the very last backprop
                retain_graph = not (ex_num == batch_size-1 and label_idx == num_classes-1)
                grads = torch.autograd.grad(loss, self.model.parameters(), only_inputs=True, retain_graph=retain_graph)

                grad_norm = np.sqrt(sum((g*g).sum().item() for g in grads))
                expected_grad_norm += probs[ex_num, label_idx].item() * grad_norm
            scores.append(expected_grad_norm)
        return scores

class InfluenceActiveLearner(ScoreActiveLearner):
    def __init__(self, *args, tokenizer=None, train_dataset=None, dev_dataset=None, **kwargs):
        super().__init__(*args, **kwargs)
        aux_dev_path = self.config.get('aux_dev_path', None)

        dev_dataset = load_and_cache_examples_from_file(self.args, aux_dev_path, processors[self.args.task_name](), tokenizer, datasplit='dev')
        influence_type = self.config.get('influence_type', 'koh')
        if influence_type == 'koh':
            self.inf_func = influence_functions.KohInfluenceFunction(self._model_loss_fn, list(self.model.parameters()), train_dataset, dev_dataset, config=self.config)
        elif influence_type == 'dev_grad':
            self.inf_func = influence_functions.DevgradInfluenceFunction(self._model_loss_fn, list(self.model.parameters()), dev_dataset)
        else:
            raise ValueError("Unknown influence function {}".format(influence_type))

    def _model_loss_fn(self, batch):
        """Returns the loss for a given batch"""
        loss = self.run_model_fn(self.args, self.model, batch)[0]

        return loss

    def score_model_outputs(self, labels, logits, probs):
        scores = []
        loss_fn = torch.nn.CrossEntropyLoss()
        for ex_num in range(labels.shape[0]):
            label_idx = labels[ex_num]
            loss = loss_fn(logits[ex_num:ex_num+1, :], torch.Tensor([label_idx], dtype=torch.long, device=self.args.device))

            grad = torch.autograd.grad(loss, self.model.parameters(), only_inputs=True, retain_graph=True)
            # Multiply influence by -1 since lower influence is
            # better (i.e., influence is the change in loss, so
            # negative influence means decreasing loss)
            influence = -1 * self.inf_func.calc_influence_from_grad(grad)
            scores.append(influence)
        return scores

class BaldActiveLearner(ScoreActiveLearner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if 'mcdropout_samples' not in self.config:
            raise ValueError("Need mcdropout_samples config option for BALD")
        self.num_mc_samples = self.config['mcdropout_samples']

        self.dropout_override = None
        if 'mcdropout_dropout_override' in self.config:
            self.dropout_override = self.config['mcdropout_dropout_override']
        else:
            logger.warning('Using original dropout weight for BALD.  Specify a mcdropout_dropout_override value to override dropout to particular value for BALD.')

    def score_dataset(self, data_source, candidate_indices):
        batch_size = self.args.per_gpu_eval_batch_size * max(1, self.args.n_gpu)

        was_training = self.model.training
        self.model.train()

        dropout_modules = [m for m in self.model.modules() if isinstance(m, torch.nn.Dropout)]
        orig_dropout_weights = [m.p for m in dropout_modules]

        if self.dropout_override is not None:
            for m in dropout_modules:
                m.p = self.dropout_override

        result_scores = []
        for step_num in tqdm.trange(intdiv_ceil(len(candidate_indices), batch_size), desc="AL-batch"):
            batch = data_source[candidate_indices[step_num*batch_size:min(len(candidate_indices), (step_num+1)*batch_size)]]
            batch = tuple(t.to(self.args.device) for t in batch)

            grad_context = torch.no_grad()
            with grad_context:
                probs_group = []
                for _ in range(self.num_mc_samples):
                    logits = self.run_model_fn(self.args, self.model, batch)[1]
                    probs_group.append(torch.softmax(logits, dim=1))

                bald_scores = self._score_outputs_bald(probs_group)
                result_scores.extend(bald_scores.tolist())

        if self.dropout_override is not None:
            for orig_p, m in zip(orig_dropout_weights, dropout_modules):
                m.p = orig_p

        self.model.train(mode=was_training)

        return result_scores

    def _score_outputs_bald(self, probs_group):
        probs_group = torch.stack(probs_group)
        entropies = -torch.sum(probs_group * torch.log(probs_group), dim=2)
        avgs_of_entropies = torch.mean(entropies, dim=0)
        avg_probs = torch.mean(probs_group, dim=0)
        entropies_of_avgs = -torch.sum(avg_probs * torch.log(avg_probs), dim=1)
        return entropies_of_avgs - avgs_of_entropies

class OracleMaxLossActiveLearner(ScoreActiveLearner):
    def score_model_outputs(self, labels, logits, probs):
        return torch.nn.functional.cross_entropy(logits, labels, reduction='none').tolist()

class OracleRandomMispredictionsActiveLearner(ScoreActiveLearner):
    def select_new_examples(self, size):
        candidate_indices = self.get_unlabeled_indices()

        result_scores = self.score_dataset(self.data_source, candidate_indices)

        if len(result_scores) != len(candidate_indices):
            raise ValueError("len(result_scores) != len(candidate_indices)")

        mislabeled_indices = [x[1] for x in zip(result_scores, candidate_indices) if x[0] == 0]
        if len(mislabeled_indices) < size:
            selected_idxs = mislabeled_indices
            remaining_size = size - len(mislabeled_indices)
            candidate_indices = [x for x in candidate_indices if x not in mislabeled_indices]
            selected_idxs.extend(list(self._rng.choice(candidate_indices, min(len(candidate_indices), remaining_size), replace=False)))
            return selected_idxs

        return list(self._rng.choice(mislabeled_indices, min(len(mislabeled_indices), size), replace=False))

    def score_model_outputs(self, labels, logits, probs):
        return (logits.argmax(dim=-1) == labels).type(torch.int64).tolist()

class CoresetActiveLearner(ActiveLearner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.do_embedding_retraining = self.config.get('coreset_retraining', False)
        embeddings = embed_examples_clsfrozen(self.model, self.data_source, get_task_type(self.args.task_name), batch_size=self.args.per_gpu_eval_batch_size, device=self.args.device, model_type=self.args.model_type)
        embeddings = torch.stack(embeddings)
        if self.seed_dataset_size == 0:
            logger.info('initializing coreset with {} coreset-selected examples'.format(self.seed_dataset_size))
            self._coreset_builder = greedy_coreset.CoresetBuilder(embeddings, greedy_coreset.l2_distances, initial_size=1, rng=self._rng)
            self._coreset_builder.acquire_points(n_points=(self.refill_increment-1))
            self._human_labeled_indices = set(self._coreset_builder.get_coreset_indices())
        else:
            self._coreset_builder = greedy_coreset.CoresetBuilder(embeddings, greedy_coreset.l2_distances, initial_size=0, rng=self._rng)
            for idx in self._human_labeled_indices:
                self._coreset_builder.add_to_coreset(idx)

    def select_new_examples(self, size):
        if self.do_embedding_retraining:
            embeddings = embed_examples_clsfrozen(self.model, self.data_source, get_task_type(self.args.task_name), batch_size=self.args.per_gpu_eval_batch_size, device=self.args.device, model_type=self.args.model_type)
            embeddings = torch.stack(embeddings)
            self._coreset_builder.set_embeddings(embeddings)
        selected = list(self._coreset_builder.acquire_points(size))
        self._rng.shuffle(selected)
        return selected


class FreezableDropout(torch.nn.Module):
    """Like dropout, but the mask can be frozen to be re-used across multiple inputs."""

    def __init__(self, p):
        super().__init__()
        self.p = p
        self.mask = None
        self.frozen = False

    def extra_repr(self):
        return "p={}".format(self.p)

    def set_mask(self, mask):
        self.mask = mask

    def _create_mask(self, shape, device):
        # Invert p since bernoulli_ wants the fraction of 1s, not 0s
        return torch.empty(shape, dtype=torch.bool, device=device).bernoulli_(1 - self.p)

    def set_mask_frozen(self, frozen):
        if not frozen:
            self.mask = None
        self.frozen = frozen

    def forward(self, inp):
        if self.p == 0.0 or not (self.frozen or self.training):
            return inp

        if self.frozen:
            if self.mask is None:
                self.set_mask(self._create_mask(inp.shape[1:], inp.device))
            # Use the same mask for each input in a batch
            mask = self.mask.unsqueeze(0).expand(inp.shape)

            return (inp*mask) / (1 - self.p)

        # self.training == True
        mask = self._create_mask(inp.shape, inp.device)
        return (inp*mask) / (1 - self.p)


def set_submodule_by_name(module, submodule_name, submodule):
    """
    Sets a submodule of `module` by name, including deep submodules like
    `encoder.layer.5.dropout`
    """
    cur_mod = module
    pathparts = submodule_name.split('.')
    for pathpart in pathparts[:-1]:
        if pathpart.isdigit():
            cur_mod = cur_mod[int(pathpart)]
        else:
            cur_mod = getattr(cur_mod, pathpart)

    if pathparts[-1].isdigit():
        cur_mod[int(pathparts[-1])] = submodule
    else:
        setattr(cur_mod, pathparts[-1], submodule)

def replace_dropout_with_freezabledropout(model):
    named_dropout_modules = [x for x in model.named_modules() if isinstance(x[1], torch.nn.Dropout)]
    for name, module in named_dropout_modules:
        if module.inplace:
            raise ValueError("Cannot use inplace with FreezableDropout")
        set_submodule_by_name(model, name, FreezableDropout(module.p))

class BatchBaldActiveLearner(ActiveLearner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if 'mcdropout_samples' not in self.config:
            raise ValueError("Need mcdropout_samples config option for BALD")
        self.num_mc_samples = self.config['mcdropout_samples']

        self.dropout_override = None
        if 'mcdropout_dropout_override' in self.config:
            self.dropout_override = self.config['mcdropout_dropout_override']
        else:
            logger.warning('Using original dropout weight for BALD.  Specify a mcdropout_dropout_override value to override dropout to particular value for BALD.')

        logger.debug('Replacing Dropout with FreezableDropout in model')
        replace_dropout_with_freezabledropout(self.model)

    def generate_mc_samples(self, data_source, candidate_indices):
        batch_size = self.args.per_gpu_eval_batch_size * max(1, self.args.n_gpu)

        was_training = self.model.training
        self.model.eval()

        dropout_modules = [m for m in self.model.modules() if isinstance(m, FreezableDropout)]
        orig_dropout_weights = [m.p for m in dropout_modules]

        if self.dropout_override is not None:
            for m in dropout_modules:
                m.p = self.dropout_override

        all_probs = None
        with torch.no_grad():
            for _ in tqdm.trange(self.num_mc_samples, desc='BatchBALD-MC-samples'):
                sample_probs = None
                for m in dropout_modules:
                    m.set_mask(None)
                    m.set_mask_frozen(True)
                for step_num in tqdm.trange(intdiv_ceil(len(candidate_indices), batch_size), desc="AL-BatchBALD"):
                    batch = data_source[candidate_indices[step_num*batch_size:min(len(candidate_indices), (step_num+1)*batch_size)]]
                    batch = tuple(t.to(self.args.device) for t in batch)

                    logits = self.run_model_fn(self.args, self.model, batch)[1]
                    probs = torch.softmax(logits, dim=1).cpu()
                    if sample_probs is None:
                        sample_probs = probs
                    else:
                        sample_probs = torch.vstack([sample_probs, probs])
                if all_probs is None:
                    all_probs = sample_probs.unsqueeze(0)
                else:
                    all_probs = torch.vstack([all_probs, sample_probs.unsqueeze(0)])

        for m in dropout_modules:
            m.set_mask_frozen(False)

        if self.dropout_override is not None:
            for orig_p, m in zip(orig_dropout_weights, dropout_modules):
                m.p = orig_p

        self.model.train(mode=was_training)

        return all_probs.transpose(0,1)

    def select_new_examples(self, size):
        candidate_indices = self.get_unlabeled_indices()

        probs = self.generate_mc_samples(self.data_source, candidate_indices)
        probs = probs.to(self.args.device)
        selected_candidate_idxs = batchbald.compute_multi_bald_batch(probs, size, device=self.args.device)
        batch_idxs = [candidate_indices[x] for x in selected_candidate_idxs]

        return batch_idxs

    #def _score_outputs_bald(self, probs_group):
    #    probs_group = torch.stack(probs_group)
    #    entropies = -torch.sum(probs_group * torch.log(probs_group), dim=2)
    #    avgs_of_entropies = torch.mean(entropies, dim=0)
    #    avg_probs = torch.mean(probs_group, dim=0)
    #    entropies_of_avgs = -torch.sum(avg_probs * torch.log(avg_probs), dim=1)
    #    return entropies_of_avgs - avgs_of_entropies

class AlpsActiveLearner(ActiveLearner):
    """
    ALPS method described in "Cold-start Active Learning through
    Self-Supervised Language Modeling" (EMNLP 2020):
    https://arxiv.org/pdf/2010.09535.pdf

    NOTE: There is another method that used to also be called ALPS and was
    later changed to SoCal, but that is different from this.
    """

    def __init__(self, *args, **kwargs):
        """The model passed to this constructor should be a pretrained masked
        language modeling (MLM) model, *not* an end task model."""
        super().__init__(*args, **kwargs)
        self.embeddings, self.seq_lengths = embed_examples_dense_surprisal(self.model, self.data_source, get_task_type(self.args.task_name), batch_size=self.args.per_gpu_eval_batch_size, device=self.args.device, model_type=self.args.model_type)
        self.embeddings = self.embeddings.numpy()
        self.seq_lengths = np.array(self.seq_lengths, dtype=np.int32)

        # We're done with this now
        del self.model

        # compute loss for 15% of tokens
        self.token_sample_proportion = 0.15

        # Acquire seed set using ALPS instead of random sampling
        self.acquire_batch(self.seed_dataset_size)

    def _init_seed_set(self):
        # ALPS doesn't use random initialization, so we override this from
        # superclass and do initialization manually
        return set()

    def select_new_examples(self, size, candidate_indices=None):
        candidate_indices = np.array(self.get_unlabeled_indices())

        # Step 1: Run examples through the model with MLM head (probably need
        #         to be passed some info in constructor about how to do MLM for
        #         a given model).  DON'T MASK anything in input, but still
        #         compute MLM loss for a random 15% of output tokens.  Tokens
        #         that don't get loss computed are treated as having loss of 0.
        #         Pad with zeros to max length (TODO: double-check if this is
        #         correct based on paper).  These are the "surprisal
        #         embeddings".
        # (Since the dense surprisals are pre-computed, all we have to do is sample tokens and normalize here.)
        candidate_embeds = self.embeddings[candidate_indices]
        candidate_lens = self.seq_lengths[candidate_indices]
        mask = np.zeros_like(candidate_embeds, dtype=np.int32)
        for idx, _len in enumerate(candidate_lens):
            num_to_keep = int(np.ceil(_len * self.token_sample_proportion))
            mask[idx, self._rng.choice(_len, size=num_to_keep, replace=False)] = 1
        candidate_embeds *= mask

        # L2 norm
        candidate_embeds /= np.expand_dims(np.sqrt((candidate_embeds * candidate_embeds).sum(axis=-1)), 1)

        # Step 2: Cluster the surprisal embeddings with k-means, where K is the
        #         acquisition batch size.
        clusterer = sklearn.cluster.KMeans(n_clusters=size, init='random', algorithm='full')
        cluster_ids = clusterer.fit_predict(candidate_embeds)
        clusters = {cluster_id: [] for cluster_id in range(size)}
        for batch_idx, cluster_id in enumerate(cluster_ids):
            clusters[cluster_id].append(batch_idx)

        # Step 3: For each k-means cluster, pick the example with the surprisal
        #         embedding closest to the cluster center.  This results in
        #         a set of K selected examples, which is why we used the
        #         acquisition batch size for K.
        selected_batch_idxs = []
        for cluster_id, batch_idxs in clusters.items():
            dists2 = candidate_embeds[np.array(batch_idxs)] - clusterer.cluster_centers_[cluster_id]
            dists2 = (dists2 * dists2).sum(axis=1)
            selected_batch_idxs.append(batch_idxs[dists2.argmin()])

        # Finally, map batch idxs back to dataset idxs
        return [candidate_indices[batch_idx] for batch_idx in selected_batch_idxs]


ACTIVE_LEARNING_CLASSES = {
        "alps": AlpsActiveLearner,
        'bald': BaldActiveLearner,
        'batchbald': BatchBaldActiveLearner,
        'entropy': EntropyActiveLearner,
        'expected_grad_norm': ExpectedGradientNormActiveLearner,
        'greedy_coreset': CoresetActiveLearner,
        'influence_function': InfluenceActiveLearner,
        'least_confident': LeastConfActiveLearner,
        'oracle_max_loss': OracleMaxLossActiveLearner,
        'oracle_random_mispredictions': OracleRandomMispredictionsActiveLearner,
        'random': RandomActiveLearner,
    }
