Fashion MNIST

from AIsaac.all import *
import torch
from datasets import load_dataset
from torch import nn
from torcheval.metrics import MulticlassAccuracy
import torchvision.transforms.functional as TF
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
set_seed(42)
xmean,xstd = 0.28, 0.35
@inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]

_dataset = load_dataset('fashion_mnist').with_transform(transformi)
dls = DataLoaders.from_dataset_dict(_dataset, 256, num_workers=4)
dls.show_batch()

Basic Trainer

set_seed(1,True)
trainer = Trainer(dls,
              nn.CrossEntropyLoss(), 
              torch.optim.Adam, 
              get_model_conv(norm=nn.BatchNorm2d),
              callbacks=[CoreCBs(Accuracy=MulticlassAccuracy()),OneCycleSchedulerCB()])
trainer.fit(5,lr=.01)

Timm Model

model = get_model_timm('resnet18', pretrained=True,num_classes=10,in_chans=1)
set_seed(1,True)
trainer = Trainer(dls,
              nn.CrossEntropyLoss(), 
              torch.optim.Adam, 
              model,
              callbacks=[CoreCBs(Accuracy=MulticlassAccuracy()),OneCycleSchedulerCB()])
trainer.fit(5,lr=.01)

Looking at Trainer

trainer.summarize_callbacks()

Looking at Model

trainer.summarize_model()