Skip to content

dice_loss

DiceLoss

Bases: LossOp

Calculate Dice Loss.

Parameters:

Name Type Description Default
inputs Union[Tuple[str, str], List[str]]

A tuple or list of keys representing prediction and ground truth, like: ("y_pred", "y_true").

required
outputs str

The key under which to save the output.

required
mode Union[None, str, Iterable[str]]

What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

'!infer'
ds_id Union[None, str, Iterable[str]]

What dataset id(s) to execute this Op in. To execute regardless of ds_id, pass None. To execute in all ds_ids except for a particular one, you can pass an argument like "!ds1".

None
soft_dice bool

Whether to square elements in the denominator.

False
average_loss bool

Whether to average the element-wise loss after the Loss Op.

True
channel_average bool

Whether to average the dice score along the channel dimension.

True
channel_weights Optional[Dict[int, float]]

Dictionary mapping channel indices to a weight for weighting the loss function. Useful when you need to pay more attention to a particular channel.

None
epsilon float

A small value to prevent numeric instability in the division.

1e-06

Returns:

Type Description

The dice loss between y_pred and y_true. A scalar if sample_average and channel_average are True,

otherwise a tensor.

Raises:

Type Description
AssertionError

If y_true or y_pred are unacceptable data types.

Source code in fastestimator/fastestimator/op/tensorop/loss/dice_loss.py
class DiceLoss(LossOp):
    """Calculate Dice Loss.

    Args:
        inputs: A tuple or list of keys representing prediction and ground truth, like: ("y_pred", "y_true").
        outputs: The key under which to save the output.
        mode: What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        ds_id: What dataset id(s) to execute this Op in. To execute regardless of ds_id, pass None. To execute in all
            ds_ids except for a particular one, you can pass an argument like "!ds1".
        soft_dice: Whether to square elements in the denominator.
        average_loss: Whether to average the element-wise loss after the Loss Op.
        channel_average: Whether to average the dice score along the channel dimension.
        channel_weights: Dictionary mapping channel indices to a weight for weighting the loss function. Useful when you
            need to pay more attention to a particular channel.
        epsilon: A small value to prevent numeric instability in the division.

    Returns:
        The dice loss between `y_pred` and `y_true`. A scalar if `sample_average` and `channel_average` are True,
        otherwise a tensor.

    Raises:
        AssertionError: If `y_true` or `y_pred` are unacceptable data types.
    """

    def __init__(self,
                 inputs: Union[Tuple[str, str], List[str]],
                 outputs: str,
                 mode: Union[None, str, Iterable[str]] = "!infer",
                 ds_id: Union[None, str, Iterable[str]] = None,
                 soft_dice: bool = False,
                 average_loss: bool = True,
                 channel_average: bool = True,
                 channel_weights: Optional[Dict[int, float]] = None,
                 epsilon: float = 1e-6):
        super().__init__(inputs=inputs, outputs=outputs, mode=mode, ds_id=ds_id, average_loss=average_loss)
        self.channel_average = channel_average
        self.soft_dice = soft_dice
        self.epsilon = epsilon
        if channel_weights:
            assert isinstance(channel_weights, dict), \
                "channel_weights should be a dictionary or have None value, got {}".format(type(channel_weights))
            assert all(isinstance(key, int) for key in channel_weights.keys()), \
                "Please ensure that the keys of the class_weight dictionary are of type: int"
            assert all(isinstance(value, float) for value in channel_weights.values()), \
                "Please ensure that the values of the class_weight dictionary are of type: float"
        self.weights = None
        if channel_weights is not None:
            self.weights = np.ones((1, max(channel_weights.keys()) + 1), dtype='float32')
            for channel, weight in channel_weights.items():
                self.weights[0, channel] = weight

    def build(self, framework: str, device: Optional[torch.device] = None) -> None:
        if framework == 'tf':
            if self.weights is not None:
                self.weights = convert_tensor_precision(to_tensor(self.weights, 'tf'))
        elif framework == 'torch':
            if self.weights is not None:
                self.weights = convert_tensor_precision(to_tensor(self.weights, 'torch'))
                self.weights.to(device)
        else:
            raise ValueError("unrecognized framework: {}".format(framework))

    def forward(self, data: List[Tensor], state: Dict[str, Any]) -> Tensor:
        y_pred, y_true = data
        dice = dice_score(y_pred=y_pred,
                          y_true=y_true,
                          soft_dice=self.soft_dice,
                          sample_average=self.average_loss,
                          channel_average=self.channel_average,
                          channel_weights=self.weights,
                          epsilon=self.epsilon)
        return -dice