from AIsaac.all import *
import torch
from datasets import load_dataset
from torch import nn
from torcheval.metrics import MulticlassAccuracy
import torchvision.transforms.functional as TF
Fashion MNIST
=2, linewidth=140, sci_mode=False)
torch.set_printoptions(precision1)
torch.manual_seed(42) set_seed(
= 0.28, 0.35
xmean,xstd @inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]
= load_dataset('fashion_mnist').with_transform(transformi)
_dataset = DataLoaders.from_dataset_dict(_dataset, 256, num_workers=4) dls
dls.show_batch()
Basic Trainer
1,True)
set_seed(= Trainer(dls,
trainer
nn.CrossEntropyLoss(),
torch.optim.Adam, =nn.BatchNorm2d),
get_model_conv(norm=[CoreCBs(Accuracy=MulticlassAccuracy()),OneCycleSchedulerCB()])
callbacks5,lr=.01) trainer.fit(
Timm Model
= get_model_timm('resnet18', pretrained=True,num_classes=10,in_chans=1) model
1,True)
set_seed(= Trainer(dls,
trainer
nn.CrossEntropyLoss(),
torch.optim.Adam,
model,=[CoreCBs(Accuracy=MulticlassAccuracy()),OneCycleSchedulerCB()])
callbacks5,lr=.01) trainer.fit(
Looking at Trainer
trainer.summarize_callbacks()
Looking at Model
trainer.summarize_model()