Skip to content

fuse

Fuse

Bases: TensorOp

Run a sequence of TensorOps as a single Op.

Parameters:

Name Type Description Default
ops Union[TensorOp, List[TensorOp]]

A sequence of TensorOps to run. They must all share the same mode. It also doesn't support scheduled ops at the moment, though the subnet itself may be scheduled.

required

Raises:

Type Description
ValueError

If ops are invalid.

Source code in fastestimator/fastestimator/op/tensorop/meta/fuse.py
@traceable()
class Fuse(TensorOp):
    """Run a sequence of TensorOps as a single Op.

    Args:
        ops: A sequence of TensorOps to run. They must all share the same mode. It also doesn't support scheduled ops at
            the moment, though the subnet itself may be scheduled.

    Raises:
        ValueError: If `ops` are invalid.
    """
    def __init__(self, ops: Union[TensorOp, List[TensorOp]]) -> None:
        ops = to_list(ops)
        if len(ops) < 1:
            raise ValueError("Fuse requires at least one op")
        inputs = []
        outputs = []
        mode = ops[0].mode
        ds_id = ops[0].ds_id
        self.last_retain_idx = 0
        self.models = set()
        self.loss_keys = set()
        for idx, op in enumerate(ops):
            if op.mode != mode:
                raise ValueError(f"All Fuse ops must share the same mode, but got {mode} and {op.mode}")
            if op.ds_id != ds_id:
                raise ValueError(f"All Fuse ops must share the same ds_id, but got {ds_id} and {op.ds_id}")
            for inp in op.inputs:
                if inp not in inputs and inp not in outputs:
                    inputs.append(inp)
            for out in op.outputs:
                if out not in outputs:
                    outputs.append(out)
            if op.fe_retain_graph(True) is not None:  # Set all of the internal ops to retain
                self.last_retain_idx = idx  # Keep tabs on the last one since it might be set to False
            self.models |= op.get_fe_models()
            self.loss_keys |= op.get_fe_loss_keys()
        super().__init__(inputs=inputs, outputs=outputs, mode=mode, ds_id=ds_id)
        self.ops = ops

    def build(self, framework: str, device: Optional[torch.device] = None) -> None:
        for op in self.ops:
            op.build(framework, device)

    def get_fe_models(self) -> Set[Model]:
        return self.models

    def get_fe_loss_keys(self) -> Set[str]:
        return self.loss_keys

    def fe_retain_graph(self, retain: Optional[bool] = None) -> Optional[bool]:
        return self.ops[self.last_retain_idx].fe_retain_graph(retain)

    def __getstate__(self) -> Dict[str, List[Dict[Any, Any]]]:
        return {'ops': [elem.__getstate__() if hasattr(elem, '__getstate__') else {} for elem in self.ops]}

    def forward(self, data: List[Tensor], state: Dict[str, Any]) -> List[Tensor]:
        data = {key: elem for key, elem in zip(self.inputs, data)}
        BaseNetwork._forward_batch(data, state, self.ops)
        return [data[key] for key in self.outputs]