Skip to content



A class to generate saliency masks from a given model.


Name Type Description Default
model Model

The model, compiled with, which is to be inspected.

model_inputs Union[str, Sequence[str]]

The key(s) corresponding to the model inputs within the data dictionary.

model_outputs Union[str, Sequence[str]]

The key(s) corresponding to the model outputs which are written into the data dictionary.

outputs Union[str, List[str]]

The keys(s) under which to write the generated saliency images.

Source code in fastestimator/fastestimator/xai/
class SaliencyNet:
    """A class to generate saliency masks from a given model.

        model: The model, compiled with, which is to be inspected.
        model_inputs: The key(s) corresponding to the model inputs within the data dictionary.
        model_outputs: The key(s) corresponding to the model outputs which are written into the data dictionary.
        outputs: The keys(s) under which to write the generated saliency images.
    def __init__(self,
                 model: Model,
                 model_inputs: Union[str, Sequence[str]],
                 model_outputs: Union[str, Sequence[str]],
                 outputs: Union[str, List[str]] = "saliency"):
        mode = "test"
        self.model_op = ModelOp(model=model, mode=mode, inputs=model_inputs, outputs=model_outputs, trainable=False)
        self.outputs = to_list(outputs)
        self.mode = mode
        self.gather_keys = ["SaliencyNet_Target_Index_{}".format(key) for key in self.model_outputs] = Network(ops=[
            Watch(inputs=self.model_inputs, mode=mode),
                   outputs=["SaliencyNet_Intermediate_{}".format(key) for key in self.model_outputs],
                       finals=["SaliencyNet_Intermediate_{}".format(key) for key in self.model_outputs],

    def model_inputs(self):
        return deepcopy(self.model_op.inputs)

    def model_outputs(self):
        return deepcopy(self.model_op.outputs)

    def _convert_for_visualization(tensor: Tensor, tile: int = 99) -> np.ndarray:
        """Modify the range of data in a given input `tensor` to be appropriate for visualization.

            tensor: Input masks, whose channel values are to be reduced by absolute value summation.
            tile: The percentile [0-100] used to set the max value of the image.

            A (batch X width X height) image after visualization clipping is applied.
        if isinstance(tensor, torch.Tensor):
            channel_axis = 1
            channel_axis = -1
        flattened_mask = reduce_sum(abs(tensor), axis=channel_axis, keepdims=True)

        non_batch_axes = list(range(len(flattened_mask.shape)))[1:]

        vmax = percentile(flattened_mask, tile, axis=non_batch_axes, keepdims=True)
        vmin = reduce_min(flattened_mask, axis=non_batch_axes, keepdims=True)

        return clip_by_value((flattened_mask - vmin) / (vmax - vmin), 0, 1)

    def get_masks(self, batch: Dict[str, Any]) -> Dict[str, Union[Tensor, np.ndarray]]:
        """Generates greyscale saliency mask(s) from a given `batch` of data.

            batch: A batch of input data to be fed to the model.

            The model's classification decisions and greyscale saliency mask(s) for the given `batch` of data.
        # Shallow copy batch since we're going to modify its contents later
        batch = {key: val for key, val in batch.items()}
        grads_and_preds = self._get_mask(batch)
        for key in self.outputs:
            grads_and_preds[key] = self._convert_for_visualization(grads_and_preds[key])
        return grads_and_preds

    def _get_mask(self, batch: Dict[str, Any]) -> Dict[str, Tensor]:
        """Generates raw saliency mask(s) from a given `batch` of data.

        This method assumes that the Network is already loaded.

            batch: A batch of input data to be fed to the model.

            The model outputs and the raw saliency mask(s) for the given `batch` of data. Model predictions are reduced
            via argmax.
        for key in self.gather_keys:
            # If there's no target key, use an empty array which will cause the max-likelihood class to be selected
            batch.setdefault(key, [])
        prediction =, mode=self.mode)
        for key in self.model_outputs:
            prediction[key] = argmax(prediction[key], axis=1)
        return prediction

    def _get_integrated_masks(self, batch: Dict[str, Any], nsamples: int = 25) -> Dict[str, Tensor]:
        """Generates raw integrated saliency mask(s) from a given `batch` of data.

        This method assumes that the Network is already loaded.

            batch: A batch of input data to be fed to the model.
            nsamples: How many samples to consider during integration.

            The raw integrated saliency mask(s) for the given `batch` of data.
        model_inputs = [batch[ins] for ins in self.model_inputs]

        # Use a random uniform baseline as advised in
        input_baselines = [
            random_uniform_like(ins, minval=reduce_min(ins), maxval=reduce_max(ins)) for ins in model_inputs
        input_diffs = [
            model_input - input_baseline for model_input, input_baseline in zip(model_inputs, input_baselines)

        response = {}

        for alpha in np.linspace(0.0, 1.0, nsamples):
            noisy_batch = {key: batch[key] for key in self.gather_keys}
            for idx, input_name in enumerate(self.model_inputs):
                x_step = input_baselines[idx] + alpha * input_diffs[idx]
                noisy_batch[input_name] = x_step
            grads_and_preds = self._get_mask(noisy_batch)
            for key in self.outputs:
                if key in response:
                    response[key] += grads_and_preds[key]
                    response[key] = grads_and_preds[key]

        for key in self.outputs:
            grad = response[key]
            for diff in input_diffs:
                grad = grad * diff
            response[key] = grad

        return response

    def get_smoothed_masks(self,
                           batch: Dict[str, Any],
                           stdev_spread: float = .15,
                           nsamples: int = 25,
                           nintegration: Optional[int] = None,
                           magnitude: bool = True) -> Dict[str, Union[Tensor, np.ndarray]]:
        """Generates smoothed greyscale saliency mask(s) from a given `batch` of data.

            batch: An input batch of data.
            stdev_spread: Amount of noise to add to the input, as fraction of the total spread (x_max - x_min).
            nsamples: Number of samples to average across to get the smooth gradient.
            nintegration: Number of samples to compute when integrating (None to disable).
            magnitude: If true, computes the sum of squares of gradients instead of just the sum.

            Greyscale saliency mask(s) smoothed via the SmoothGrad method.
        # Shallow copy batch since we're going to modify its contents later
        batch = {key: val for key, val in batch.items()}
        model_inputs = [batch[ins] for ins in self.model_inputs]
        stdevs = [to_number(stdev_spread * (reduce_max(ins) - reduce_min(ins))).item() for ins in model_inputs]

        # Adding noise to the image might cause the max likelihood class value to change, so need to keep track of
        # which class we're comparing to
        response = self._get_mask(batch)
        for gather_key, output_key in zip(self.gather_keys, self.model_outputs):
            batch[gather_key] = response[output_key]

        if magnitude:
            for key in self.outputs:
                response[key] = response[key] * response[key]

        for _ in range(nsamples - 1):
            noisy_batch = {key: batch[key] for key in self.gather_keys}
            clean_batch = {key: val for key, val in noisy_batch.items()}
            for idx, input_name in enumerate(self.model_inputs):
                noise = random_normal_like(model_inputs[idx], std=stdevs[idx])
                x_plus_noise = model_inputs[idx] + noise
                clean_batch[input_name] = model_inputs[idx]
                noisy_batch[input_name] = x_plus_noise
            grads_and_preds = self._get_mask(noisy_batch) if not nintegration else self._get_integrated_masks(
                clean_batch, nsamples=nintegration)  # Integration introduces its own noise pattern
            for name in self.outputs:
                grad = grads_and_preds[name]
                if magnitude:
                    response[name] += grad * grad
                    response[name] += grad
        for key in self.outputs:
            grad = response[key]
            response[key] = self._convert_for_visualization(grad / nsamples)
        return response

    def get_integrated_masks(self, batch: Dict[str, Any], nsamples: int = 25) -> Dict[str, Union[Tensor, np.ndarray]]:
        """Generates integrated greyscale saliency mask(s) from a given `batch` of data.

        See for background on the IntegratedGradient method.

            batch: An input batch of data.
            nsamples: Number of samples to average across to get the integrated gradient.

            Greyscale saliency masks smoothed via the IntegratedGradient method.
        # Shallow copy batch since we're going to modify its contents later
        batch = {key: val for key, val in batch.items()}

        # Performing integration might cause the max likelihood class value to change, so need to keep track of
        # which class we're comparing to
        response = self._get_mask(batch)
        for gather_key, output_key in zip(self.gather_keys, self.model_outputs):
            batch[gather_key] = response[output_key]

        response.update(self._get_integrated_masks(batch, nsamples=nsamples))
        for key in self.outputs:
            response[key] = self._convert_for_visualization(response[key])

        return response


Generates integrated greyscale saliency mask(s) from a given batch of data.

See for background on the IntegratedGradient method.


Name Type Description Default
batch Dict[str, Any]

An input batch of data.

nsamples int

Number of samples to average across to get the integrated gradient.



Type Description
Dict[str, Union[Tensor, ndarray]]

Greyscale saliency masks smoothed via the IntegratedGradient method.

Source code in fastestimator/fastestimator/xai/
def get_integrated_masks(self, batch: Dict[str, Any], nsamples: int = 25) -> Dict[str, Union[Tensor, np.ndarray]]:
    """Generates integrated greyscale saliency mask(s) from a given `batch` of data.

    See for background on the IntegratedGradient method.

        batch: An input batch of data.
        nsamples: Number of samples to average across to get the integrated gradient.

        Greyscale saliency masks smoothed via the IntegratedGradient method.
    # Shallow copy batch since we're going to modify its contents later
    batch = {key: val for key, val in batch.items()}

    # Performing integration might cause the max likelihood class value to change, so need to keep track of
    # which class we're comparing to
    response = self._get_mask(batch)
    for gather_key, output_key in zip(self.gather_keys, self.model_outputs):
        batch[gather_key] = response[output_key]

    response.update(self._get_integrated_masks(batch, nsamples=nsamples))
    for key in self.outputs:
        response[key] = self._convert_for_visualization(response[key])

    return response


Generates greyscale saliency mask(s) from a given batch of data.


Name Type Description Default
batch Dict[str, Any]

A batch of input data to be fed to the model.



Type Description
Dict[str, Union[Tensor, ndarray]]

The model's classification decisions and greyscale saliency mask(s) for the given batch of data.

Source code in fastestimator/fastestimator/xai/
def get_masks(self, batch: Dict[str, Any]) -> Dict[str, Union[Tensor, np.ndarray]]:
    """Generates greyscale saliency mask(s) from a given `batch` of data.

        batch: A batch of input data to be fed to the model.

        The model's classification decisions and greyscale saliency mask(s) for the given `batch` of data.
    # Shallow copy batch since we're going to modify its contents later
    batch = {key: val for key, val in batch.items()}
    grads_and_preds = self._get_mask(batch)
    for key in self.outputs:
        grads_and_preds[key] = self._convert_for_visualization(grads_and_preds[key])
    return grads_and_preds


Generates smoothed greyscale saliency mask(s) from a given batch of data.


Name Type Description Default
batch Dict[str, Any]

An input batch of data.

stdev_spread float

Amount of noise to add to the input, as fraction of the total spread (x_max - x_min).

nsamples int

Number of samples to average across to get the smooth gradient.

nintegration Optional[int]

Number of samples to compute when integrating (None to disable).

magnitude bool

If true, computes the sum of squares of gradients instead of just the sum.



Type Description
Dict[str, Union[Tensor, ndarray]]

Greyscale saliency mask(s) smoothed via the SmoothGrad method.

Source code in fastestimator/fastestimator/xai/
def get_smoothed_masks(self,
                       batch: Dict[str, Any],
                       stdev_spread: float = .15,
                       nsamples: int = 25,
                       nintegration: Optional[int] = None,
                       magnitude: bool = True) -> Dict[str, Union[Tensor, np.ndarray]]:
    """Generates smoothed greyscale saliency mask(s) from a given `batch` of data.

        batch: An input batch of data.
        stdev_spread: Amount of noise to add to the input, as fraction of the total spread (x_max - x_min).
        nsamples: Number of samples to average across to get the smooth gradient.
        nintegration: Number of samples to compute when integrating (None to disable).
        magnitude: If true, computes the sum of squares of gradients instead of just the sum.

        Greyscale saliency mask(s) smoothed via the SmoothGrad method.
    # Shallow copy batch since we're going to modify its contents later
    batch = {key: val for key, val in batch.items()}
    model_inputs = [batch[ins] for ins in self.model_inputs]
    stdevs = [to_number(stdev_spread * (reduce_max(ins) - reduce_min(ins))).item() for ins in model_inputs]

    # Adding noise to the image might cause the max likelihood class value to change, so need to keep track of
    # which class we're comparing to
    response = self._get_mask(batch)
    for gather_key, output_key in zip(self.gather_keys, self.model_outputs):
        batch[gather_key] = response[output_key]

    if magnitude:
        for key in self.outputs:
            response[key] = response[key] * response[key]

    for _ in range(nsamples - 1):
        noisy_batch = {key: batch[key] for key in self.gather_keys}
        clean_batch = {key: val for key, val in noisy_batch.items()}
        for idx, input_name in enumerate(self.model_inputs):
            noise = random_normal_like(model_inputs[idx], std=stdevs[idx])
            x_plus_noise = model_inputs[idx] + noise
            clean_batch[input_name] = model_inputs[idx]
            noisy_batch[input_name] = x_plus_noise
        grads_and_preds = self._get_mask(noisy_batch) if not nintegration else self._get_integrated_masks(
            clean_batch, nsamples=nintegration)  # Integration introduces its own noise pattern
        for name in self.outputs:
            grad = grads_and_preds[name]
            if magnitude:
                response[name] += grad * grad
                response[name] += grad
    for key in self.outputs:
        grad = response[key]
        response[key] = self._convert_for_visualization(grad / nsamples)
    return response