---
jupytext:
  text_representation:
    extension: .md
    format_name: myst
    format_version: 0.13
    jupytext_version: 1.16.7
kernelspec:
  display_name: .venv
  language: python
  name: python3
---

# Using ReX interactively

+++

There are three key components that need to be set up in order to use ReX: the input parameters, the model, and the input data. In this tutorial we will walk through how to set up each of these for a simple ReX analysis using an image classification model. We will use ReX to calculate a responsibility landscape and identify a minimal explanation for the classification of an image by the [ResNet50 model](https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet50.html#torchvision.models.resnet50). We will also plot this explanation and the responsibility landscape on the original image. 

## Set up

First, we will set up the input parameters as a `CausalArgs` object.

You can create a `CausalArgs` object with default parameters and then modify it, or use [`load_config`](https://rex-xai.readthedocs.io/en/latest/autoapi/rex_xai/config/index.html#rex_xai.config.load_config) to read your desired set of input parameters from a `rex.toml` file.

```{code-cell} ipython3
from rex_xai.input.config import CausalArgs

args = CausalArgs()
print(args)
```

For the purpose of this tutorial we will set `gpu = False` and use 10 iterations. We will also set a seed to ensure reproducible outputs.

```{code-cell} ipython3
args.gpu = False
args.iters = 10
args.seed = 123
print(args)
```

Next, we will set up the model details. We need to provide ReX with three things:

- the shape the model expects the input data to be
- a preprocessing function that will apply the appropriate transforms to the input data
- a prediction function to be applied to the transformed input data

For this tutorial we will use the [ResNet50 model](https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet50.html#torchvision.models.resnet50) as provided by the `torchvision` library.

```{code-cell} ipython3
:tags: [hide-output]

from torchvision.models import resnet50

model = resnet50(weights="ResNet50_Weights.DEFAULT")
model.eval()
model.to("cpu")
```

We need to define a `model_shape` object which is a list with the expected shape of the input data for the model. In this case the order is `[batch, channels, height, width]`. We use "N" for the batch size in this case as this model accepts dynamic batch sizes. The actual batch size used can be set in the `CausalArgs` object.

```{code-cell} ipython3
model_shape = ["N", 3, 224, 224]
```

We also need to define a `preprocess` function which will be applied to our input data file, and return a ReX `Data` object that has been appropriately transformed.
The transformations (e.g. resizing, normalisation) should be the same as were used for the model's training. For this model, we use the following transformations:

```{code-cell} ipython3
from torchvision import transforms as T
from PIL import Image

from rex_xai.input.input_data import Data

def preprocess(path, shape, device, mode) -> Data:
    transform = T.Compose(
        [
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    img = Image.open(path).convert("RGB")
    data = Data(img, shape, device, mode=mode, process=False)
    data.data = transform(img).unsqueeze(0).to(device) 
    data.mode = "RGB"
    data.model_shape = shape
    data.model_height = 224
    data.model_width = 224
    data.model_channels = 3
    data.transposed = True
    data.model_order = "first"

    return data
```

Finally, we will define a `prediction_function` that will be applied to our input data and mutants:

```{code-cell} ipython3
import torch as tt
import torch.nn.functional as F

from rex_xai.responsibility.prediction import from_pytorch_tensor

def prediction_function(mutants, target=None, raw=False, binary_threshold=None):
    with tt.no_grad():
        tensor = model(mutants)
        if raw:
            return F.softmax(tensor, dim=1)
        return from_pytorch_tensor(tensor)
```

Now we are almost ready to run ReX. The final step is to set up the input data.

```{code-cell} ipython3
from rex_xai.explanation.rex import load_and_preprocess_data, predict_target
from rex_xai.utils._utils import get_device

device = get_device(gpu=False)

args.path = '../../tests/test_data/ladybird.jpg'
data = load_and_preprocess_data(model_shape, device, args)
```

This is our input image:

```{code-cell} ipython3
Image.open(args.path)
```

We will now set the mask value to be used to mask the data when creating mutants, and predict the class of the original input image (the 'target' for the mutants).

ReX allows functions to be used to set the mask value (e.g. the 'min' or 'mean' of the normalised image), but the default mask value of 0 generally performs well enough for images.

```{code-cell} ipython3
data.set_mask_value(0)
data.target = predict_target(data, prediction_function)

print(data.target)
```

## Running ReX

We are now ready to run ReX and identify a causal explanation for this classification.

The main function in ReX is `calculate_responsibility`, which returns a `ResponsibilityMaps` object and some statistics about the function execution.
We can then create an `Explanation` object containing the calculated responsibility landscape and `extract` an explanation from this object.
Here we will use the `Global` strategy. Additional strategies are available in the `rex_xai._utils.Strategy` Enum.

```{code-cell} ipython3
:tags: [hide-output]

from rex_xai.explanation.rex import calculate_responsibility
from rex_xai.explanation.explanation import Explanation
from rex_xai.utils._utils import Strategy

resp_maps, stats = calculate_responsibility(data, args, prediction_function)
exp = Explanation(resp_maps, prediction_function, data, args, stats)
exp.extract(Strategy.Global)
```

## Examining the results

The mask corresponding to the final explanation is stored in the `Explanation` object and can be printed:

```{code-cell} ipython3
print(exp.final_mask)
```

ReX also provides some plotting methods for easier visualisation of the explanation and the responsibility landscape.

```{code-cell} ipython3
display(exp.show())
```

The responsibility landscape can be plotted as a heatmap or 3D surface plot:

```{code-cell} ipython3
exp.heatmap_plot()
exp.surface_plot()
```

If you are not satisfied with the quality of the explanation output, a good place to start is to check the run statistics output by `calculate_responsibility`. 
In particular, check the tree depth and numbers of mutants examined, as low tree depth or low numbers of mutants examined can lead to strange results. The easiest way to increase both of these and refine an explanation is to increase the number of iterations you use.

```{code-cell} ipython3
print(stats)
```

In this case, the tree depth is 9, almost 1000 mutants have been assessed, and the returned explanation matches well with what we would expect, so we are happy with the quality of the explanation and don’t need to increase the iterations.

Another way to check explanation quality would be to analyze the `Explanation` object and calculate some common metrics.

```{code-cell} ipython3
from rex_xai.explanation.rex import analyze

analyze(exp, data.mode)
```

## Multiple Explanations

There may be multiple regions of an image that are sufficient to explain its classification.
ReX provides a way to identify multiple explanations by using the `MultiExplanation` class and `MultiSpotlight` strategy.

We will use a different image to illustrate this approach, so we need to set up the Data object for this new image.
We will copy the `args` object and change the path to point to the new data.

```{code-cell} ipython3
import copy
peacock_args = copy.copy(args)
peacock_args.path = "../../tests/test_data/peacock.jpg"
data = load_and_preprocess_data(model_shape, device, peacock_args)
Image.open(peacock_args.path)
```

We set the mask value to zero again and predict the class of the new image.

```{code-cell} ipython3
data.set_mask_value(0)
data.target = predict_target(data, prediction_function)

print(data.target)
```

Next, we will run `calculate_responsibility` on the new image, create a `MultiExplanation` object, and extract explanations.

```{code-cell} ipython3
from rex_xai.explanation.multi_explanation import MultiExplanation

resp_maps, stats = calculate_responsibility(data, peacock_args, prediction_function)

multi_exp = MultiExplanation(resp_maps, prediction_function, data, peacock_args, stats)
multi_exp.extract(Strategy.MultiSpotlight)
```

Here we have used the default settings, which generate (up to) ten explanations.
The first explanation is always the explanation identified with `Strategy.Global`.

```{code-cell} ipython3
multi_exp.show(multi_style="separate")
```

We can use the `separate_by` method to identify sets of explanations that do not overlap with each other (or that have an overlap less than some maximum threshold).
Here we will use an overlap of zero to find explanations that have no overlap with each other.

```{code-cell} ipython3
clauses = multi_exp.separate_by(0)
print(clauses)
```

We have identified three groups of 6 explanations that have no overlap with each other.
The "composite" plotting style can be used to plot these explanations in a single plot.
Here we will just plot the first group.

```{code-cell} ipython3
multi_exp.show(multi_style="composite", clauses=[clauses[0]])
```

## Saving results

We may want to run ReX on multiple input datasets and save the results for further analysis later.
We can save results from an `Explanation` or `MultiExplanation` object in a sqlite database.
We first initialise the database and then save the ladybird explanation and the multiple explanations of the peacock image.

We did not measure the time taken to calculate the responsibility landscape and extract explanations in this tutorial, so here we use zero in place of the time.
If you wish to calculate this you can use the `time.time()` function to calculate timestamps before and after the steps you wish to time. 

Before saving the results, we should update the `strategy` saved in the `Explanation` or `MultiExplanation` object to ensure it matches the strategy we used, as this will be saved in the database.

```{code-cell} ipython3
from rex_xai.output.database import initialise_rex_db, update_database

db = initialise_rex_db("rex.db")

exp.args.strategy = Strategy.Global
update_database(db, exp, time_taken = None)

multi_exp.args.strategy = Strategy.MultiSpotlight
update_database(db, multi_exp, time_taken = None, multi=True)
```

ReX provides a helper function to read the results from the database into a [Pandas](https://pandas.pydata.org/) dataframe for further analysis.

```{code-cell} ipython3
from rex_xai.output.database import db_to_pandas

df = db_to_pandas("rex.db")
print(df)
```

```{code-cell} ipython3
:tags: [hide-cell]

import os
os.remove("rex.db")
```

```{code-cell} ipython3

```
