trainer

Training loop
xmean,xstd = 0.28, 0.35
@inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]

_dataset = load_dataset('fashion_mnist').with_transform(transformi)
100%|██████████| 2/2 [00:00<00:00, 344.12it/s]
_dataset = sample_dataset_dict(_dataset)
dls = DataLoaders.from_dataset_dict(_dataset, 1024, num_workers=4)

Base Trainer


source

CancelEpochException

Common base class for all non-exit exceptions.


source

CancelBatchException

Common base class for all non-exit exceptions.


source

CancelFitException

Common base class for all non-exit exceptions.


source

Callback

 Callback ()

Initialize self. See help(type(self)) for accurate signature.


source

summarize_callbacks

 summarize_callbacks (trainer)
def summarize_model(model,batch,row_settings=("var_names",),verbose=0,depth=3,col_names=("input_size","output_size","kernel_size")):
    
    # Other useful columns: "num_params","mult_adds"
    return torchinfo.summary(model,input_data=batch,row_settings=row_settings,verbose=verbose,depth=depth,col_names=col_names)

source

Trainer

 Trainer (dls, loss_func, opt_func, model, callbacks)

Initialize self. See help(type(self)) for accurate signature.

trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.Adam, 
                  get_model_conv(),
                  callbacks=[])
trainer.summarize_model()
===================================================================================================================
Layer (type (var_name))                  Input Shape               Output Shape              Kernel Shape
===================================================================================================================
Sequential (Sequential)                  [500, 1, 28, 28]          [500, 10]                 --
├─Sequential (0)                         [500, 1, 28, 28]          [500, 8, 14, 14]          --
│    └─Conv2d (0)                        [500, 1, 28, 28]          [500, 8, 14, 14]          [3, 3]
│    └─ReLU (1)                          [500, 8, 14, 14]          [500, 8, 14, 14]          --
├─Sequential (1)                         [500, 8, 14, 14]          [500, 16, 7, 7]           --
│    └─Conv2d (0)                        [500, 8, 14, 14]          [500, 16, 7, 7]           [3, 3]
│    └─ReLU (1)                          [500, 16, 7, 7]           [500, 16, 7, 7]           --
├─Sequential (2)                         [500, 16, 7, 7]           [500, 32, 4, 4]           --
│    └─Conv2d (0)                        [500, 16, 7, 7]           [500, 32, 4, 4]           [3, 3]
│    └─ReLU (1)                          [500, 32, 4, 4]           [500, 32, 4, 4]           --
├─Sequential (3)                         [500, 32, 4, 4]           [500, 64, 2, 2]           --
│    └─Conv2d (0)                        [500, 32, 4, 4]           [500, 64, 2, 2]           [3, 3]
│    └─ReLU (1)                          [500, 64, 2, 2]           [500, 64, 2, 2]           --
├─Sequential (4)                         [500, 64, 2, 2]           [500, 10, 1, 1]           --
│    └─Conv2d (0)                        [500, 64, 2, 2]           [500, 10, 1, 1]           [3, 3]
├─Flatten (5)                            [500, 10, 1, 1]           [500, 10]                 --
===================================================================================================================
Total params: 30,154
Trainable params: 30,154
Non-trainable params: 0
Total mult-adds (M): 113.45
===================================================================================================================
Input size (MB): 1.57
Forward/backward pass size (MB): 12.52
Params size (MB): 0.12
Estimated Total Size (MB): 14.21
===================================================================================================================

Trainer Summaries