from AIsaac.all import *
import fastcore.all as fc
import matplotlib.pyplot as plt,matplotlib as mpl
import torch
from datasets import load_dataset
from torch import nn
from torcheval.metrics import MulticlassAccuracy
import torchvision.transforms.functional as TF
Callbacks
The AIsaac
library is an extremely flexible framework that uses callbacks a lot. They are probably more widely used than in any other framework. It’s super heavily influenced by callbacks system in the miniai
library developed as part of the fastai course, but goes a bit further in that direction in a couple of aspects. Because of this it’s very important to understand how to use AIsaac
uses them and how you can leverage that.
Setup
Here I will set up the needed pieces for the tutorial. This includes imports and loading a small subset of the fashion MNIST dataset.
=2, linewidth=140, sci_mode=False)
torch.set_printoptions(precision1)
torch.manual_seed('image.cmap'] = 'gray'
mpl.rcParams[42) set_seed(
= 0.28, 0.35
xmean,xstd @inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]
= load_dataset('fashion_mnist').with_transform(transformi)
_dataset = DataLoaders.from_dataset_dict(_dataset, 64, num_workers=4) dls
Found cached dataset fashion_mnist (/home/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1)
Basic Trainer
= Trainer(dls,
trainer
nn.CrossEntropyLoss(),
torch.optim.Adam, = get_model_conv(),
model =[BasicTrainCB(),MetricsCB(Accuracy=MulticlassAccuracy()), DeviceCB(),ProgressCB()])
callbacks trainer.fit()
train | valid | Accuracy | |
---|---|---|---|
0 | 0.601941 | 0.436749 | 0.8439 |
train | valid | Accuracy | |
---|---|---|---|
1 | 0.389781 | 0.381396 | 0.858400 |
train | valid | Accuracy | |
---|---|---|---|
2 | 0.339881 | 0.348775 | 0.870900 |
So we passed in a DataLoaders
, a pytorch loss, a pytorch optimizer, a pytorch model, and some callbacks. As you can see by running Trainer.fit
it ran a full training loop. The training loop is defined entirely in the callbacks. For this tutorial we are focusing on the callbacks. Please refer to pytorch documentation for the pytorch pieces.
One batch
Let’s see how a batch is processed. The source code for the batch trainer is very small and there’s two things we need to understand about it, the decorator and the run_callbacks method.
@with_cbs('batch', CancelBatchException)
def one_batch(self):
self.run_callbacks(['predict','get_loss'])
if self.training: self.run_callbacks(['before_backward','backward','step','zero_grad'
run_callbacks
The run_callbacks
method is what actually executes the callbacks code. As you can see a batch is just all callbacks.
The first run_callbacks
does the following:
- Sorts all callbacks according to the “order” attribute (defaults to 0)
- Loops through
['predict','get_loss']
- Loops through ordered callbacks:
- If “predict” method exists for that callback then run it
- Loops through ordered callbacks:
Let’s look at the BasicTrainCB
code. Each element needed to process the batch is here. “before_backward” is not defined so this callback won’t do anything in that step. We could however define a callback that happens before the backward pass if we want to add functionality there to our training loop.
view_source_code(BasicTrainCB)
class BasicTrainCB:
'''Callback for basic pytorch training loop'''
def predict(self,trainer): trainer.preds = trainer.model(trainer.batch[0])
def get_loss(self,trainer): trainer.loss = trainer.loss_func(trainer.preds,trainer.batch[1])
def backward(self,trainer): trainer.loss.backward()
def step(self,trainer): trainer.opt.step()
def zero_grad(self,trainer): trainer.opt.zero_grad()
with_cbs
With cbs adds two pieces of functionality.
- Ability to exit and skip the rest of the function (ie the batch). This is similar to how you can use
continue
in a for loop. This can be done with raising the particularexception
. - Adds before, after, and cleanup callbacks to the function. Before and after run before and after the function. Cleanup will always run, even if an exception is thrown.
one_batch
example
- To run a callback before or after every batch, you would use
before_batch
andafter_batch
- To skip a batch, you would raise a
CancelBatchException
in a callback as that’s what is passed to the decorator. - The
cleanup_batch
callback will always run if one exists, even if you skipped the batch.after_batch
will be skipped once theCancelBatchException
is raised.
view_source_code(with_cbs)
class with_cbs:
def __init__(self, nm, exception): fc.store_attr()
def __call__(self, f):
def _f(o, *args, **kwargs):
try:
o.run_callbacks(f'before_{self.nm}')
f(o, *args, **kwargs)
o.run_callbacks(f'after_{self.nm}')
except self.exception: pass
finally: o.run_callbacks(f'cleanup_{self.nm}')
return _f
Other Callbacks
Another example of a callback is the device callback, that puts things onto whatever device we want (ie GPU)
view_source_code(DeviceCB)
class DeviceCB:
'''Callback to train on specific device'''
def __init__(self, device=def_device): self.device=device
def before_fit(self, trainer):
'''Moves model to device'''
if hasattr(trainer.model, 'to'): trainer.model.to(self.device)
def before_batch(self, trainer):
'''moves batch to device'''
trainer.batch = to_device(trainer.batch, device=self.device)
view_source_code(MetricsCB)
class MetricsCB:
'''Callback to track train/valid loss + metrics'''
def __init__(self, **metrics):
self.metrics = metrics
self.losses = {'train':Mean(),'valid':Mean()}
self.metrics_epoch,self.losses_epoch,self.losses_batch = [],[],[]
def after_batch(self,trainer):
'''stores losses and metrics for batch'''
self.losses[f"{'train' if trainer.training else 'valid'}"].update(to_cpu(trainer.loss),weight=len(trainer.batch[1]))
if not trainer.training:
preds,batch = map(to_cpu,[trainer.preds,trainer.batch[1]])
for k in self.metrics: self.metrics[k].update(preds,batch)
self.losses_batch.append({'training':trainer.training,'loss':to_cpu(trainer.loss)})
def cleanup_epoch(self,trainer):
'''compute metrics and append to epoch stats and display'''
if not trainer.training:
self.metrics_epoch.append({name:float(metric.compute()) for name, metric in self.metrics.items()})
self.losses_epoch.append({name:float(metric.compute()) for name, metric in self.losses.items()})
for metric in self.metrics.values(): metric.reset()
for metric in self.losses.values(): metric.reset()
In addition, the MetricsCB
in the example above is responsible for calculating and tracking the losses, the metrics it’s initalized with, and logging that every epoch.
All functionality that is done in the training loop is managed through callbacks.
Epochs work similarly to batches with callbacks, and there is also a fit method which also executes callbacks in the same way.
- Batch callbacks
- before_batch
- predict
- get_loss
- before_backward
- backward
- step
- zero_grad
- after_batch
- cleanup_batch
- Epoch callbacks
- before_epoch
- after_epoch
- cleanup_epoch
- Fit callbacks
- before_fit
- after_fit
- cleanup_fit
- CancelBatchException
- CancelEpochException
- CancelFitException
Callback Subclassing/Inheritance
We can inherit from the BasicTrainCB
because momentum is mostly the same as a normal training loop with one small tweak that allows previous gradients to be accounted for. In this way we can build callbacks from other similar callbacks.
Rather than subclassing the Trainer, we subclass callbacks.
view_source_code(MomentumTrainCB)
class MomentumTrainCB(BasicTrainCB):
def __init__(self,momentum): self.momentum = momentum
def zero_grad(self,trainer):
'''Multiply grads by momentum (instead of zero)'''
with torch.no_grad():
for p in trainer.model.parameters(): p.grad *= self.momentum
Multiple Callbacks
Now that we know how to modify and extend the training loop with individual callbacks, one next logical question is how to we create abstractions with this. For example, we probably don’t want to add DeviceCB
, MetricsCB
, and BasicTrainCB
to every Trainer we create as lots of Trainers will use those. As we build more complex models we may want combinations of callbacks as well that are commonly used together, rather than having to memorize lots of callback recipes.
To do this, we create a recursive call when we add callbacks that allows us to group callbacks together. Instead of subclassing the Trainer
in a way that may be more common in other frameworks, we group the callbacks together. We create these callbacks that are a combination of other callbacks, by defining the callbacks
attribute in a callback.
Adding callbacks works recursively. Once a callback is added, it will check for a callbacks
attribute and add those callbacks
. Those in turn could have callbacks
attributes of their own.
Here is an example of a group of callbacks that will likely go together. This simple class will add all of these callbacks when used. While this class is not a callback itself because it does not have a callback method (ie before_batch
), you have the flexibility to add those methods to this class to add behavior to your trainer specific to this grouping of callbacks. When passed as a Callback it will add the 5 callbacks stores in self.callbacks
. It could also be a callback itsel
view_source_code(CoreCBs)
class CoreCBs:
def __init__(self,device=def_device,module_filter=fc.noop,**metrics):
self.callbacks = [DeviceCB(device=device),
BasicTrainCB(),
MetricsCB(**metrics),
ProgressCB(),
ActivationStatsCB(module_filter)]
= Trainer(dls,
trainer
nn.CrossEntropyLoss(),
torch.optim.Adam,
get_model_conv(),=[CoreCBs(Accuracy=MulticlassAccuracy(),)]) callbacks
trainer.fit()
train | valid | Accuracy | |
---|---|---|---|
0 | 0.599093 | 0.452553 | 0.8334 |
train | valid | Accuracy | |
---|---|---|---|
1 | 0.381516 | 0.377897 | 0.862000 |
train | valid | Accuracy | |
---|---|---|---|
2 | 0.332677 | 0.343139 | 0.876300 |