Skip to content

mean_unslicer

MeanUnslicer

Bases: Slicer

A slicer which re-combines mini-batches via averaging.

Parameters:

Name Type Description Default
unslice Union[str, Sequence[str]]

The input key(s) which this Slicer un-slices.

required
axis

The axis along which to cut the data

required
mode Union[None, str, Iterable[str]]

What mode(s) to invoke this Slicer in. For example, "train", "eval", "test", or "infer". To invoke regardless of mode, pass None. To invoke in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

None
ds_id Union[None, str, Iterable[str]]

What dataset id(s) to invoke this Slicer in. To invoke regardless of ds_id, pass None. To invoke in all ds_ids except for a particular one, you can pass an argument like "!ds1".

None
Source code in fastestimator/fastestimator/slicer/mean_unslicer.py
@traceable()
class MeanUnslicer(Slicer):
    """A slicer which re-combines mini-batches via averaging.

    Args:
        unslice: The input key(s) which this Slicer un-slices.
        axis: The axis along which to cut the data
        mode: What mode(s) to invoke this Slicer in. For example, "train", "eval", "test", or "infer". To invoke
            regardless of mode, pass None. To invoke in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        ds_id: What dataset id(s) to invoke this Slicer in. To invoke regardless of ds_id, pass None. To invoke in all
            ds_ids except for a particular one, you can pass an argument like "!ds1".
    """
    def __init__(self,
                 unslice: Union[str, Sequence[str]],
                 mode: Union[None, str, Iterable[str]] = None,
                 ds_id: Union[None, str, Iterable[str]] = None) -> None:
        super().__init__(slice=None, unslice=unslice, mode=mode, ds_id=ds_id)

    def _unslice_batch(self, slices: Tuple[Tensor, ...], key: str) -> Tensor:
        mean = zeros_like(slices[0])
        for minibatch in slices:
            mean += minibatch
        mean /= len(slices)
        return mean