from fastai.vision.all import *
= untar_data(URLs.MNIST,force_download=True)
path from sklearn.model_selection import StratifiedKFold
from numpy.random import default_rng
1 Intro
Goal: The goal of this article is to provide an understanding of what pseudo labeling is, why you might use it, and how you would go about using it.
What’s Included in this post: The information needed to get started on pseudo labeling on your project.
2 What is pseudo labeling
Pseudo Labeling is the process of creating new labels for a piece of data.
The general idea can be broken into a few steps
- Create a model
- Make predictions on some data with that model
- Pretend all (or some) of those predictions are ground truth label
- Train a new model with those predictions
We will get into more of the details in the how-to section!
3 Why would I use pseudo labeling?
There are two main functions pseudo Labeling can be used for. I will step through each and provide a general
3.1 Data Cleaning & Noise Reduction
Imagine you have a dataset and all the samples have been hand labeled. You know they can’t all be labeled appropriately because it was manual labeling and you want to improve your labels. You have a few options:
- Go through every datapoint again and manually verify them all
- Somehow identify the ones that are likely to be wrong and put more focus on those ones.
Pseudo labeling can help with option number 2. By creating a prediction on a datapoint, you can see which labels the model disagrees with. Even better, you can look at the confidence that model has in the prediction. So by looking at datapoints that the model is confident are wrong, you can really narrow your focus on your problem areas quickly.
You then can fix it in 2 ways:
- Replace your labels with the predicted labels following some threshold (ie score of .9 or higher).
- Manually re-classify these labels if you have the time and domain expertise to do these.
3.2 Data Augmentation
This approach can also be used on unlabeled data. Rather than trying to replace bad labels, this approach focuses on creating labels for unlabeled data. This can be used on a kaggle test set for example. The reason this can work is because you are teaching the model the structure of the data. Even if not all labels are correct, a lot can still be learning.
Think about if you were to learn what a new type of object looks like. Maybe a type of furniture you’d never heard of before. Doing a google image search for that name and looking at all the results is really helpful, even if not all of the images that are shown are all correct.
4 How to use pseudo labeling (Noise Reduction)
I will use the validation set for this example because it is a little bit more involved. You can simplify this approach a bit if you are doing this on unlabeled data.
4.1 Imports
4.2 Introduce Noise to Data
= get_image_files(path)
x = L(parent_label(o) for o in get_image_files(path)) y
Get 10% of the indexes to randomly change
= len(x)
n = default_rng()
rng
= rng.choice(n, size=round(n*0.1), replace=False)
noise_idxs len(noise_idxs),noise_idxs[:5]
(7000, array([17419, 48844, 61590, 49810, 26348]))
Randomly change these so we have some bad labels
for i in range(0,len(noise_idxs)):
= str(x[noise_idxs[i]])
old_path
if 'training' in old_path:
= str(x[noise_idxs[i]])[:49]+f'{np.random.randint(0,10)}'+str(x[noise_idxs[i]])[50:]
new_path elif 'testing' in old_path:
= str(x[noise_idxs[i]])[:48]+f'{np.random.randint(0,10)}'+str(x[noise_idxs[i]])[49:]
new_path
f'mv {old_path} {new_path}') os.system(
4.3 Look at Data
Some of our labels are now labeled, but we don’t know which ones. We could look at every image to find them, but that would take a ton of time. Let’s try to find the mislabeled images and correct them using a pseudo labeling approach.
= DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock),
mnist =get_image_files,
get_items=RandomSplitter(),
splitter=parent_label)
get_y= mnist.dataloaders(path,bs=16)
dls =36,figsize=(6,6)) dls.show_batch(max_n
4.4 Create Crossfold, Train, and Predict
This step is much simpler if you are generating labels for the test set, as you would train your model as normal and predict as normal. The reason I am doing cross folds is to get predicted labels on the training set.
I am doing this with a 2 fold, but you may want to use a 5-fold or more folds.
This cross-fold code was mostly supplied by Zach Mueller, with minor modifications by me for this dataset and tutorial. There is also a tutorial he wrote with more details here
= StratifiedKFold(n_splits=2, shuffle=True, random_state=1)
skf = [],[],[],[], []
splits, preds, targs, preds_c, items
for _, val_idx in skf.split(x,y):
= IndexSplitter(val_idx)
splitter
splits.append(val_idx)
= DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock),
mnist =get_image_files,
get_items=splitter,
splitter=parent_label)
get_y
= mnist.dataloaders(path,bs=16)
dls = cnn_learner(dls,resnet18,metrics=accuracy)
learn 2,reset_opt=True)
learn.fine_tune(
# store predictions
= learn.get_preds(ds_idx=1,with_decoded=True)
p, t, c ; targs.append(t); preds_c.append(c); items.append(dls.valid.items) preds.append(p)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 1.310182 | 1.102497 | 0.722114 | 01:01 |
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.696935 | 0.617736 | 0.886114 | 01:20 |
1 | 0.631840 | 0.570121 | 0.895343 | 01:21 |
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 1.379348 | 1.098212 | 0.715343 | 01:04 |
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.725246 | 0.607686 | 0.888714 | 01:21 |
1 | 0.625833 | 0.557313 | 0.897943 | 01:21 |
4.5 Look at Predictions
Lets throw it all in a dataframe so we can look at what we have a little easier. First, let’s break out our different pieces of information.
= L(itertools.chain.from_iterable(items))
items_flat = L(o for o in items_flat)
imgs = L(int(parent_label(o)) for o in items_flat) # Labels from dataset
y_true = L(int(o) for o in torch.cat(targs)) # Labels from out predictions
y_targ = L(int(o) for o in torch.cat(preds_c)) # predicted labels or "pseudo labels"
y_pred = torch.cat(preds).max(dim=1)[0] # max model score for row p_max
We can double check we are matching things up correctly by checking that the labels line up from the predictions and the original data. Throwing some simple assert statements in is nice because it takes no time and it will let you know if you screw something up later as you are tinkering with things.
assert (y_true == y_targ) # test we matched these up correct
Put it in a dataframe and see what we have.
= pd.DataFrame({'imgs':imgs,'y_true':y_true,'y_pred':y_pred,'p_max':p_max})
res 5) res.head(
imgs | y_true | y_pred | p_max | |
---|---|---|---|---|
0 | /home/isaacflath/.fastai/data/mnist_png/testing/1/8418.png | 1 | 1 | 0.864995 |
1 | /home/isaacflath/.fastai/data/mnist_png/testing/1/2888.png | 1 | 7 | 0.900654 |
2 | /home/isaacflath/.fastai/data/mnist_png/testing/1/6482.png | 1 | 1 | 0.906335 |
3 | /home/isaacflath/.fastai/data/mnist_png/testing/1/7582.png | 1 | 1 | 0.902999 |
4 | /home/isaacflath/.fastai/data/mnist_png/testing/1/4232.png | 1 | 1 | 0.925955 |
Perfect so lets get a list of our images our model got ‘wrong’ and grab some random ones out of the top 5000 the model was most confident about. The theory is that many of these may be mislabeled, and we can reclassify them either using the predicted ‘pseudo’ labels, or with manual classification.
= res[res.y_true != res.y_pred].sort_values('p_max',ascending=False)[:5000].sample(frac=1) imgs
And then we plot them and see our predicted labels of these are WAY better than the actual labels. A great way to identify some bad labels.
%matplotlib inline
= plt.subplots(5,5,figsize=(10,10))
fig, ax
for row in range(0,5):
for col in range(0,5):
= imgs.iloc[row*4+col,0]
img_path1 = np.array(Image.open(img_path1))
img_path1 ='Greys')
ax[row,col].imshow(img_path1,cmapf'Label:{parent_label(imgs.iloc[row*4+col,0])} | Pred:{imgs.iloc[row*4+col,2]}')
ax[row,col].set_title(False)
ax[row,col].get_xaxis().set_visible(False) ax[row,col].get_yaxis().set_visible(
4.6 What Next?
Now that we have found mislabeled data, we can fix them. We see that in this problem in the top 5000 most confident wrong answers our predicted labels are much better.
So the next step would be to replace the labels with our predicted labels, then train our model on the newly cleaned labels!
Note: This same approach can be used on unlabeled data to get data points the model is confident in to expand the training data.