# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/05aa_methods.uncertain.ipynb (unless otherwise specified).

__all__ = ['UncertainCF']

# Cell
from ..import_essentials import *
from ..utils import cat_normalize
from ..training_module import CounterfactualTrainingModule
from torch import Tensor

# Internal Cell
class TargetedLossFunction:
    """Loss function for generating CEs in a particular target class."""
    def loss(self, outputs: Tensor, original_labels: Tensor) -> Tensor:
        batch_size = outputs.size(0)
        targets = 1. - original_labels
        assert targets.size() == (batch_size, )
        return -F.binary_cross_entropy(outputs[:, 1], targets)


# Internal Cell
def _get_prediction_and_grad(
    pred_fn: Callable[[Tensor], Tensor],
    loss_function,
    examples: Tensor, original_labels: Tensor
) ->Tuple[Tensor, Tensor, Tensor]:
    examples = examples.clone().detach()
    assert examples.grad is None
    examples.requires_grad = True
    output = pred_fn(examples)
    loss = loss_function.loss(output, original_labels)
    loss.backward()
    confidence, _ = torch.max(output, dim=1)
    return output, confidence, examples.grad.clone()

# Internal Cell
def _uncertaincf(
    originals: Tensor,
    pred_fn: Callable[[Tensor], Tensor],
    loss_function,
    n_steps: int,
    n_changes: int, # n_changes = 10
    confidence_threshold: float,
    cat_arrays: List[List[str]],
    cat_idx: int
):
    batch_size = originals.size(0)

    examples = originals.clone().detach().view(batch_size, -1)

    original_labels = pred_fn(originals).argmax(dim=1).detach()
    assert original_labels.shape == (batch_size, )
    # assert_shape(original_labels, (batch_size,))

    input_flat_size = originals.view(originals.size(0), -1).size(1)
    _perturbations = torch.full((input_flat_size,), 1.0 / n_changes)

    batch_perturbations = (
        _perturbations.unsqueeze(0).repeat(batch_size, 1)
    )

    altered_pixels = torch.zeros(size=examples.shape, device=examples.device, dtype=torch.int)

    for i in range(n_steps):
        prediction, confidence, grad = _get_prediction_and_grad(
            pred_fn, loss_function, examples, original_labels
        )

        have_changed_class = torch.argmax(prediction, -1) != original_labels
        if torch.sum(have_changed_class) == batch_size:
            break

        # If we have already changed a pixel n_changes times, set the gradient to zero so we
        # don't change it again.
        grad[altered_pixels >= n_changes] = 0.0

        # We want to change the pixel with the largest gradient, which is the most sensitive.
        max_mask = grad.abs() == grad.abs().max(dim=1, keepdim=True)[0]
        have_changed_class_mask = have_changed_class.view(batch_size, 1).repeat(
            1, max_mask.size(1)
        )
        confidence_mask = (
            (confidence < confidence_threshold)
            .view(batch_size, 1)
            .repeat(1, max_mask.size(1))
        )
        # Change the pixel with the largest gradient, if it is part of an example which either
        # hasn't changed class, or the class prediction is not >=95%.
        to_change_mask = (
            max_mask & (~have_changed_class_mask | confidence_mask) & (grad != 0.0)
        )
        grad_sign = grad[to_change_mask].sign()
        perturbation_size = batch_perturbations[to_change_mask]
        examples[to_change_mask] += perturbation_size * grad_sign
        altered_pixels[to_change_mask] += 1

        examples = torch.clamp(examples, 0.0, 1.0)
        examples = cat_normalize(examples, cat_arrays, cat_idx, hard=False)
    return examples.view(originals.size())

# Cell
class UncertainCF(CounterfactualTrainingModule):
    def __init__(self, config, model: pl.LightningModule, n_steps: int = 500, n_changes: int = 10, confidence_threshold: float = 0.99):
        """
        config: basic configs
        model: the black-box model to be explained
        """
        super().__init__(config)
        self.n_steps = n_steps
        self.n_changes = n_changes
        self.confidence_threshold = confidence_threshold
        self.model = model
        self.model.freeze()
        self.prepare_data()

    def predict(self, x):
        return self.model.predict(x)

    def generate_cf(self, x):
        def pred_fn(x):
            output = self.model(x)
            return torch.cat([1. - output, output], dim=1)

        cat_idx = len(self.continous_cols)
        return _uncertaincf(
            x, pred_fn, TargetedLossFunction(),
            self.n_steps, self.n_changes, self.confidence_threshold,
            cat_arrays=self.cat_array, cat_idx=cat_idx
        )
