Skip to content

gradient

GradientOp

Bases: TensorOp

Return the gradients of finals w.r.t. inputs.

Parameters:

Name Type Description Default
finals Union[str, List[str]]

The tensor(s) to compute gradients from.

required
outputs Union[str, List[str]]

The key(s) under which to save the gradients.

required
inputs Union[None, str, List[str]]

The tensor(s) to compute gradients with respect to, mutually exclusive with model.

None
model Union[None, Model, Module]

The model instance to compute gradients with respect to, mutually exclusive with inputs.

None
mode Union[None, str, Iterable[str]]

What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

None
ds_id Union[None, str, Iterable[str]]

What dataset id(s) to execute this Op in. To execute regardless of ds_id, pass None. To execute in all ds_ids except for a particular one, you can pass an argument like "!ds1".

None
Source code in fastestimator/fastestimator/op/tensorop/gradient/gradient.py
@traceable()
class GradientOp(TensorOp):
    """Return the gradients of finals w.r.t. inputs.

    Args:
        finals: The tensor(s) to compute gradients from.
        outputs: The key(s) under which to save the gradients.
        inputs: The tensor(s) to compute gradients with respect to, mutually exclusive with `model`.
        model: The model instance to compute gradients with respect to, mutually exclusive with `inputs`.
        mode: What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        ds_id: What dataset id(s) to execute this Op in. To execute regardless of ds_id, pass None. To execute in all
            ds_ids except for a particular one, you can pass an argument like "!ds1".
    """
    def __init__(self,
                 finals: Union[str, List[str]],
                 outputs: Union[str, List[str]],
                 inputs: Union[None, str, List[str]] = None,
                 model: Union[None, tf.keras.Model, torch.nn.Module] = None,
                 mode: Union[None, str, Iterable[str]] = None,
                 ds_id: Union[None, str, Iterable[str]] = None):
        inputs = to_list(inputs)
        finals = to_list(finals)
        outputs = to_list(outputs)
        assert bool(model) != bool(inputs), "Must provide either one of 'inputs' or 'model'"
        if model is None:
            assert len(inputs) == len(finals) == len(outputs), \
                "GradientOp requires the same number of inputs, finals, and outputs"
        else:
            assert isinstance(model, (tf.keras.Model, torch.nn.Module)), "Unrecognized model format"
            assert len(finals) == len(outputs), "GradientOp requires the same number of finals, and outputs"
        inputs.extend(finals)
        super().__init__(inputs=inputs, outputs=outputs, mode=mode, ds_id=ds_id)
        self.model = model
        self.retain_graph = True

    def fe_retain_graph(self, retain: Optional[bool] = None) -> Optional[bool]:
        if retain is not None:
            self.retain_graph = retain
        return self.retain_graph

    def build(self, framework: str, device: Optional[torch.device] = None) -> None:
        self.framework = framework

    def forward(self, data: List[Tensor], state: Dict[str, Any]) -> List[Tensor]:
        results = []
        if self.model is None:
            initials = data[:len(data) // 2]
            finals = data[len(data) // 2:]
            for idx, (initial, final) in enumerate(zip(initials, finals)):
                retain_graph = self.retain_graph or not idx == len(finals) - 1
                results.append(get_gradient(final, initial, tape=state['tape'], retain_graph=retain_graph))
        else:
            finals = data
            if self.framework == "tf":
                trainable_params = self.model.trainable_variables
                for idx, final in enumerate(finals):
                    gradient = get_gradient(final, trainable_params, tape=state['tape'])
                    results.append(gradient)
            elif self.framework == "torch":
                trainable_params = [p for p in self.model.parameters() if p.requires_grad]
                for idx, final in enumerate(finals):
                    # get_gradient
                    retain_graph = self.retain_graph or not idx == len(finals) - 1
                    gradient = get_gradient(final, trainable_params, retain_graph=retain_graph)
                    results.append(gradient)
            else:
                raise ValueError(f"Unrecognized framework {self.framework}")

        return results