Skip to content

super_loss

SuperLoss

Bases: LossOp

Loss class to compute a 'super loss' (automatic curriculum learning) based on a regular loss.

This class adds automatic curriculum learning on top of any other loss metric. It is especially useful in for noisy datasets. See https://papers.nips.cc/paper/2020/file/2cfa8f9e50e0f510ede9d12338a5f564-Paper.pdf for details.

Parameters:

Name Type Description Default
loss LossOp

A loss object which we use to calculate the underlying regular loss. This should be an object of type fe.op.tensorop.loss.loss.LossOp.

required
threshold Union[float, str]

Either a constant value corresponding to an average expected loss (for example log(n_classes) for cross-entropy classification), or 'exp' to use an exponential moving average loss.

'exp'
regularization float

The regularization parameter to use for the super loss (must by >0, as regularization approaches infinity the SuperLoss converges to the regular loss value).

1.0
average_loss bool

Whether the final loss should be averaged or not.

True
output_confidence Optional[str]

If not None then the confidence scores for each sample will be written into the specified key. This can be useful for finding difficult or mislabeled data.

None

Raises:

Type Description
ValueError

If the provided loss has multiple outputs or the regularization / threshold parameters are invalid.

Source code in fastestimator/fastestimator/op/tensorop/loss/super_loss.py
class SuperLoss(LossOp):
    """Loss class to compute a 'super loss' (automatic curriculum learning) based on a regular loss.

    This class adds automatic curriculum learning on top of any other loss metric. It is especially useful in for noisy
    datasets. See https://papers.nips.cc/paper/2020/file/2cfa8f9e50e0f510ede9d12338a5f564-Paper.pdf for details.

    Args:
        loss: A loss object which we use to calculate the underlying regular loss. This should be an object of type
            fe.op.tensorop.loss.loss.LossOp.
        threshold: Either a constant value corresponding to an average expected loss (for example log(n_classes) for
            cross-entropy classification), or 'exp' to use an exponential moving average loss.
        regularization: The regularization parameter to use for the super loss (must by >0, as regularization approaches
            infinity the SuperLoss converges to the regular loss value).
        average_loss: Whether the final loss should be averaged or not.
        output_confidence: If not None then the confidence scores for each sample will be written into the specified
            key. This can be useful for finding difficult or mislabeled data.

    Raises:
        ValueError: If the provided `loss` has multiple outputs or the `regularization` / `threshold` parameters are
            invalid.
    """
    def __init__(self,
                 loss: LossOp,
                 threshold: Union[float, str] = 'exp',
                 regularization: float = 1.0,
                 average_loss: bool = True,
                 output_confidence: Optional[str] = None):
        if len(loss.outputs) != 1 or loss.out_list:
            raise ValueError("SuperLoss only supports lossOps which have a single output.")
        self.loss = loss
        self.loss.average_loss = False
        super().__init__(inputs=loss.inputs,
                         outputs=loss.outputs[0] if not output_confidence else (loss.outputs[0], output_confidence),
                         mode=loss.mode,
                         ds_id=loss.ds_id,
                         average_loss=average_loss)
        if not isinstance(threshold, str):
            threshold = to_number(threshold).item()
        if not isinstance(threshold, float) and threshold != 'exp':
            raise ValueError(f'SuperLoss threshold parameter must be "exp" or a float, but got {threshold}')
        self.tau_method = threshold
        if regularization <= 0:
            raise ValueError(f"SuperLoss regularization parameter must be greater than 0, but got {regularization}")
        self.lam = regularization
        self.cap = -1.9999998 / e  # Slightly more than -2 / e for numerical stability
        self.initialized = {}
        self.tau = {}

    def build(self, framework: str, device: Optional[torch.device] = None) -> None:
        self.loss.build(framework, device)
        if framework == 'tf':
            self.initialized = {
                'train': tf.Variable(False, trainable=False),
                'eval': tf.Variable(False, trainable=False),
                'test': tf.Variable(False, trainable=False),
                'infer': tf.Variable(False, trainable=False)
            }
            if self.tau_method == 'exp':
                self.tau = {
                    'train': tf.Variable(0.0, trainable=False),
                    'eval': tf.Variable(0.0, trainable=False),
                    'test': tf.Variable(0.0, trainable=False),
                    'infer': tf.Variable(0.0, trainable=False)
                }
            else:
                self.tau = {
                    'train': tf.Variable(self.tau_method, trainable=False),
                    'eval': tf.Variable(self.tau_method, trainable=False),
                    'test': tf.Variable(self.tau_method, trainable=False),
                    'infer': tf.Variable(self.tau_method, trainable=False)
                }
            self.cap = tf.constant(self.cap)
        elif framework == 'torch':
            self.initialized = {
                'train': torch.tensor(False).to(device),
                'eval': torch.tensor(False).to(device),
                'test': torch.tensor(False).to(device),
                'infer': torch.tensor(False).to(device)
            }
            if self.tau_method == 'exp':
                self.tau = {
                    'train': torch.tensor(0.0).to(device),
                    'eval': torch.tensor(0.0).to(device),
                    'test': torch.tensor(0.0).to(device),
                    'infer': torch.tensor(0.0).to(device)
                }
            else:
                self.tau = {
                    'train': torch.tensor(self.tau_method).to(device),
                    'eval': torch.tensor(self.tau_method).to(device),
                    'test': torch.tensor(self.tau_method).to(device),
                    'infer': torch.tensor(self.tau_method).to(device)
                }
            self.cap = torch.tensor(self.cap).to(device)
            self.lam = torch.tensor(self.lam).to(device)
        else:
            raise ValueError("unrecognized framework: {}".format(framework))

    def forward(self, data: List[Tensor], state: Dict[str, Any]) -> Union[Tensor, List[Tensor]]:
        base_loss = self.loss.forward(data, state)
        tau = self._accumulate_tau(base_loss, state['mode'], state['warmup'])
        beta = (base_loss - tau) / self.lam
        # TODO The authors say to remove the gradients. Need to check whether this is necessary (speed or metrics)
        ln_sigma = -lambertw(0.5 * maximum(self.cap, beta))
        super_loss = (base_loss - tau) * exp(ln_sigma) + self.lam * pow(ln_sigma, 2)

        if self.average_loss:
            super_loss = reduce_mean(super_loss)

        if len(self.outputs) == 2:
            # User requested that the confidence score be returned
            return [super_loss, exp(ln_sigma)]

        return super_loss

    def _accumulate_tau(self, loss: Tensor, mode: str, warmup: bool) -> Tensor:
        """Determine an average loss value based on a particular method chosen during __init__.

        Right now this only supports constant values or exponential averaging. The original paper also proposed global
        averaging, but they didn't find much difference between the three methods and global averaging would more
        complicated memory requirements.

        Args:
            loss: The current step loss.
            mode: The current step mode.
            warmup: Whether running in warmup mode or not.

        Returns:
            Either the static value provided at __init__, or an exponential moving average of the loss over time.
        """
        if self.tau_method == 'exp':
            if _read_variable(self.initialized[mode]):
                _assign(self.tau[mode], self.tau[mode] - 0.1 * (self.tau[mode] - reduce_mean(loss)))
            else:
                _assign(self.tau[mode], reduce_mean(loss))
                if not warmup:
                    _assign(self.initialized[mode], ones_like(self.initialized[mode]))
        return self.tau[mode]