Callback System

Basics

Ivory implements a simple but powerful callback system.

Here is the list of callback functions in the order of invocation:

import ivory.core.base

ivory.core.base.Callback.METHODS

[2] 2020-06-20 15:23:39 (3.00ms) python3 (6.37s)

['on_init_begin',
 'on_init_end',
 'on_fit_begin',
 'on_epoch_begin',
 'on_train_begin',
 'on_train_end',
 'on_val_begin',
 'on_val_end',
 'on_epoch_end',
 'on_fit_end',
 'on_test_begin',
 'on_test_end']

Any class that defines these functions can be a callback.

class SimpleCallback:  # No base class is needed.
    # You don't have to define all of the callback functions
    def on_fit_begin(self, run):  # Must have an only `run` argument.
        print(f'on_fit_begin is called from id={id(run)}')
        # Do something with `run`.

[3] 2020-06-20 15:23:39 (3.00ms) python3 (6.38s)

To invoke callback functions, create a CallbackCaller instance.

caller = ivory.core.base.CallbackCaller(simple=SimpleCallback())
caller

[4] 2020-06-20 15:23:39 (4.00ms) python3 (6.38s)

CallbackCaller(num_instances=1)

The number of registered instances is 1.

list(caller)

[5] 2020-06-20 15:23:39 (3.00ms) python3 (6.38s)

['simple']

Then call CallbackCaller.create_callbacks() to build a callback network.

caller.create_callbacks()
caller

[6] 2020-06-20 15:23:39 (4.00ms) python3 (6.39s)

CallbackCaller(num_instances=13)

The number of instances increased up to 13.

list(caller)

[7] 2020-06-20 15:23:39 (4.00ms) python3 (6.39s)

['simple',
 'on_init_begin',
 'on_init_end',
 'on_fit_begin',
 'on_epoch_begin',
 'on_train_begin',
 'on_train_end',
 'on_val_begin',
 'on_val_end',
 'on_epoch_end',
 'on_fit_end',
 'on_test_begin',
 'on_test_end']

Callback functions are added to the caller instance. Let's inspect some callback functions.

caller.on_init_begin

[8] 2020-06-20 15:23:39 (4.00ms) python3 (6.40s)

Callback([])

This is an empty callback because the caller has no instances that define the on_init_begin(). On the other hand,

caller.on_fit_begin

[9] 2020-06-20 15:23:39 (4.00ms) python3 (6.40s)

Callback(['simple'])

The simple instance is registered as a receiver of the on_fit_begin(). We can call this.

caller.on_fit_begin()

[10] 2020-06-20 15:23:39 (5.00ms) python3 (6.40s)

on_fit_begin is called from id=1375756371208
id(caller)

[11] 2020-06-20 15:23:39 (4.00ms) python3 (6.41s)

1375756371208

This caller-receiver network among arbitrary instance collection builds a complex machine learning workflow.

Run class is a subclass of the CallbackCaller and performs more library-specific process. We uses Run below.

Example Callback: Results

To work with the Results callback, we create a set of data and a model. For more details about the following code, see Creating Instance section.

import yaml
from ivory.core.instance import create_instance

# A helper function.
def create(doc, name, **kwargs):
    params = yaml.safe_load(doc)
    return create_instance(params, name, **kwargs)

doc = """
library: torch
datasets:
  data:
    class: rectangle.data.Data
    n_splits: 5
  dataset:
  fold: 0
model:
  class: rectangle.torch.Model
  hidden_sizes: [3, 4, 5]
"""
datasets = create(doc, 'datasets')
model = create(doc, 'model')

[12] 2020-06-20 15:23:39 (8.00ms) python3 (6.42s)

The Results callback stores index, output, and target data. To save memory, a Results instance ignores input data.

# import ivory.callbacks.results  # For Scikit-learn or TensorFlow.
import ivory.torch.results

results = ivory.torch.results.Results()
results

[13] 2020-06-20 15:23:39 (5.00ms) python3 (6.42s)

Results([])
import ivory.core.run

run = ivory.core.run.Run(
    datasets=datasets,
    model=model,
    results=results
)
run.create_callbacks()
run

[14] 2020-06-20 15:23:39 (4.00ms) python3 (6.43s)

Run(num_instances=15)
# A helper function
def print_callbacks(obj):
    for func in ivory.core.base.Callback.METHODS:
        if hasattr(obj, func) and callable(getattr(obj, func)):
            print(func)

print_callbacks(results)  

[15] 2020-06-20 15:23:39 (8.00ms) python3 (6.43s)

on_train_begin
on_train_end
on_val_end
on_test_begin
on_test_end

Let's play with the Results callback. Results.step() records the current index, output, and target.

import torch

# For simplicity, just one epoch with some batches.
run.on_train_begin()
dataset = run.datasets.train
for k in range(3):
    index, input, target = dataset[4 * k : 4 * (k + 1)]
    input, target = torch.tensor(input), torch.tensor(target)
    output = run.model(input)
    run.results.step(index, output, target)
    # Do something for example parameter update or early stopping.
run.on_train_end()
run.on_val_begin()  # Can call even if there is no callback.
dataset = run.datasets.val
for k in range(2):
    index, input, target = dataset[4 * k : 4 * (k + 1)]
    input, target = torch.tensor(input), torch.tensor(target)
    output = run.model(input)
    run.results.step(index, output, target)
run.on_val_end()
run.on_epoch_end()

results

[16] 2020-06-20 15:23:39 (10.0ms) python3 (6.44s)

Results(['train', 'val'])

We performed a train and validation loop so that the Results instance has these data, but doesn't have test data.

results.train

[17] 2020-06-20 15:23:39 (4.00ms) python3 (6.45s)

Dict(['index', 'output', 'target'])
results.train.index  # The length is 4 x 3.

[18] 2020-06-20 15:23:39 (4.00ms) python3 (6.45s)

array([ 0,  2,  3,  4,  5,  6,  7, 10, 12, 13, 15, 16])
results.val.index  # The length is 4 x 2.

[19] 2020-06-20 15:23:39 (3.00ms) python3 (6.45s)

array([ 1,  8, 14, 27, 30, 31, 34, 45])
results.val.output

[20] 2020-06-20 15:23:39 (5.00ms) python3 (6.46s)

array([[-0.19664028],
       [-0.18804704],
       [-0.18011682],
       [-0.1660739 ],
       [-0.19168824],
       [-0.1673036 ],
       [-0.20202819],
       [-0.18093993]], dtype=float32)
results.val.target

[21] 2020-06-20 15:23:39 (5.00ms) python3 (6.46s)

array([[ 9.582911 ],
       [ 8.156037 ],
       [ 5.7045836],
       [ 3.401937 ],
       [ 7.614189 ],
       [ 3.2392535],
       [15.3450165],
       [ 4.47484  ]], dtype=float32)

Other Callback

There are several callback such as Metrics, Monitor, etc. We will learn about them in next Training a Model tutorial.