models

Model and architecture tooling

Imports

Setup

Load Data

xmean,xstd = 0.28, 0.35

@inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]

_dataset = load_dataset('fashion_mnist').with_transform(transformi)
_dataset = sample_dataset_dict(_dataset)
dls = DataLoaders.from_dataset_dict(_dataset, 64, num_workers=4)
xb = fc.first(dls.train)[0]

Models


source

get_model_timm

 get_model_timm (model_name, pretrained=False, pretrained_cfg=None,
                 checkpoint_path='', scriptable=None, exportable=None,
                 no_jit=None)

Loads model from timm, see timm.list_models for options

model = get_model_timm('resnet18', pretrained=True,num_classes=10,in_chans=1).to(def_device)
fc.test_eq(model(fc.first(dls.train)[0].to(def_device)).shape,(64,10))

source

get_model_conv

 get_model_conv (act=<class 'torch.nn.modules.activation.ReLU'>, nfs=None,
                 norm=None)

source

conv

 conv (ni, nf, kernel_size=3, stride=2, act=<class
       'torch.nn.modules.activation.ReLU'>, norm=None, bias=None)
model = get_model_conv()
fc.test_eq(model(fc.first(dls.train)[0].to(def_device)).shape,(64,10))