Skip to content

_l2_regularization

l2_regularization

Calculate L2 Norm of model weights.

l2_reg = sum(parameter**2)/2

This method can be used with TensorFlow and Pytorch tensors

Parameters:

Name Type Description Default
model Union[Model, Module]

A tensorflow or pytorch model

required
beta float

The multiplicative factor, to weight the l2 regularization loss with the input loss

0.01

Returns:

Type Description
Tensor

The L2 norm of model parameters

Raises:

Type Description
ValueError

If model belongs to an unacceptable framework.

Source code in fastestimator/fastestimator/backend/_l2_regularization.py
def l2_regularization(model: Union[tf.keras.Model, torch.nn.Module], beta: float = 0.01) -> Tensor:
    """Calculate L2 Norm of model weights.

    l2_reg = sum(parameter**2)/2

    This method can be used with TensorFlow and Pytorch tensors

    Args:
        model: A tensorflow or pytorch model
        beta: The multiplicative factor, to weight the l2 regularization loss with the input loss

    Returns:
        The L2 norm of model parameters

    Raises:
        ValueError: If `model` belongs to an unacceptable framework.
    """
    if isinstance(model, torch.nn.Module):
        l2_loss = torch.sum(torch.stack([torch.sum(p**2) / 2 for p in model.parameters() if p.requires_grad]))
    elif isinstance(model, tf.keras.Model):
        l2_loss = tf.reduce_sum([tf.nn.l2_loss(p) for p in model.trainable_variables])
    else:
        raise ValueError("Unrecognized model framework: Please make sure to pass either torch or tensorflow models")
    return beta * l2_loss