Skip to content

pbm_calibrator

PBMCalibrator

Bases: Trace

A trace to generate a PlattBinnerMarginalCalibrator given a set of predictions.

Unlike many common calibration error correction algorithms, this one has actual theoretical bounds on the quality of its output: https://arxiv.org/pdf/1909.10155v1.pdf. This trace is commonly used together with the Calibrate NumpyOp for postprocessing. This trace will collect data from whichever mode it is set to run on in order to perform empirical probability calibration. The calibrated predictions will be output on epoch end. The trained calibration function will also be saved if save_path is provided.

Parameters:

Name Type Description Default
true_key str

Name of the key that corresponds to ground truth in the batch dictionary.

required
pred_key str

Name of the key that corresponds to predicted score in the batch dictionary.

required
mode Union[str, Set[str]]

What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

'eval'
ds_id Union[None, str, Iterable[str]]

What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all ds_ids except for a particular one, you can pass an argument like "!ds1".

None
output_name Optional[str]

What to call the output from this trace. If None, the default will be '_calibrated'.

None
save_path Optional[str]

Where to save the calibrator generated by this Trace. If None, then no saving will be performed.

None
save_if_key Optional[str]

Name of a key to control whether to save the calibrator. For example "since_best_acc". If provided, then the calibrator will only be saved when the save_if_key value is zero.

None

Raises:

Type Description
ValueError

If 'save_if_key' is provided but no 'save_path' is given.

Source code in fastestimator/fastestimator/trace/adapt/pbm_calibrator.py
@traceable()
class PBMCalibrator(Trace):
    """A trace to generate a PlattBinnerMarginalCalibrator given a set of predictions.

    Unlike many common calibration error correction algorithms, this one has actual theoretical bounds on the quality
    of its output: https://arxiv.org/pdf/1909.10155v1.pdf. This trace is commonly used together with the Calibrate
    NumpyOp for postprocessing. This trace will collect data from whichever `mode` it is set to run on in order to
    perform empirical probability calibration. The calibrated predictions will be output on epoch end. The trained
    calibration function will also be saved if `save_path` is provided.

    Args:
        true_key: Name of the key that corresponds to ground truth in the batch dictionary.
        pred_key: Name of the key that corresponds to predicted score in the batch dictionary.
        mode: What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        ds_id: What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all
            ds_ids except for a particular one, you can pass an argument like "!ds1".
        output_name: What to call the output from this trace. If None, the default will be '<pred_key>_calibrated'.
        save_path: Where to save the calibrator generated by this Trace. If None, then no saving will be performed.
        save_if_key: Name of a key to control whether to save the calibrator. For example "since_best_acc". If provided,
            then the calibrator will only be saved when the save_if_key value is zero.

    Raises:
        ValueError: If 'save_if_key' is provided but no 'save_path' is given.
    """
    system: System

    def __init__(self,
                 true_key: str,
                 pred_key: str,
                 output_name: Optional[str] = None,
                 save_path: Optional[str] = None,
                 save_if_key: Optional[str] = None,
                 mode: Union[str, Set[str]] = "eval",
                 ds_id: Union[None, str, Iterable[str]] = None) -> None:
        if output_name is None:
            output_name = f"{pred_key}_calibrated"
        if save_if_key is not None and save_path is None:
            raise ValueError("If 'save_if_key' is provided, then a 'save_path' must also be provided.")
        super().__init__(inputs=[true_key, pred_key] + to_list(save_if_key),
                         outputs=output_name,
                         mode=mode,
                         ds_id=ds_id)
        self.y_true = []
        self.y_pred = []
        if save_path is not None:
            save_path = os.path.abspath(os.path.normpath(save_path))
        self.save_path = save_path

    @property
    def true_key(self) -> str:
        return self.inputs[0]

    @property
    def pred_key(self) -> str:
        return self.inputs[1]

    @property
    def save_key(self) -> Optional[str]:
        if len(self.inputs) == 3:
            return self.inputs[2]
        return None

    def on_epoch_begin(self, data: Data) -> None:
        self.y_true = []
        self.y_pred = []

    def on_batch_end(self, data: Data) -> None:
        y_true, y_pred = to_number(data[self.true_key]), to_number(data[self.pred_key])
        if y_true.shape[-1] > 1 and y_true.ndim > 1:
            y_true = np.argmax(y_true, axis=-1)
        assert y_pred.shape[0] == y_true.shape[0]
        self.y_true.extend(y_true)
        self.y_pred.extend(y_pred)

    def on_epoch_end(self, data: Data) -> None:
        self.y_true = np.squeeze(np.stack(self.y_true))
        self.y_pred = np.stack(self.y_pred)
        calibrator = cal.PlattBinnerMarginalCalibrator(num_calibration=len(self.y_true), num_bins=10)
        calibrator.train_calibration(probs=self.y_pred, labels=self.y_true)
        if self.save_path:
            if not self.save_key or (self.save_key and to_number(data[self.save_key]) == 0):
                with open(self.save_path, 'wb') as f:
                    dill.dump(calibrator.calibrate, file=f)
                print(f"FastEstimator-PBMCalibrator: Calibrator written to {self.save_path}")
        data.write_without_log(self.outputs[0], calibrator.calibrate(self.y_pred))