Skip to content

traceability_util

FeInputSpec

A class to keep track of a model's input so that fake inputs can be generated.

This class is intentionally not @traceable.

Parameters:

Name Type Description Default
model_input Any

The input to the model.

required
model Model

The model which corresponds to the given model_input.

required
Source code in fastestimator/fastestimator/util/traceability_util.py
class FeInputSpec:
    """A class to keep track of a model's input so that fake inputs can be generated.

    This class is intentionally not @traceable.

    Args:
        model_input: The input to the model.
        model: The model which corresponds to the given `model_input`.
    """
    def __init__(self, model_input: Any, model: Model):
        self.shape = to_shape(model_input)
        self.dtype = to_type(model_input)
        self.device = self._get_device(model_input)
        self.tensor_func = tf.ones if isinstance(model, tf.keras.Model) else torch.ones

    def _get_device(self, data: Any) -> Union[None, str, torch.device]:
        """Get the device on which a tensor or collection of tensors is residing.

        Args:
            data: A tensor or collection of tensors.

        Returns:
            The device on which the tensors are residing
        """
        if tf.is_tensor(data) or isinstance(data, torch.Tensor):
            return data.device
        elif isinstance(data, dict):
            return self._get_device(list(data.values()))
        elif isinstance(data, (list, tuple, set)):
            for val in data:
                device = self._get_device(val)
                if device is not None:
                    return device
        else:
            return None

    def get_dummy_input(self) -> Any:
        """Get fake input for the model.

        Returns:
            Input of the correct shape and dtype for the model.
        """
        return self._from_shape_and_type(self.shape, self.dtype)

    def _from_shape_and_type(self, shape: Any, dtype: Any) -> Any:
        """Constructs tensor(s) with the specified shape and dtype.

        It is assumed that the `shape` and `dtype` arguments have the same container structure. That is to say, if
        `shape` is a list of 5 elements, it is required that `dtype` also be a list of 5 elements.

        Args:
            shape: A shape or (possibly nested) container of shapes.
            dtype: A dtype or (possibly nested) container of dtypes.

        Returns:
            A tensor or collection of tensors corresponding to the shape and dtype arguments.
        """
        if isinstance(dtype, dict):
            return {key: self._from_shape_and_type(value, dtype[key]) for key, value in shape.items()}
        elif isinstance(dtype, list):
            return [self._from_shape_and_type(shape[i], dtype[i]) for i in range(len(shape))]
        elif isinstance(dtype, tuple):
            return tuple([self._from_shape_and_type(shape[i], dtype[i]) for i in range(len(shape))])
        elif isinstance(dtype, set):
            return set([self._from_shape_and_type(s, t) for s, t in zip(shape, dtype)])
        else:
            retval = self.tensor_func(shape, dtype=dtype)
            if isinstance(self.device, torch.device):
                retval = retval.to(self.device)
            return retval

get_dummy_input

Get fake input for the model.

Returns:

Type Description
Any

Input of the correct shape and dtype for the model.

Source code in fastestimator/fastestimator/util/traceability_util.py
def get_dummy_input(self) -> Any:
    """Get fake input for the model.

    Returns:
        Input of the correct shape and dtype for the model.
    """
    return self._from_shape_and_type(self.shape, self.dtype)

FeSplitSummary

Bases: LatexObject

A class to summarize splits performed on an FE Dataset.

This class is intentionally not @traceable.

Source code in fastestimator/fastestimator/util/traceability_util.py
class FeSplitSummary(LatexObject):
    """A class to summarize splits performed on an FE Dataset.

    This class is intentionally not @traceable.
    """
    def __init__(self):
        super().__init__()
        self.data = []

    def add_split(self, parent: Union[FEID, str], fraction: str, seed: Optional[int], stratify: Optional[str]) -> None:
        """Record another split on this dataset.

        Args:
            parent: The id of the parent involved in the split (or 'self' if you are the parent).
            fraction: The string representation of the split fraction that was used.
            seed: The random seed used during the split.
            stratify: The stratify key used during the split.
        """
        self.data.append((parent, fraction, seed, stratify))

    def dumps(self) -> str:
        """Generate a LaTeX formatted representation of this object.

        Returns:
            A LaTeX string representation of this object.
        """
        return " $\\rightarrow$ ".join([
            f"{HrefFEID(parent, name='').dumps() if isinstance(parent, FEID) else parent}({escape_latex(fraction)}" +
            (f", seed={seed}" if seed is not None else "") +
            (f", stratify=`{escape_latex(stratify)}'" if stratify is not None else "") + ")" for parent,
            fraction,
            seed,
            stratify in self.data
        ])

add_split

Record another split on this dataset.

Parameters:

Name Type Description Default
parent Union[FEID, str]

The id of the parent involved in the split (or 'self' if you are the parent).

required
fraction str

The string representation of the split fraction that was used.

required
seed Optional[int]

The random seed used during the split.

required
stratify Optional[str]

The stratify key used during the split.

required
Source code in fastestimator/fastestimator/util/traceability_util.py
def add_split(self, parent: Union[FEID, str], fraction: str, seed: Optional[int], stratify: Optional[str]) -> None:
    """Record another split on this dataset.

    Args:
        parent: The id of the parent involved in the split (or 'self' if you are the parent).
        fraction: The string representation of the split fraction that was used.
        seed: The random seed used during the split.
        stratify: The stratify key used during the split.
    """
    self.data.append((parent, fraction, seed, stratify))

dumps

Generate a LaTeX formatted representation of this object.

Returns:

Type Description
str

A LaTeX string representation of this object.

Source code in fastestimator/fastestimator/util/traceability_util.py
def dumps(self) -> str:
    """Generate a LaTeX formatted representation of this object.

    Returns:
        A LaTeX string representation of this object.
    """
    return " $\\rightarrow$ ".join([
        f"{HrefFEID(parent, name='').dumps() if isinstance(parent, FEID) else parent}({escape_latex(fraction)}" +
        (f", seed={seed}" if seed is not None else "") +
        (f", stratify=`{escape_latex(stratify)}'" if stratify is not None else "") + ")" for parent,
        fraction,
        seed,
        stratify in self.data
    ])

FeSummaryTable

A class containing summaries of traceability information.

This class is intentionally not @traceable.

Parameters:

Name Type Description Default
name str

The string to be used as the title line in the summary table.

required
fe_id FEID

The id of this table, used for cross-referencing from other tables.

required
target_type Type

The type of the object being summarized.

required
path Union[None, str, LatexObject]

The import path of the object in question. Might be more complicated when methods/functions are involved.

None
kwargs Optional[Dict[str, Any]]

The keyword arguments used to instantiate the object being summarized.

None
**fields Any

Any other information about the summarized object / function.

{}
Source code in fastestimator/fastestimator/util/traceability_util.py
class FeSummaryTable:
    """A class containing summaries of traceability information.

    This class is intentionally not @traceable.

    Args:
        name: The string to be used as the title line in the summary table.
        fe_id: The id of this table, used for cross-referencing from other tables.
        target_type: The type of the object being summarized.
        path: The import path of the object in question. Might be more complicated when methods/functions are involved.
        kwargs: The keyword arguments used to instantiate the object being summarized.
        **fields: Any other information about the summarized object / function.
    """
    name: Union[str, LatexObject]
    fe_id: FEID
    fields: Dict[str, Any]

    def __init__(self,
                 name: str,
                 fe_id: FEID,
                 target_type: Type,
                 path: Union[None, str, LatexObject] = None,
                 kwargs: Optional[Dict[str, Any]] = None,
                 **fields: Any):
        self.name = name
        self.fe_id = fe_id
        self.type = target_type
        self.path = path
        self.args = fields.pop("args", None)
        self.kwargs = kwargs or {}
        self.fields = fields

    def render_table(self,
                     doc: Document,
                     name_override: Optional[LatexObject] = None,
                     toc_ref: Optional[str] = None,
                     extra_rows: Optional[List[Tuple[str, Any]]] = None) -> None:
        """Write this table into a LaTeX document.

        Args:
            doc: The LaTeX document to be appended to.
            name_override: An optional replacement for this table's name field.
            toc_ref: A reference to be added to the table of contents.
            extra_rows: Any extra rows to be added to the table before the kwargs.
        """
        with doc.create(Table(position='htp!')) as table:
            table.append(NoEscape(r'\refstepcounter{table}'))
            table.append(Label(Marker(name=str(self.fe_id), prefix="tbl")))
            if toc_ref:
                table.append(NoEscape(r'\addcontentsline{toc}{subsection}{' + escape_latex(toc_ref) + '}'))
            with doc.create(Tabularx('|lX|', booktabs=True)) as tabular:
                package = Package('xcolor', options='table')
                if package not in tabular.packages:
                    # Need to invoke a table color before invoking TextColor (bug?)
                    tabular.packages.append(package)
                package = Package('seqsplit')
                if package not in tabular.packages:
                    tabular.packages.append(package)
                tabular.add_row((name_override if name_override else bold(self.name),
                                 MultiColumn(size=1, align='r|', data=TextColor('blue', self.fe_id))))
                tabular.add_hline()
                type_str = f"{self.type}"
                match = re.fullmatch(r'^<.* \'(?P<typ>.*)\'>$', type_str)
                type_str = match.group("typ") if match else type_str
                tabular.add_row(("Type: ", escape_latex(type_str)))
                if self.path:
                    if isinstance(self.path, LatexObject):
                        tabular.add_row(("", self.path))
                    else:
                        tabular.add_row(("", escape_latex(self.path)))
                for k, v in self.fields.items():
                    tabular.add_hline()
                    tabular.add_row((f"{k.capitalize()}: ", v))
                if self.args:
                    tabular.add_hline()
                    tabular.add_row(("Args: ", self.args))
                if extra_rows:
                    for (key, val) in extra_rows:
                        tabular.add_hline()
                        tabular.add_row(key, val)
                if self.kwargs:
                    tabular.add_hline()
                    for idx, (kwarg, val) in enumerate(self.kwargs.items()):
                        tabular.add_row((italic(kwarg), val), color='white' if idx % 2 else 'black!5')

render_table

Write this table into a LaTeX document.

Parameters:

Name Type Description Default
doc Document

The LaTeX document to be appended to.

required
name_override Optional[LatexObject]

An optional replacement for this table's name field.

None
toc_ref Optional[str]

A reference to be added to the table of contents.

None
extra_rows Optional[List[Tuple[str, Any]]]

Any extra rows to be added to the table before the kwargs.

None
Source code in fastestimator/fastestimator/util/traceability_util.py
def render_table(self,
                 doc: Document,
                 name_override: Optional[LatexObject] = None,
                 toc_ref: Optional[str] = None,
                 extra_rows: Optional[List[Tuple[str, Any]]] = None) -> None:
    """Write this table into a LaTeX document.

    Args:
        doc: The LaTeX document to be appended to.
        name_override: An optional replacement for this table's name field.
        toc_ref: A reference to be added to the table of contents.
        extra_rows: Any extra rows to be added to the table before the kwargs.
    """
    with doc.create(Table(position='htp!')) as table:
        table.append(NoEscape(r'\refstepcounter{table}'))
        table.append(Label(Marker(name=str(self.fe_id), prefix="tbl")))
        if toc_ref:
            table.append(NoEscape(r'\addcontentsline{toc}{subsection}{' + escape_latex(toc_ref) + '}'))
        with doc.create(Tabularx('|lX|', booktabs=True)) as tabular:
            package = Package('xcolor', options='table')
            if package not in tabular.packages:
                # Need to invoke a table color before invoking TextColor (bug?)
                tabular.packages.append(package)
            package = Package('seqsplit')
            if package not in tabular.packages:
                tabular.packages.append(package)
            tabular.add_row((name_override if name_override else bold(self.name),
                             MultiColumn(size=1, align='r|', data=TextColor('blue', self.fe_id))))
            tabular.add_hline()
            type_str = f"{self.type}"
            match = re.fullmatch(r'^<.* \'(?P<typ>.*)\'>$', type_str)
            type_str = match.group("typ") if match else type_str
            tabular.add_row(("Type: ", escape_latex(type_str)))
            if self.path:
                if isinstance(self.path, LatexObject):
                    tabular.add_row(("", self.path))
                else:
                    tabular.add_row(("", escape_latex(self.path)))
            for k, v in self.fields.items():
                tabular.add_hline()
                tabular.add_row((f"{k.capitalize()}: ", v))
            if self.args:
                tabular.add_hline()
                tabular.add_row(("Args: ", self.args))
            if extra_rows:
                for (key, val) in extra_rows:
                    tabular.add_hline()
                    tabular.add_row(key, val)
            if self.kwargs:
                tabular.add_hline()
                for idx, (kwarg, val) in enumerate(self.kwargs.items()):
                    tabular.add_row((italic(kwarg), val), color='white' if idx % 2 else 'black!5')

fe_summary

Return a summary of how this class was instantiated (for traceability).

Parameters:

Name Type Description Default
self

The bound class instance.

required

Returns:

Type Description
List[FeSummaryTable]

A summary of the instance.

Source code in fastestimator/fastestimator/util/traceability_util.py
def fe_summary(self) -> List[FeSummaryTable]:
    """Return a summary of how this class was instantiated (for traceability).

    Args:
        self: The bound class instance.

    Returns:
        A summary of the instance.
    """
    # Delayed imports to avoid circular dependency
    from torch.utils.data import Dataset

    from fastestimator.estimator import Estimator
    from fastestimator.network import TFNetwork, TorchNetwork
    from fastestimator.op.op import Op
    from fastestimator.pipeline import Pipeline
    from fastestimator.schedule.schedule import Scheduler
    from fastestimator.slicer.slicer import Slicer
    from fastestimator.trace.trace import Trace

    # re-number the references for nicer viewing
    ordered_items = sorted(
        self._fe_traceability_summary.items(),
        key=lambda x: 0 if issubclass(x[1].type, Estimator) else 1
        if issubclass(x[1].type, (TFNetwork, TorchNetwork)) else 2 if issubclass(x[1].type, Pipeline) else 3
        if issubclass(x[1].type, Scheduler) else 4 if issubclass(x[1].type, Trace) else 5
        if issubclass(x[1].type, Op) else 6 if issubclass(x[1].type, Slicer) else 7
        if issubclass(x[1].type, (Dataset, tf.data.Dataset)) else 8
        if issubclass(x[1].type, (tf.keras.Model, torch.nn.Module)) else 9
        if issubclass(x[1].type, types.FunctionType) else 10
        if issubclass(x[1].type, (np.ndarray, tf.Tensor, tf.Variable, torch.Tensor)) else 11)
    key_mapping = {fe_id: f"@FE{idx}" for idx, (fe_id, _) in enumerate(ordered_items)}
    FEID.set_translation_dict(key_mapping)
    return [item[1] for item in ordered_items]

is_restorable

Determine whether a given object can be restored easily via Pickle.

Parameters:

Name Type Description Default
data Any

The object in question.

required
memory_limit int

The maximum memory size (in bytes) to allow for an object (or 0 for no limit).

0

Returns:

Type Description
bool

(result, memory size) where result is True iff data is only comprised of 'simple' objects and does not exceed

int

the memory_limit. If the result is False, then memory size will be <= the true memory size of the data.

Source code in fastestimator/fastestimator/util/traceability_util.py
def is_restorable(data: Any, memory_limit: int = 0) -> Tuple[bool, int]:
    """Determine whether a given object can be restored easily via Pickle.

    Args:
        data: The object in question.
        memory_limit: The maximum memory size (in bytes) to allow for an object (or 0 for no limit).

    Returns:
        (result, memory size) where result is True iff `data` is only comprised of 'simple' objects and does not exceed
        the `memory_limit`. If the result is False, then memory size will be <= the true memory size of the `data`.
    """
    if isinstance(data, _RestorableClasses):
        size = sys.getsizeof(data)
        if isinstance(data, tf.Tensor):
            size = sys.getsizeof(data.numpy())
        elif isinstance(data, torch.Tensor):
            size = data.element_size() * data.nelement()
        return True, size
    elif isinstance(data, dict):
        size = 0
        for key, value in data.items():
            key_stat = is_restorable(key, memory_limit)
            if key_stat[0] is False:
                return False, size
            size += key_stat[1]
            if 0 < memory_limit < size:
                return False, size
            val_stat = is_restorable(value, memory_limit)
            if val_stat[0] is False:
                return False, size
            size += val_stat[1]
            if 0 < memory_limit < size:
                return False, size
        return True, size
    elif isinstance(data, (list, tuple, set)):
        size = 0
        for elem in data:
            elem_stat = is_restorable(elem, memory_limit)
            if elem_stat[0] is False:
                return False, size
            size += elem_stat[1]
            if 0 < memory_limit < size:
                return False, size
        return True, size
    else:
        return False, 0

trace_model

A function to add traceability information to an FE-compiled model.

Parameters:

Name Type Description Default
model Model

The model to be made traceable.

required
model_idx int

Which of the return values from the model_fn is this model (or -1 if only a single return value).

required
model_fn Any

The function used to generate this model.

required
optimizer_fn Any

The thing used to define this model's optimizer.

required
weights_path Any

The path to the weights for this model.

required

Returns:

Type Description
Model

The model, but now with an fe_summary() method.

Source code in fastestimator/fastestimator/util/traceability_util.py
def trace_model(model: Model, model_idx: int, model_fn: Any, optimizer_fn: Any, weights_path: Any) -> Model:
    """A function to add traceability information to an FE-compiled model.

    Args:
        model: The model to be made traceable.
        model_idx: Which of the return values from the `model_fn` is this model (or -1 if only a single return value).
        model_fn: The function used to generate this model.
        optimizer_fn: The thing used to define this model's optimizer.
        weights_path: The path to the weights for this model.

    Returns:
        The `model`, but now with an fe_summary() method.
    """
    tables = {}
    description = {'definition': _trace_value(model_fn, tables, ret_ref=Flag())}
    if model_idx != -1:
        description['index'] = model_idx
    if optimizer_fn or isinstance(optimizer_fn, list) and optimizer_fn[0] is not None:
        description['optimizer'] = _trace_value(
            optimizer_fn[model_idx] if isinstance(optimizer_fn, list) else optimizer_fn, tables, ret_ref=Flag())
    if weights_path:
        description['weights'] = _trace_value(weights_path, tables, ret_ref=Flag())
    fe_id = FEID(id(model))
    tbl = FeSummaryTable(name=model.model_name, fe_id=fe_id, target_type=type(model), **description)
    tables[fe_id] = tbl
    # Have to put this in a ChainMap b/c dict gets put into model._layers automatically somehow
    model._fe_traceability_summary = ChainMap(tables)

    # Use MethodType to bind the method to the class instance
    setattr(model, 'fe_summary', types.MethodType(fe_summary, model))
    return model

traceable

A decorator to be placed on classes in order to make them traceable and to enable a deep restore.

Decorated classes will gain the .fe_summary() and .fe_state() methods.

Parameters:

Name Type Description Default
whitelist Union[str, Tuple[str, ...]]

Arguments which should be included in a deep restore of the decorated class.

()
blacklist Union[str, Tuple[str, ...]]

Arguments which should be excluded from a deep restore of the decorated class.

()

Returns:

Type Description
Callable

The decorated class.

Source code in fastestimator/fastestimator/util/traceability_util.py
def traceable(whitelist: Union[str, Tuple[str, ...]] = (), blacklist: Union[str, Tuple[str, ...]] = ()) -> Callable:
    """A decorator to be placed on classes in order to make them traceable and to enable a deep restore.

    Decorated classes will gain the .fe_summary() and .fe_state() methods.

    Args:
        whitelist: Arguments which should be included in a deep restore of the decorated class.
        blacklist: Arguments which should be excluded from a deep restore of the decorated class.

    Returns:
        The decorated class.
    """
    if isinstance(whitelist, str):
        whitelist = (whitelist, )
    if isinstance(blacklist, str):
        blacklist = (blacklist, )
    if whitelist and blacklist:
        raise ValueError("Traceable objects may specify a whitelist or a blacklist, but not both")

    def make_traceable(cls):
        base_init = getattr(cls, '__init__')
        if hasattr(base_init, '__module__') and base_init.__module__ != 'fastestimator.util.traceability_util':
            # We haven't already overridden this class' init method
            @functools.wraps(base_init)  # to preserve the original class signature
            def init(self, *args, **kwargs):
                if not hasattr(self, '_fe_state_whitelist'):
                    self._fe_state_whitelist = whitelist
                else:
                    self._fe_state_whitelist = tuple(set(self._fe_state_whitelist).union(set(whitelist)))
                if not hasattr(self, '_fe_state_blacklist'):
                    self._fe_state_blacklist = blacklist + (
                        '_fe_state_whitelist', '_fe_state_blacklist', '_fe_traceability_summary')
                else:
                    self._fe_state_blacklist = tuple(set(self._fe_state_blacklist).union(set(blacklist)))
                if not hasattr(self, '_fe_traceability_summary'):
                    bound_args = inspect.signature(base_init).bind(self, *args, **kwargs)
                    bound_args.apply_defaults()
                    tables = {}
                    _trace_value(_BoundFn(self, bound_args), tables, ret_ref=Flag())
                    self._fe_traceability_summary = tables
                base_init(self, *args, **kwargs)

            setattr(cls, '__init__', init)

        base_func = getattr(cls, 'fe_summary', None)
        if base_func is None:
            setattr(cls, 'fe_summary', fe_summary)

        base_func = getattr(cls, '__getstate__', None)
        if base_func is None:
            setattr(cls, '__getstate__', __getstate__)

        base_func = getattr(cls, '__setstate__', None)
        if base_func is None:
            setattr(cls, '__setstate__', __setstate__)

        return cls

    return make_traceable