visalizations

Visualizations to help understand what’s going on
xmean,xstd = 0.28, 0.35
@inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]

_dataset = sample_dataset_dict(load_dataset('fashion_mnist').with_transform(transformi),(2000,500))
dls = DataLoaders.from_dataset_dict(_dataset, 256, num_workers=4)
100%|██████████| 2/2 [00:00<00:00, 356.19it/s]

Learning Rate Finder


source

LRFinderCB

 LRFinderCB (lr_mult=1.3)

Initialize self. See help(type(self)) for accurate signature.

trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.SGD, 
                  get_model_conv(), 
                  callbacks=[BasicTrainCB(),DeviceCB(),MetricsCB(Accuracy=MulticlassAccuracy())])
trainer.fit(callbacks=[LRFinderCB()])

Activation Stats


source

HooksCallback

 HooksCallback (hookfunc, module_filter=<function noop>, on_train=True,
                on_valid=False, modules=None)

Initialize self. See help(type(self)) for accurate signature.


source

get_hist

 get_hist (h)

source

append_stats

 append_stats (hook, module, inp, outp)

source

ActivationStatsCB

 ActivationStatsCB (module_filter=<function noop>)

Initialize self. See help(type(self)) for accurate signature.


source

get_min

 get_min (h)
model = get_model_conv()
model.to(def_device)
xb = fc.first(dls.train)[0]
for module in model.modules(): lsuv_init(model, module, xb.to(def_device))
trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.Adam, 
                  model, 
                  callbacks=[BasicTrainCB(),
                             MetricsCB(Accuracy=MulticlassAccuracy()), 
                             DeviceCB(), 
                             ProgressCB(), 
                             ActivationStatsCB(fc.risinstance(nn.Conv2d))])
trainer.fit()
train valid Accuracy
0 1.617802 1.350127 0.56
  train valid Accuracy
1 1.143669 1.029318 0.676000
  train valid Accuracy
2 0.912712 0.920224 0.684000
trainer.ActivationStatsCB.color_dim()

trainer.ActivationStatsCB.dead_chart()

trainer.ActivationStatsCB.plot_stats()