Skip to content

gather

Gather

Bases: TensorOp

Gather values from an input tensor.

If indices are not provided, the maximum values along the batch dimension will be collected.

Parameters:

Name Type Description Default
inputs Union[str, List[str]]

The tensor(s) to gather values from.

required
indices Union[None, str, List[str]]

A tensor containing target indices to gather.

None
outputs Union[str, List[str]]

The key(s) under which to save the output.

required
mode Union[None, str, Iterable[str]]

What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

None
ds_id Union[None, str, Iterable[str]]

What dataset id(s) to execute this Op in. To execute regardless of ds_id, pass None. To execute in all ds_ids except for a particular one, you can pass an argument like "!ds1".

None
Source code in fastestimator/fastestimator/op/tensorop/gather.py
@traceable()
class Gather(TensorOp):
    """Gather values from an input tensor.

    If indices are not provided, the maximum values along the batch dimension will be collected.

    Args:
        inputs: The tensor(s) to gather values from.
        indices: A tensor containing target indices to gather.
        outputs: The key(s) under which to save the output.
        mode: What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        ds_id: What dataset id(s) to execute this Op in. To execute regardless of ds_id, pass None. To execute in all
            ds_ids except for a particular one, you can pass an argument like "!ds1".
    """
    def __init__(self,
                 inputs: Union[str, List[str]],
                 outputs: Union[str, List[str]],
                 indices: Union[None, str, List[str]] = None,
                 mode: Union[None, str, Iterable[str]] = None,
                 ds_id: Union[None, str, Iterable[str]] = None):
        indices = to_list(indices)
        self.num_indices = len(indices)
        combined_inputs = indices
        combined_inputs.extend(to_list(inputs))
        super().__init__(inputs=combined_inputs, outputs=outputs, mode=mode, ds_id=ds_id)
        self.in_list, self.out_list = True, True

    def forward(self, data: List[Tensor], state: Dict[str, Any]) -> List[Tensor]:
        indices = data[:self.num_indices]
        inputs = data[self.num_indices:]
        results = []
        for idx, tensor in enumerate(inputs):
            # Check len(indices[0]) since an empty indices element is used to trigger the else
            if tf.is_tensor(indices[0]) or isinstance(indices[0], torch.Tensor):
                elem_len = indices[0].shape[0]
            else:
                elem_len = len(indices[0])
            if len(indices) > idx and elem_len > 0:
                results.append(gather_from_batch(tensor, indices=indices[idx]))
            elif len(indices) == 1 and elem_len > 0:
                # One set of indices for all outputs
                results.append(gather_from_batch(tensor, indices=indices[0]))
            else:
                results.append(reduce_max(tensor, 1))  # The maximum value within each batch element
        return results