Skip to content

cross_entropy

CrossEntropy

Bases: LossOp

Calculate Element-Wise CrossEntropy (binary, categorical or sparse categorical).

Parameters:

Name Type Description Default
inputs Union[Tuple[str, str], List[str]]

A tuple or list like: [, ].

required
outputs str

String key under which to store the computed loss value.

required
mode Union[None, str, Iterable[str]]

What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

'!infer'
ds_id Union[None, str, Iterable[str]]

What dataset id(s) to execute this Op in. To execute regardless of ds_id, pass None. To execute in all ds_ids except for a particular one, you can pass an argument like "!ds1".

None
from_logits bool

Whether y_pred is logits (without softmax).

False
average_loss bool

Whether to average the element-wise loss after the Loss Op.

True
form Optional[str]

What form of cross entropy should be performed ('binary', 'categorical', 'sparse', or None). None will automatically infer the correct form based on tensor shape: if the both y_pred and y_true are rank-2 tensors then 'categorical' will be used, if y_pred is rank-2 tensors but y_true is rank-1 tensor, then sparse will be chosen, otherwise binary will be applied.

None
class_weights Optional[Dict[int, float]]

Dictionary mapping class indices to a weight for weighting the loss function. Useful when you need to pay more attention to samples from an under-represented class.

None

Raises:

Type Description
AssertionError

If class_weights or it's keys and values are of unacceptable data types.

Source code in fastestimator/fastestimator/op/tensorop/loss/cross_entropy.py
@traceable()
class CrossEntropy(LossOp):
    """Calculate Element-Wise CrossEntropy (binary, categorical or sparse categorical).

    Args:
        inputs: A tuple or list like: [<y_pred>, <y_true>].
        outputs: String key under which to store the computed loss value.
        mode: What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        ds_id: What dataset id(s) to execute this Op in. To execute regardless of ds_id, pass None. To execute in all
            ds_ids except for a particular one, you can pass an argument like "!ds1".
        from_logits: Whether y_pred is logits (without softmax).
        average_loss: Whether to average the element-wise loss after the Loss Op.
        form: What form of cross entropy should be performed ('binary', 'categorical', 'sparse', or None). None will
            automatically infer the correct form based on tensor shape: if the both y_pred and y_true are rank-2 tensors
            then 'categorical' will be used, if y_pred is rank-2 tensors but y_true is rank-1 tensor, then `sparse` will
            be chosen, otherwise `binary` will be applied.
        class_weights: Dictionary mapping class indices to a weight for weighting the loss function. Useful when you
            need to pay more attention to samples from an under-represented class.

    Raises:
        AssertionError: If `class_weights` or it's keys and values are of unacceptable data types.
    """
    def __init__(self,
                 inputs: Union[Tuple[str, str], List[str]],
                 outputs: str,
                 mode: Union[None, str, Iterable[str]] = "!infer",
                 ds_id: Union[None, str, Iterable[str]] = None,
                 from_logits: bool = False,
                 average_loss: bool = True,
                 form: Optional[str] = None,
                 class_weights: Optional[Dict[int, float]] = None):
        super().__init__(inputs=inputs, outputs=outputs, mode=mode, ds_id=ds_id, average_loss=average_loss)
        self.from_logits = from_logits
        self.form = form
        self.cross_entropy_fn = {
            "binary": binary_crossentropy,
            "categorical": categorical_crossentropy,
            "sparse": sparse_categorical_crossentropy
        }

        if class_weights:
            assert isinstance(class_weights, dict), \
                "class_weights should be a dictionary or have None value, got {}".format(type(class_weights))
            assert all(isinstance(key, int) for key in class_weights.keys()), \
                "Please ensure that the keys of the class_weight dictionary are of type: int"
            assert all(isinstance(value, float) for value in class_weights.values()), \
                "Please ensure that the values of the class_weight dictionary are of type: float"

        self.class_weights = class_weights
        self.class_dict = None

    def build(self, framework: str, device: Optional[torch.device] = None) -> None:
        if self.class_weights:
            if framework == 'tf':
                keys_tensor = tf.constant(list(self.class_weights.keys()))
                vals_tensor = tf.constant(list(self.class_weights.values()))
                self.class_dict = tf.lookup.StaticHashTable(
                    tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), default_value=1.0)
            elif framework == 'torch':
                self.class_dict = self.class_weights
            else:
                raise ValueError("unrecognized framework: {}".format(framework))

    def forward(self, data: List[Tensor], state: Dict[str, Any]) -> Tensor:
        y_pred, y_true = data
        form = self.form
        if form is None:
            if len(y_pred.shape) == 2 and y_pred.shape[-1] > 1:
                if len(y_true.shape) == 2 and y_true.shape[-1] > 1:
                    form = "categorical"
                else:
                    form = "sparse"
            else:
                form = "binary"

        loss = self.cross_entropy_fn[form](y_pred,
                                           y_true,
                                           from_logits=self.from_logits,
                                           average_loss=self.average_loss,
                                           class_weights=self.class_dict)
        return loss