Callbacks

The AIsaac library is an extremely flexible framework that uses callbacks a lot. They are probably more widely used than in any other framework. It’s super heavily influenced by callbacks system in the miniai library developed as part of the fastai course, but goes a bit further in that direction in a couple of aspects. Because of this it’s very important to understand how to use AIsaac uses them and how you can leverage that.

Setup

Here I will set up the needed pieces for the tutorial. This includes imports and loading a small subset of the fashion MNIST dataset.

from AIsaac.all import *
import fastcore.all as fc
import matplotlib.pyplot as plt,matplotlib as mpl
import torch
from datasets import load_dataset
from torch import nn
from torcheval.metrics import MulticlassAccuracy
import torchvision.transforms.functional as TF
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'
set_seed(42)
xmean,xstd = 0.28, 0.35
@inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]

_dataset = load_dataset('fashion_mnist').with_transform(transformi)
dls = DataLoaders.from_dataset_dict(_dataset, 64, num_workers=4)
Found cached dataset fashion_mnist (/home/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1)

Basic Trainer

trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.Adam, 
                  model = get_model_conv(),
                  callbacks=[BasicTrainCB(),MetricsCB(Accuracy=MulticlassAccuracy()), DeviceCB(),ProgressCB()])
trainer.fit()
train valid Accuracy
0 0.601941 0.436749 0.8439
  train valid Accuracy
1 0.389781 0.381396 0.858400
  train valid Accuracy
2 0.339881 0.348775 0.870900

So we passed in a DataLoaders, a pytorch loss, a pytorch optimizer, a pytorch model, and some callbacks. As you can see by running Trainer.fit it ran a full training loop. The training loop is defined entirely in the callbacks. For this tutorial we are focusing on the callbacks. Please refer to pytorch documentation for the pytorch pieces.

One batch

Let’s see how a batch is processed. The source code for the batch trainer is very small and there’s two things we need to understand about it, the decorator and the run_callbacks method.

@with_cbs('batch', CancelBatchException)
def one_batch(self):
    self.run_callbacks(['predict','get_loss'])
    if self.training: self.run_callbacks(['before_backward','backward','step','zero_grad'

run_callbacks

The run_callbacks method is what actually executes the callbacks code. As you can see a batch is just all callbacks.

The first run_callbacks does the following:

run_callbacks pseudo code
  • Sorts all callbacks according to the “order” attribute (defaults to 0)
  • Loops through ['predict','get_loss']
    • Loops through ordered callbacks:
      • If “predict” method exists for that callback then run it

Let’s look at the BasicTrainCB code. Each element needed to process the batch is here. “before_backward” is not defined so this callback won’t do anything in that step. We could however define a callback that happens before the backward pass if we want to add functionality there to our training loop.

view_source_code(BasicTrainCB)

class BasicTrainCB:
    '''Callback for basic pytorch training loop'''
    def predict(self,trainer): trainer.preds = trainer.model(trainer.batch[0])
    def get_loss(self,trainer): trainer.loss = trainer.loss_func(trainer.preds,trainer.batch[1])
    def backward(self,trainer): trainer.loss.backward()
    def step(self,trainer): trainer.opt.step()
    def zero_grad(self,trainer): trainer.opt.zero_grad()

with_cbs

With cbs adds two pieces of functionality.

  • Ability to exit and skip the rest of the function (ie the batch). This is similar to how you can use continue in a for loop. This can be done with raising the particular exception.
  • Adds before, after, and cleanup callbacks to the function. Before and after run before and after the function. Cleanup will always run, even if an exception is thrown.
one_batch example
  • To run a callback before or after every batch, you would use before_batch and after_batch
  • To skip a batch, you would raise a CancelBatchException in a callback as that’s what is passed to the decorator.
  • The cleanup_batch callback will always run if one exists, even if you skipped the batch. after_batch will be skipped once the CancelBatchException is raised.
view_source_code(with_cbs)

class with_cbs:
    def __init__(self, nm, exception): fc.store_attr()
    def __call__(self, f):
        def _f(o, *args, **kwargs):
            try:
                o.run_callbacks(f'before_{self.nm}')
                f(o, *args, **kwargs)
                o.run_callbacks(f'after_{self.nm}')
            except self.exception: pass
            finally: o.run_callbacks(f'cleanup_{self.nm}')
        return _f

Other Callbacks

Another example of a callback is the device callback, that puts things onto whatever device we want (ie GPU)

view_source_code(DeviceCB)

class DeviceCB:
    '''Callback to train on specific device'''
    def __init__(self, device=def_device): self.device=device
    def before_fit(self, trainer):
        '''Moves model to device'''
        if hasattr(trainer.model, 'to'): trainer.model.to(self.device)
    def before_batch(self, trainer): 
        '''moves batch to device'''
        trainer.batch = to_device(trainer.batch, device=self.device)
view_source_code(MetricsCB)

class MetricsCB:
    '''Callback to track train/valid loss + metrics'''
    def __init__(self, **metrics):
        self.metrics = metrics
        self.losses = {'train':Mean(),'valid':Mean()}
        self.metrics_epoch,self.losses_epoch,self.losses_batch = [],[],[]
            
    def after_batch(self,trainer):
        '''stores losses and metrics for batch'''        
        self.losses[f"{'train' if trainer.training else 'valid'}"].update(to_cpu(trainer.loss),weight=len(trainer.batch[1]))
        if not trainer.training:
            preds,batch = map(to_cpu,[trainer.preds,trainer.batch[1]])    
            for k in self.metrics: self.metrics[k].update(preds,batch)
        self.losses_batch.append({'training':trainer.training,'loss':to_cpu(trainer.loss)})
            
    def cleanup_epoch(self,trainer):
        '''compute metrics and append to epoch stats and display'''
        if not trainer.training:
            self.metrics_epoch.append({name:float(metric.compute()) for name, metric in self.metrics.items()})
            self.losses_epoch.append({name:float(metric.compute()) for name, metric in self.losses.items()})

            for metric in self.metrics.values(): metric.reset()
            for metric in self.losses.values(): metric.reset()

In addition, the MetricsCB in the example above is responsible for calculating and tracking the losses, the metrics it’s initalized with, and logging that every epoch.

All functionality that is done in the training loop is managed through callbacks.

Epochs work similarly to batches with callbacks, and there is also a fit method which also executes callbacks in the same way.

Available Callback List
  • Batch callbacks
    • before_batch
    • predict
    • get_loss
    • before_backward
    • backward
    • step
    • zero_grad
    • after_batch
    • cleanup_batch
  • Epoch callbacks
    • before_epoch
    • after_epoch
    • cleanup_epoch
  • Fit callbacks
    • before_fit
    • after_fit
    • cleanup_fit
Cancel Exceptions
  • CancelBatchException
  • CancelEpochException
  • CancelFitException

Callback Subclassing/Inheritance

We can inherit from the BasicTrainCB because momentum is mostly the same as a normal training loop with one small tweak that allows previous gradients to be accounted for. In this way we can build callbacks from other similar callbacks.

Rather than subclassing the Trainer, we subclass callbacks.

view_source_code(MomentumTrainCB)

class MomentumTrainCB(BasicTrainCB):
    def __init__(self,momentum): self.momentum = momentum
    def zero_grad(self,trainer): 
        '''Multiply grads by momentum (instead of zero)'''
        with torch.no_grad():
            for p in trainer.model.parameters(): p.grad *= self.momentum

Multiple Callbacks

Now that we know how to modify and extend the training loop with individual callbacks, one next logical question is how to we create abstractions with this. For example, we probably don’t want to add DeviceCB, MetricsCB, and BasicTrainCB to every Trainer we create as lots of Trainers will use those. As we build more complex models we may want combinations of callbacks as well that are commonly used together, rather than having to memorize lots of callback recipes.

To do this, we create a recursive call when we add callbacks that allows us to group callbacks together. Instead of subclassing the Trainer in a way that may be more common in other frameworks, we group the callbacks together. We create these callbacks that are a combination of other callbacks, by defining the callbacks attribute in a callback.

Callbacks Attribute

Adding callbacks works recursively. Once a callback is added, it will check for a callbacks attribute and add those callbacks. Those in turn could have callbacks attributes of their own.

Here is an example of a group of callbacks that will likely go together. This simple class will add all of these callbacks when used. While this class is not a callback itself because it does not have a callback method (ie before_batch), you have the flexibility to add those methods to this class to add behavior to your trainer specific to this grouping of callbacks. When passed as a Callback it will add the 5 callbacks stores in self.callbacks. It could also be a callback itsel

view_source_code(CoreCBs)

class CoreCBs:
    def __init__(self,device=def_device,module_filter=fc.noop,**metrics):
        self.callbacks = [DeviceCB(device=device),
                          BasicTrainCB(),
                          MetricsCB(**metrics),
                          ProgressCB(),
                          ActivationStatsCB(module_filter)]
trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.Adam, 
                  get_model_conv(),
                  callbacks=[CoreCBs(Accuracy=MulticlassAccuracy(),)])
trainer.fit()
train valid Accuracy
0 0.599093 0.452553 0.8334
  train valid Accuracy
1 0.381516 0.377897 0.862000
  train valid Accuracy
2 0.332677 0.343139 0.876300