Skip to content

estimator

EarlyStop

Bases: Exception

An exception raised when the system.stop_training flag is flipped by a Trace in order to abort the training.

This class is intentionally not @traceable.

Source code in fastestimator/fastestimator/estimator.py
class EarlyStop(Exception):
    """An exception raised when the system.stop_training flag is flipped by a Trace in order to abort the training.

    This class is intentionally not @traceable.
    """

Estimator

One class to rule them all.

Estimator is the highest level class within FastEstimator. It is the class which is invoked to actually train (estimator.fit) or test (estimator.test) models. It wraps Pipeline, Network, Trace objects together and defines the whole optimization process.

If the data fed into pipeline is a TensorFlow Dataset, then the parameters train_steps_per_epoch and eval_steps_per_epoch can only reduce the number of steps per epoch. If these parameters are higher than the dimension of the stated Dataset then the whole Dataset will be used.

Parameters:

Name Type Description Default
pipeline Pipeline

An fe.Pipeline object that defines the data processing workflow.

required
network BaseNetwork

An fe.Network object that contains models and other training graph definitions.

required
epochs int

The number of epochs to run.

required
train_steps_per_epoch Optional[int]

Training will be cut short or extended to complete N steps even if loader is not yet exhausted. If None, all data will be used.

None
eval_steps_per_epoch Optional[int]

Evaluation will be cut short or extended to complete N steps even if loader is not yet exhausted. If None, all data will be used.

None
traces Union[None, Trace, Scheduler[Trace], Sequence[Union[None, Trace, Scheduler[Trace]]]]

What Traces to run during training. If None, only the system's default Traces will be included.

None
log_steps Optional[int]

Frequency (in steps) for printing log messages. 0 to disable all step-based printing (though epoch information will still print). None to completely disable printing.

100
eval_log_steps Sequence[int]

The list of steps on which evaluation progress logs need to be printed.

()
monitor_names Union[None, str, Iterable[Optional[str]]]

Additional keys from the data dictionary to be written into the logs.

None
Source code in fastestimator/fastestimator/estimator.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
@traceable()
class Estimator:
    """One class to rule them all.

    Estimator is the highest level class within FastEstimator. It is the class which is invoked to actually train
    (estimator.fit) or test (estimator.test) models. It wraps `Pipeline`, `Network`, `Trace` objects together and
    defines the whole optimization process.

    If the data fed into pipeline is a TensorFlow Dataset, then the parameters `train_steps_per_epoch` and
    `eval_steps_per_epoch` can only reduce the number of steps per epoch. If these parameters are higher than the
    dimension of the stated Dataset then the whole Dataset will be used.


    Args:
        pipeline: An fe.Pipeline object that defines the data processing workflow.
        network: An fe.Network object that contains models and other training graph definitions.
        epochs: The number of epochs to run.
        train_steps_per_epoch: Training will be cut short or extended to complete N steps even if loader is not yet
            exhausted. If None, all data will be used.
        eval_steps_per_epoch: Evaluation will be cut short or extended to complete N steps even if loader is not yet
            exhausted. If None, all data will be used.
        traces: What Traces to run during training. If None, only the system's default Traces will be included.
        log_steps: Frequency (in steps) for printing log messages. 0 to disable all step-based printing (though epoch
            information will still print). None to completely disable printing.
        eval_log_steps: The list of steps on which evaluation progress logs need to be printed.
        monitor_names: Additional keys from the data dictionary to be written into the logs.
    """
    monitor_names: Set[str]
    traces_in_use: List[Union[Trace, Scheduler[Trace]]]
    system: System
    filepath: str

    def __init__(self,
                 pipeline: Pipeline,
                 network: BaseNetwork,
                 epochs: int,
                 train_steps_per_epoch: Optional[int] = None,
                 eval_steps_per_epoch: Optional[int] = None,
                 traces: Union[None, Trace, Scheduler[Trace], Sequence[Union[None, Trace, Scheduler[Trace]]]] = None,
                 log_steps: Optional[int] = 100,
                 eval_log_steps: Sequence[int] = (),
                 monitor_names: Union[None, str, Iterable[Optional[str]]] = None):
        self.traces_in_use = []
        self.filepath = os.path.realpath(inspect.stack()[2].filename)  # Record this for history tracking
        assert log_steps is None or log_steps >= 0, \
            "log_steps must be None or positive (or 0 to disable only train logging)"
        self.monitor_names = filter_nones(to_set(monitor_names)) | network.get_loss_keys()
        self.system = System(network=network,
                             pipeline=pipeline,
                             traces=filter_nones(to_list(traces)),
                             log_steps=log_steps,
                             total_epochs=epochs,
                             train_steps_per_epoch=train_steps_per_epoch,
                             eval_steps_per_epoch=eval_steps_per_epoch,
                             eval_log_steps=eval_log_steps,
                             system_config=self.fe_summary())

    @property
    def pipeline(self) -> Pipeline:
        return self.system.pipeline

    @property
    def network(self) -> BaseNetwork:
        return self.system.network

    @property
    def traces(self) -> List[Union[Trace, Scheduler[Trace]]]:
        return self.system.traces

    @overload
    def fit(self, summary: None = None, warmup: bool = True, eager: bool = False) -> None:
        ...

    @overload
    def fit(self, summary: str, warmup: bool = True, eager: bool = False) -> Summary:
        ...

    def fit(self, summary: Optional[str] = None, warmup: bool = True, eager: bool = False) -> Optional[Summary]:
        """Train the network for the number of epochs specified by the estimator's constructor.

        Args:
            summary: A name for the experiment. If provided, the log history will be recorded in-memory and returned as
                a summary object at the end of training.
            warmup: Whether to perform warmup before training begins. The warmup procedure will test one step at every
                epoch where schedulers cause the execution graph to change. This can take some time up front, but can
                also save significant heartache on epoch 300 when the training unexpectedly fails due to a tensor size
                mismatch.
            eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
                PyTorch by nature is always in eager mode.

        Returns:
            A summary object containing the training history for this session iff a `summary` name was provided.
        """
        _verify_dependency_versions()
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'  # Prevent tf from constantly printing useless information
        draw()
        self.system.reset(summary, self.fe_summary())
        self._prepare_traces(run_modes={"train", "eval"})
        if warmup:
            self._warmup(eager=eager)
        self._start(run_modes={"train", "eval"}, eager=eager)
        return self.system.summary or None

    def _prepare_traces(self, run_modes: Set[str]) -> None:
        """Prepare information about the traces for execution.

        Add default traces into the traces_in_use list, also prints a warning if no model saver trace is detected.

        Args:
            run_modes: The current execution modes.
        """
        self.traces_in_use = [trace for trace in self.traces]
        if self.system.log_steps is not None:
            self.traces_in_use.append(Logger())
        # Look for any monitor names which should be automagically added.
        trace_outputs = set()
        extra_monitor_keys = set()
        for trace in sort_traces(get_current_items(self.traces_in_use, run_modes=run_modes), ds_ids=[]):
            trace_outputs.update(trace.get_outputs(ds_ids=[]))
            extra_monitor_keys.update(trace.fe_monitor_names - trace_outputs)
        # Add the essential traces
        if "train" in run_modes:
            self.traces_in_use.insert(0, TrainEssential(monitor_names=self.monitor_names.union(extra_monitor_keys)))
            no_save_warning = True
            for trace in get_current_items(self.traces_in_use, run_modes=run_modes):
                if isinstance(trace, (ModelSaver, BestModelSaver)):
                    no_save_warning = False
            if no_save_warning:
                warn("No ModelSaver Trace detected. Models will not be saved.")
        if "eval" in run_modes and "eval" in self.pipeline.get_modes():
            self.traces_in_use.insert(1, EvalEssential(monitor_names=self.monitor_names.union(extra_monitor_keys)))
        if "test" in run_modes and "test" in self.pipeline.get_modes():
            self.traces_in_use.insert(0, TestEssential(monitor_names=self.monitor_names.union(extra_monitor_keys)))
        # insert system instance to trace
        for trace in get_current_items(self.traces_in_use, run_modes=run_modes):
            trace.system = self.system

    @overload
    def test(self, summary: None = None, eager: bool = False) -> None:
        ...

    @overload
    def test(self, summary: str, eager: bool = False) -> Summary:
        ...

    def test(self, summary: Optional[str] = None, eager: bool = False) -> Optional[Summary]:
        """Run the pipeline / network in test mode for one epoch.

        Args:
            summary: A name for the experiment. If provided, the log history will be recorded in-memory and returned as
                a summary object at the end of training. If None, the default value will be whatever `summary` name was
                most recently provided to this Estimator's .fit() or .test() methods.
            eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
                PyTorch by nature is always in eager mode.

        Returns:
            A summary object containing the training history for this session iff the `summary` name is not None (after
            considering the default behavior above).
        """
        _verify_dependency_versions()
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'  # Prevent tf from constantly printing useless information
        self.system.reset_for_test(summary)
        self._prepare_traces(run_modes={"test"})
        self._start(run_modes={"test"}, eager=eager)
        return self.system.summary or None

    def _warmup(self, eager: bool = True) -> None:
        """Perform a test run of each pipeline and network signature epoch to make sure that training won't fail later.

        Traces are not executed in the warmup since they are likely to contain state variables which could become
        corrupted by running extra steps.

        Args:
            eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
                PyTorch by nature is always in eager mode.
        """
        all_traces = get_current_items(self.traces_in_use, run_modes={"train", "eval"})
        sort_traces(all_traces, ds_ids=[])  # This ensures that the traces can sort properly for on_begin and on_end
        monitor_names = self.monitor_names
        unmet_monitor_names = set(monitor_names)
        for mode in self.pipeline.get_modes() - {"test"}:
            scheduled_items = self.pipeline.get_scheduled_items(mode) + self.network.get_scheduled_items(
                mode) + self.get_scheduled_items(mode)
            signature_epochs = get_signature_epochs(scheduled_items, self.system.total_epochs, mode=mode)
            epochs_with_data = self.pipeline.get_epochs_with_data(total_epochs=self.system.total_epochs, mode=mode)
            for epoch in signature_epochs:
                if epoch not in epochs_with_data:
                    continue
                ds_ids = self.pipeline.get_ds_ids(epoch, mode)
                for ds_id in ds_ids:
                    trace_input_keys = set()
                    trace_output_keys = {"*"}
                    traces = get_current_items(self.traces_in_use, run_modes=mode, epoch=epoch, ds_id=ds_id)
                    for idx, trace in enumerate(traces):
                        if idx == 0:
                            # Trace 0 is either TrainEssential or EvalEssential. Their inputs are the keys which should
                            # be monitored, which is a union of self.monitor_names and potentially other keys which were
                            # found when looping through traces to look for fe_monitor_names.
                            monitor_names.update(trace.inputs)
                        else:
                            # We want to ignore monitor_names for for unmet requirement checking
                            trace_input_keys.update(trace.inputs)
                        trace_output_keys.update(trace.get_outputs(ds_ids=ds_ids))

                    with self.network(mode=mode,
                                      epoch=epoch,
                                      ds_id=ds_id,
                                      desired_output_keys=trace_input_keys | monitor_names,
                                      warmup=True,
                                      eager=eager):

                        network_input_keys = self.network.ctx_inputs
                        network_output_keys = self.network.ctx_outputs

                        # key checking
                        with self.pipeline(
                                mode=mode,
                                epoch=epoch,
                                ds_id=ds_id,
                                steps_per_epoch=None,
                                output_keys=(trace_input_keys - network_output_keys)
                                | network_input_keys | monitor_names) as loader:
                            loader = self._configure_loader(loader)
                            if isinstance(loader, tf.data.Dataset):
                                batch = list(loader.take(1))[0]
                            else:
                                with Suppressor(allow_pyprint=True, show_if_exception=True):
                                    # TF multi-gpu print-spams here in version 2.11
                                    batch = next(iter(loader))
                            batch = self._configure_tensor(loader, batch)
                        assert isinstance(batch, dict), \
                            f"please make sure data output format is dictionary (got {type(batch)})"
                        pipeline_output_keys = to_set(batch.keys())

                        unmet_monitor_names = unmet_monitor_names - (pipeline_output_keys | network_output_keys)
                        unmet_requirements = trace_input_keys - (pipeline_output_keys | network_output_keys
                                                                 | trace_output_keys)
                        assert not unmet_requirements, \
                            "found missing key(s) during epoch {} mode {} ds_id {}: {}".format(epoch, mode, ds_id,
                                                                                            unmet_requirements)
                        sort_traces(traces, ds_ids=ds_ids, available_outputs=pipeline_output_keys | network_output_keys)
                        trace_input_keys.update(traces[0].inputs)
                        self.network.run_step(batch)
        assert not unmet_monitor_names, "found missing key(s): {}".format(unmet_monitor_names)

    def get_scheduled_items(self, mode: str) -> List[Any]:
        """Get a list of items considered for scheduling.

        Args:
            mode: Current execution mode.

        Returns:
            List of schedulable items in estimator.
        """
        return self.traces_in_use

    def _start(self, run_modes: Set[str], eager: bool) -> None:
        """The outer training loop.

        This method invokes the trace on_begin method, runs the necessary 'train' and 'eval' epochs, and then invokes
        the trace on_end method.

        Args:
            run_modes: The current execution modes.
            eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
                PyTorch by nature is always in eager mode.
        """
        with Suppressor():
            # TODO - remove this after updating to TF > 2.11
            from tensorflow.python.autograph.pyct.static_analysis.liveness import Analyzer
            Analyzer.lamba_check(None, None)  # type: ignore
        all_traces = sort_traces(get_current_items(self.traces_in_use, run_modes=run_modes), ds_ids=[])
        with NonContext() if fe.fe_history_path is False else HistoryRecorder(
                self.system, self.filepath, db_path=fe.fe_history_path):
            try:
                self._run_traces_on_begin(traces=all_traces)
                if "train" in run_modes or "eval" in run_modes:
                    # If the training is re-starting from a restore wizard, it should re-run the last eval epoch
                    if self.system.epoch_idx > 0 and "eval" in self.pipeline.get_modes(epoch=self.system.epoch_idx):
                        self.system.mode = "eval"
                        self._run_epoch(eager=eager)
                    for self.system.epoch_idx in range(self.system.epoch_idx + 1, self.system.total_epochs + 1):
                        if "train" in self.pipeline.get_modes(epoch=self.system.epoch_idx):
                            self.system.mode = "train"
                            self._run_epoch(eager=eager)
                        if "eval" in self.pipeline.get_modes(epoch=self.system.epoch_idx):
                            self.system.mode = "eval"
                            self._run_epoch(eager=eager)
                else:
                    self._run_epoch(eager=eager)
            except EarlyStop:
                pass  # On early stopping we still want to run the final traces and return results
            self._run_traces_on_end(traces=all_traces)

    def _run_epoch(self, eager: bool) -> None:
        """A method to perform an epoch of activity.

        This method requires that the current mode and epoch already be specified within the self.system object.

        Args:
            eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
                PyTorch by nature is always in eager mode.
        """
        ds_ids = self.pipeline.get_ds_ids(self.system.epoch_idx, self.system.mode)
        epoch_traces = sort_traces(
            get_current_items(self.traces_in_use, run_modes=self.system.mode, epoch=self.system.epoch_idx),
            ds_ids=ds_ids)
        self._run_traces_on_epoch_begin(traces=epoch_traces)
        self.system.batch_idx = None
        end_epoch_data = Data()  # We will aggregate data over on_ds_end and put it into on_epoch_end for printing
        # run for each dataset
        for self.system.ds_id in ds_ids:
            ds_traces = get_current_items(self.traces_in_use,
                                          run_modes=self.system.mode,
                                          epoch=self.system.epoch_idx,
                                          ds_id=self.system.ds_id)
            trace_input_keys = set()
            for ds_trace in ds_traces:
                trace_input_keys.update(ds_trace.inputs)
            # Note that monitor_names are included in the trace_inputs here, rather than being excluded and then
            # manually union-ed again later as was done in in _warmup.

            with self.network(mode=self.system.mode,
                              epoch=self.system.epoch_idx,
                              ds_id=self.system.ds_id,
                              desired_output_keys=trace_input_keys,
                              eager=eager):

                network_input_keys = self.network.ctx_inputs
                network_output_keys = self.network.ctx_outputs

                with self.pipeline(mode=self.system.mode,
                                   epoch=self.system.epoch_idx,
                                   ds_id=self.system.ds_id,
                                   steps_per_epoch=self.system.steps_per_epoch,
                                   output_keys=(trace_input_keys - network_output_keys)
                                   | network_input_keys) as loader:

                    if self.system.mode == 'eval':
                        log_steps_per_epoch = math.ceil(
                            len(loader) /
                            loader.get_batch_size()) if not self.system.steps_per_epoch else self.system.steps_per_epoch
                        self.system.eval_log_steps = ([
                            1, log_steps_per_epoch // 3, (2 * log_steps_per_epoch) // 3, log_steps_per_epoch
                        ], log_steps_per_epoch) if not self.system.eval_log_steps_request else \
                            (self.system.eval_log_steps_request, log_steps_per_epoch)

                    loader = self._configure_loader(loader)
                    iterator = iter(loader)
                    with Suppressor(allow_pyprint=True, show_if_exception=True):
                        # multi-gpu tensorflow prints a ton of complaint messages here
                        batch = next(iterator)
                    ds_traces = sort_traces(ds_traces,
                                            available_outputs=to_set(batch.keys()) | network_output_keys,
                                            ds_ids=ds_ids)
                    per_ds_traces = [trace for trace in ds_traces if isinstance(trace, PerDSTrace)]
                    self._run_traces_on_ds_begin(traces=per_ds_traces)
                    while True:
                        try:
                            if self.system.mode == "train":
                                self.system.update_global_step()
                            self.system.update_batch_idx()
                            batch = self._configure_tensor(loader, batch)
                            self._run_traces_on_batch_begin(batch, traces=ds_traces)
                            batch = self.network.run_step(batch)
                            self._run_traces_on_batch_end(batch, traces=ds_traces)
                            if isinstance(loader,
                                          DataLoader) and ((self.system.batch_idx == self.system.train_steps_per_epoch
                                                            and self.system.mode == "train") or
                                                           (self.system.batch_idx == self.system.eval_steps_per_epoch
                                                            and self.system.mode == "eval")):
                                raise StopIteration
                            batch = next(iterator)
                        except StopIteration:
                            break
                    self._run_traces_on_ds_end(traces=per_ds_traces, data=end_epoch_data)
        self._run_traces_on_epoch_end(traces=epoch_traces, data=end_epoch_data)

    def _configure_loader(self, loader: Union[DataLoader, tf.data.Dataset]) -> Union[DataLoader, tf.data.Dataset]:
        """A method to configure a given dataloader for use with this Estimator's Network.

        This method will ensure that the `loader` returns the correct data type (tf.Tensor or torch.Tensor) depending on
         the requirements of the Network. It also handles issues with multi-gpu data sharding.

        Args:
            loader: A data loader to be modified.

        Returns:
            The potentially modified dataloader to be used for training.
        """

        new_loader = loader
        if isinstance(new_loader, DataLoader) and isinstance(self.network, TFNetwork):
            add_batch = bool(new_loader.batch_size)
            if hasattr(loader, 'fe_postprocess_fn') and loader.fe_postprocess_fn is not None:
                # The user is manually batching data and running ops on data batches. No reliable way to shortcut this
                # since ops might require specific batch composition.
                data_instance = next(iter(loader))
                add_batch = False
            else:
                # No batch-based ops so we can try and just use the OpDataset to more quickly get our data summary
                data_instance = loader.dataset[0]
                if isinstance(data_instance, list):
                    # This is a batched dataset
                    data_instance = data_instance[0]
                    add_batch = True
                if isinstance(data_instance, FilteredData):
                    # We got unlucky and drew filtered data as the zeroth element. Fall back to a slower but more robust
                    # analysis of the batch
                    data_instance = next(iter(loader))
                    add_batch = False
            data_instance = to_tensor(data_instance, target_type="tf")
            data_type = to_type(data_instance)
            data_shape = to_shape(data_instance, add_batch=add_batch, exact_shape=False)
            new_loader = tf.data.Dataset.from_generator(lambda: loader, data_type, output_shapes=data_shape)
            new_loader = new_loader.prefetch(1)
        if isinstance(new_loader, tf.data.Dataset):
            if self.system.train_steps_per_epoch and self.system.mode == "train":
                new_loader = new_loader.take(self.system.train_steps_per_epoch)
            if self.system.eval_steps_per_epoch and self.system.mode == "eval":
                new_loader = new_loader.take(self.system.eval_steps_per_epoch)
            if isinstance(tf.distribute.get_strategy(), tf.distribute.MirroredStrategy) and isinstance(
                    self.network, TFNetwork) and not isinstance(new_loader, DistributedDataset):
                # The default autoshard policy is file, changing it to data to avoid warning
                options = tf.data.Options()
                options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
                new_loader = new_loader.with_options(options)
                new_loader = tf.distribute.get_strategy().experimental_distribute_dataset(new_loader)
        return new_loader

    def _configure_tensor(self, loader: Union[DataLoader, tf.data.Dataset], batch: Dict[str, Any]) -> Dict[str, Any]:
        """A function to convert a batch of tf.Tensors to torch.Tensors if required.

        Returns:
            Either the original `batch`, or the `batch` converted to torch.Tensors if required.
        """
        # TODO - if user has torch loader but custom collate that doesn't return torch tensor, need to cast here
        if isinstance(loader, tf.data.Dataset) and isinstance(self.network, TorchNetwork):
            batch = to_tensor(batch, target_type="torch")
        return batch

    def _run_traces_on_begin(self, traces: Iterable[Trace]) -> None:
        """Invoke the on_begin methods of given traces.

        Args:
            traces: List of traces.
        """
        data = Data()
        restore = None
        for trace in traces:
            # Delay RestoreWizard until the end so that it can overwrite everyone's on_begin methods
            if isinstance(trace, RestoreWizard):
                restore = trace
                continue
            # Restore does need to run before the logger though
            if isinstance(trace, Logger) and restore:
                restore.on_begin(data)
                restore = None
            trace.on_begin(data)
        if restore:
            restore.on_begin(data)
        self._check_early_exit()

    def _run_traces_on_epoch_begin(self, traces: Iterable[Trace]) -> None:
        """Invoke the on_epoch_begin methods of given traces.

        Args:
            traces: List of traces.
        """
        data = Data()
        for trace in traces:
            trace.on_epoch_begin(data)
        self._check_early_exit()

    def _run_traces_on_ds_begin(self, traces: Iterable[PerDSTrace]) -> None:
        """Invoke the on_ds_begin methods of given traces.

        Args:
            traces: List of traces.
        """
        data = Data()
        for trace in traces:
            trace.on_ds_begin(data)
        self._check_early_exit()

    def _run_traces_on_batch_begin(self, batch: Dict[str, Any], traces: Iterable[Trace]) -> None:
        """Invoke the on_batch_begin methods of given traces.

        Args:
            batch: The batch data which was provided by the pipeline.
            traces: List of traces.
        """
        data = Data(batch)
        for trace in traces:
            trace.on_batch_begin(data)
        self._check_early_exit()

    def _run_traces_on_batch_end(self, batch: Dict[str, Any], traces: Iterable[Trace]) -> None:
        """Invoke the on_batch_end methods of given traces.

        Args:
            batch: The batch data which was provided by the pipeline.
            traces: List of traces.
        """
        data = Data(batch)
        for trace in traces:
            trace.on_batch_end(data)
        self._check_early_exit()

    def _run_traces_on_ds_end(self, traces: Iterable[PerDSTrace], data: Data) -> None:
        """Invoke the on_ds_begin methods of given traces.

        Args:
            traces: List of traces.
            data: Data into which to record results.
        """
        for trace in traces:
            trace.on_ds_end(data)
        self._check_early_exit()

    def _run_traces_on_epoch_end(self, traces: Iterable[Trace], data: Data) -> None:
        """Invoke the on_epoch_end methods of of given traces.

        Args:
            traces: List of traces.
            data: Data into which to record results.
        """
        for trace in traces:
            trace.on_epoch_end(data)
        self._check_early_exit()

    @staticmethod
    def _run_traces_on_end(traces: Iterable[Trace]) -> None:
        """Invoke the on_end methods of given traces.

        Args:
            traces: List of traces.
        """
        data = Data()
        traceability = None
        for trace in traces:
            if isinstance(trace, Traceability):
                # Delay traceability until the end so that it can capture all data including the total training time
                traceability = trace
                continue
            trace.on_end(data)
        if traceability:
            traceability.on_end(data)

    def _check_early_exit(self) -> None:
        """Determine whether training should be prematurely aborted.

        Raises:
            EarlyStop: If the system.stop_training flag has been set to True.
        """
        if self.system.stop_training:
            raise EarlyStop

fit

Train the network for the number of epochs specified by the estimator's constructor.

Parameters:

Name Type Description Default
summary Optional[str]

A name for the experiment. If provided, the log history will be recorded in-memory and returned as a summary object at the end of training.

None
warmup bool

Whether to perform warmup before training begins. The warmup procedure will test one step at every epoch where schedulers cause the execution graph to change. This can take some time up front, but can also save significant heartache on epoch 300 when the training unexpectedly fails due to a tensor size mismatch.

True
eager bool

Whether to run the training in eager mode. This is only related to TensorFlow training because PyTorch by nature is always in eager mode.

False

Returns:

Type Description
Optional[Summary]

A summary object containing the training history for this session iff a summary name was provided.

Source code in fastestimator/fastestimator/estimator.py
def fit(self, summary: Optional[str] = None, warmup: bool = True, eager: bool = False) -> Optional[Summary]:
    """Train the network for the number of epochs specified by the estimator's constructor.

    Args:
        summary: A name for the experiment. If provided, the log history will be recorded in-memory and returned as
            a summary object at the end of training.
        warmup: Whether to perform warmup before training begins. The warmup procedure will test one step at every
            epoch where schedulers cause the execution graph to change. This can take some time up front, but can
            also save significant heartache on epoch 300 when the training unexpectedly fails due to a tensor size
            mismatch.
        eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
            PyTorch by nature is always in eager mode.

    Returns:
        A summary object containing the training history for this session iff a `summary` name was provided.
    """
    _verify_dependency_versions()
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'  # Prevent tf from constantly printing useless information
    draw()
    self.system.reset(summary, self.fe_summary())
    self._prepare_traces(run_modes={"train", "eval"})
    if warmup:
        self._warmup(eager=eager)
    self._start(run_modes={"train", "eval"}, eager=eager)
    return self.system.summary or None

get_scheduled_items

Get a list of items considered for scheduling.

Parameters:

Name Type Description Default
mode str

Current execution mode.

required

Returns:

Type Description
List[Any]

List of schedulable items in estimator.

Source code in fastestimator/fastestimator/estimator.py
def get_scheduled_items(self, mode: str) -> List[Any]:
    """Get a list of items considered for scheduling.

    Args:
        mode: Current execution mode.

    Returns:
        List of schedulable items in estimator.
    """
    return self.traces_in_use

test

Run the pipeline / network in test mode for one epoch.

Parameters:

Name Type Description Default
summary Optional[str]

A name for the experiment. If provided, the log history will be recorded in-memory and returned as a summary object at the end of training. If None, the default value will be whatever summary name was most recently provided to this Estimator's .fit() or .test() methods.

None
eager bool

Whether to run the training in eager mode. This is only related to TensorFlow training because PyTorch by nature is always in eager mode.

False

Returns:

Type Description
Optional[Summary]

A summary object containing the training history for this session iff the summary name is not None (after

Optional[Summary]

considering the default behavior above).

Source code in fastestimator/fastestimator/estimator.py
def test(self, summary: Optional[str] = None, eager: bool = False) -> Optional[Summary]:
    """Run the pipeline / network in test mode for one epoch.

    Args:
        summary: A name for the experiment. If provided, the log history will be recorded in-memory and returned as
            a summary object at the end of training. If None, the default value will be whatever `summary` name was
            most recently provided to this Estimator's .fit() or .test() methods.
        eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
            PyTorch by nature is always in eager mode.

    Returns:
        A summary object containing the training history for this session iff the `summary` name is not None (after
        considering the default behavior above).
    """
    _verify_dependency_versions()
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'  # Prevent tf from constantly printing useless information
    self.system.reset_for_test(summary)
    self._prepare_traces(run_modes={"test"})
    self._start(run_modes={"test"}, eager=eager)
    return self.system.summary or None

enable_deterministic

Invoke to set random seed for deterministic training.

The determinism only works for tensorflow >= 2.1 and pytorch >= 1.14, and some model layers don't support.

Known failing layers: * tf.keras.layers.UpSampling2D

Parameters:

Name Type Description Default
seed int

The random seed to use for training.

required
Source code in fastestimator/fastestimator/estimator.py
def enable_deterministic(seed: int) -> None:
    """Invoke to set random seed for deterministic training.

    The determinism only works for tensorflow >= 2.1 and pytorch >= 1.14, and some model layers don't support.

    Known failing layers:
    * tf.keras.layers.UpSampling2D

    Args:
        seed: The random seed to use for training.
    """
    fe.fe_deterministic_seed = seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = str(1)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    torch.manual_seed(seed)
    tf.keras.utils.set_random_seed(seed)
    tf.config.experimental.enable_op_determinism()

record_history

Change the default location for history tracking.

Parameters:

Name Type Description Default
path Union[bool, str]

The path to save experiment histories. Pass True to use the default location of ~/fastestimator_data/history.db. Pass False to disable history tracking.

required
Source code in fastestimator/fastestimator/estimator.py
def record_history(path: Union[bool, str]) -> None:
    """Change the default location for history tracking.

    Args:
        path: The path to save experiment histories. Pass True to use the default location of
            ~/fastestimator_data/history.db. Pass False to disable history tracking.
    """
    if path in (None, True):
        fe.fe_history_path = None
    else:
        fe.fe_history_path = path