This post will walk through how to do use soft labeling in fastai, and demonstrate how it helps with noisy labels to improve training and your metrics.

This post was inspired by a 1st place kaggle submission (not mine), so we know it's a good idea! The repo for that is here which is done in pytorch lightning. This post will use fastai.

Let's get started!


from import *
path = untar_data(URLs.IMAGEWOOF)
from sklearn.model_selection import StratifiedKFold
from numpy.random import default_rng

Get Noisy Data

I am using the noisy datasets repo that was hugely inspired by the noisy imagenette repository to get noisy labels for the imagewoof dataset.

First we get the noisy imagewoof csv, then use that to build the dataloaders.

#this code is taken from the noisy imagenette github repo linked above with slight modifications
def get_dls(size, woof, pct_noise, bs, splitter=ColSplitter()):
    path = untar_data(URLs.IMAGEWOOF)
    df = pd.read_csv('')
    df = df.loc[df.is_valid==False]
    batch_tfms = [Normalize.from_stats(*imagenet_stats)]
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       get_x=ColReader('path', pref=path), 
                       item_tfms=[RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)],
    return dblock.dataloaders(df, bs=bs)

dls = get_dls(224,woof=True,pct_noise=5,bs=16)

Create Crossfold, Train, and Predict

The reason I am doing cross folds is to get predicted labels on the training set. The predicted labels on the training set are using labels each model was not trained on.

Note: I am doing this with a 2 fold, but you may want to use a 5-fold or more folds.
This cross-fold code was mostly supplied by Zach Mueller, with minor modifications by me for this dataset and tutorial. There is also a tutorial he wrote with more details here
df = pd.read_csv('')
train_df = df.loc[df.is_valid==False]
path truth noisy_labels_1 noisy_labels_5 noisy_labels_25 noisy_labels_50 is_valid
0 train/n02111889/n02111889_5826.JPEG n02111889 n02111889 n02111889 n02111889 n02111889 False
1 train/n02111889/n02111889_1944.JPEG n02111889 n02111889 n02111889 n02111889 n02086240 False
2 train/n02111889/n02111889_17657.JPEG n02111889 n02111889 n02111889 n02111889 n02111889 False
skf = StratifiedKFold(n_splits=2, shuffle=True, random_state=1)
splits, preds, targs, preds_c,  = [],[],[],[]
items = pd.DataFrame(columns = ['path', 'noisy_labels_1', 'noisy_labels_5', 'noisy_labels_25','noisy_labels_50', 'is_valid'])

for _, val_idx in skf.split(train_df.path,train_df.noisy_labels_5):
    splitter = IndexSplitter(val_idx)

    dls = get_dls(224,woof=True,pct_noise=5,bs=16,splitter=splitter)

    learn = cnn_learner(dls,resnet18,metrics=[accuracy,RocAuc()])
    # store predictions
    p, t, c = learn.get_preds(ds_idx=1,with_decoded=True)
    preds.append(p); targs.append(t); preds_c.append(c); 
    items = pd.concat([items,dls.valid.items])
epoch train_loss valid_loss accuracy roc_auc_score time
0 1.036113 0.780265 0.841569 0.963737 00:49
epoch train_loss valid_loss accuracy roc_auc_score time
0 0.820277 0.695468 0.868823 0.967281 00:58
1 0.912045 0.729069 0.846444 0.965334 01:01
2 0.728870 0.716678 0.848659 0.964523 00:59
3 0.640925 0.717469 0.848659 0.964562 00:59
4 0.620054 0.712924 0.839575 0.963071 01:00
5 0.502611 0.703821 0.850210 0.964302 00:59
6 0.352605 0.727077 0.858631 0.965233 00:59
7 0.304586 0.729460 0.864392 0.964620 00:59
8 0.249052 0.723166 0.858631 0.965410 00:56
9 0.180612 0.732588 0.859739 0.965431 00:52
epoch train_loss valid_loss accuracy roc_auc_score time
0 1.166885 0.712075 0.841090 0.969210 00:32
epoch train_loss valid_loss accuracy roc_auc_score time
0 0.853930 0.630505 0.868351 0.971000 00:42
1 0.841145 0.655725 0.860372 0.968649 00:42
2 0.834104 0.692984 0.839539 0.967488 00:42
3 0.658782 0.686378 0.854167 0.968132 00:42
4 0.606825 0.703417 0.846853 0.966847 00:42
5 0.503965 0.687867 0.843528 0.966903 00:42
6 0.409177 0.686660 0.857713 0.967467 00:41
7 0.340657 0.690113 0.858821 0.967676 00:42
8 0.237057 0.685794 0.866356 0.967809 00:42
9 0.222322 0.681753 0.865470 0.968221 00:42

Look at Predictions

Lets throw it all in a dataframe so we can look at what we have a little easier. First, let's break out our different pieces of information.

imgs = L(o for o in items.path.values)
y_true = L(o for o in items.noisy_labels_5.values) # Labels from dataset
y_targ = L(dls.vocab[o] for o in # Labels from out predictions
y_pred = L(dls.vocab[o] for o in # predicted labels or "pseudo labels"
p_max =[0] # max model score for row

We can double check we are matching things up correctly by checking that the labels line up from the predictions and the original data. Throwing some simple assert statements in is nice because it takes no time and it will let you know if you screw something up later as you are tinkering with things.

assert (y_true == y_targ) # test we matched these up correct

Put it in a dataframe and see what we have.

res = pd.DataFrame({'imgs':imgs,'y_true':y_true,'y_pred':y_pred}).set_index('imgs')
(9025, 2)
(12954, 7)
y_true y_pred
train/n02086240/n02086240_6323.JPEG n02086240 n02086240
train/n02093754/n02093754_696.JPEG n02093754 n02093754
train/n02089973/n02089973_12157.JPEG n02089973 n02089973
train/n02096294/n02096294_4188.JPEG n02096294 n02096294
train/n02086240/n02086240_6595.JPEG n02086240 n02115641

Soft Labeling Setup

Now, we have all the data we need to train a model with soft labels. To recap we have:

  1. Dataloaders with noisy labels
  2. Dataframe with img path, y_true, and y_pred (pseudo labels we generated in the cross-fold above)

Now, we will need to convert things to one-hot encoding, so let's do that for our dataframe

res = pd.get_dummies(res,columns=['y_true','y_pred'])

Now, lets change the Loss Function and Metric to support one hot encoded targets

class CrossEntropyLossOneHot(nn.Module):
    def __init__(self):
        super(CrossEntropyLossOneHot, self).__init__()
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, preds, labels):
        return torch.mean(torch.sum(-labels * self.log_softmax(preds), -1))

def accuracy(inp, targ, axis=-1):
    "Compute accuracy with `targ` when `pred` is bs * n_classes"
    pred,targ = flatten_check(inp.argmax(dim=axis), targ.argmax(dim=axis))
    return (pred == targ).float().mean()

Soft Labeling CallBack

Finally, lets write the callback that does the Soft Labeling.

There's a few components to this. To put in english what is happening below in each section:

  • before_train and before_validate: This is grabbing the list of images for the entire dataloader. We don't need to do this every batch, so it fits well here;
  • before_batch: This filters the list of images that was defined down to only the images in our batch. From there, it one hot encodes the y variable, and if it's a training batch it does y_true 0.7 + y_pred 0.3. We don't want to smooth the validation set as we want a good representation of what the metrics would be on a separate test set. This is the core of soft labeling.

The intuition for this is that the labels that the model in the crossfold got wrong above have a higher chance of just being incorrect labels. So we smooth those out to punish incorrect classifications less.

Note: You can set thresholds for soft labeling to smooth more or less based on the confidence your predicted labels have. I don’t have that built into this callback, but it is something you can experiment with!
This Callback a collaboration:
  • Zach Mueller got me started with the callback system in fastai, particularly around dataloader batch indexing in fastai
  • Kevin H. and I collaborated on this. We were working on this for a joint project, and we were both running experiments to get it to work right and perform.
class SoftLabelCB(Callback):
    def __init__(self, df_preds,y_true_weight = 0.7): 
        '''df_preds is a pandas dataframe where index is image paths
             Must have y_true and y_pred one hot encoded columns (ie y_true_0, y_true_1)
        self.y_true_weight = y_true_weight
        self.y_pred_weight = 1 - y_true_weight
        self.df = df_preds

    def before_train(self):
        if type(self.dl.items)==type(pd.DataFrame()): self.imgs_list = L(o for o in self.dl.items.iloc[:,0].values)
        if is_listy(self.dl.items): self.imgs_list = L(self.dl.items)      
    def before_validate(self):
        if type(self.dl.items)==type(pd.DataFrame()): self.imgs_list = L(o for o in self.dl.items.iloc[:,0].values)
        if is_listy(self.dl.items): self.imgs_list = L(self.dl.items)       
    def before_batch(self):
        # get the images' names for the current batch
        imgs = self.imgs_list[self.dl._DataLoader__idxs[self.iter**]]
        # get soft labels
        df = self.df
        soft_labels = df.loc[imgs,df.columns.str.startswith('y_true')].values
            soft_labels = soft_labels*self.y_true_weight + df.loc[imgs,df.columns.str.startswith('y_pred')].values*self.y_pred_weight
        self.learn.yb = (Tensor(soft_labels).cuda(),)

Train the Model and Results

Then we put the callback and our one hot metric and loss function into a learning and fine tune it. As you can see, we get a small bump in both accuracy and roc_auc score.

This is training on the same data that the last of the crossfolds was, so it's a good comparison.

  • without soft labeling: max accuracy was 86.8%, which was hit very early on and then did not see improvements for 8 more epochs.
  • With soft labeling: max accuracy was 88%, over 1% higher than without soft labeling. In addition, the last 4 epochs showed epoch over epoch improvements to the metric and loss with the last epoch being the highest accuracy. We can almost certainly train longer to see even better benefits.
learn = cnn_learner(dls,resnet18,metrics=[accuracy,RocAuc()],loss_func=CrossEntropyLossOneHot(),cbs=SoftLabelCB(res))
epoch train_loss valid_loss accuracy roc_auc_score time
0 1.045338 0.761228 0.843528 0.968725 00:33
epoch train_loss valid_loss accuracy roc_auc_score time
0 0.781899 0.655288 0.869016 0.970987 00:42
1 0.777672 0.666892 0.860594 0.970698 00:43
2 0.734413 0.635256 0.852837 0.969718 00:43
3 0.611115 0.629129 0.860816 0.969812 00:43
4 0.575180 0.618686 0.863475 0.970406 00:43
5 0.464586 0.602176 0.871897 0.969895 00:43
6 0.433750 0.608785 0.867021 0.971009 00:43
7 0.414037 0.597265 0.873005 0.970817 00:45
8 0.344089 0.597751 0.875665 0.971109 00:45
9 0.314052 0.582850 0.880541 0.971416 00:45