Pseudo Labeling for Data Cleaning
Pseudo Labeling basics for Noise Reduction and Data Creation
- Intro
- What is pseudo labeling
- Why would I use pseudo labeling?
- How to use pseudo labeling (Noise Reduction)
What is pseudo labeling
Pseudo Labeling is the process of creating new labels for a piece of data.
The general idea can be broken into a few steps:
- Create a model
- Make predictions on some data with that model
- Pretend all (or some) of those predictions are ground truth label
- Train a new model with those predictions
We will get into more of the details in the how-to section!
Why would I use pseudo labeling?
There are two main functions pseudo Labeling can be used for. I will step through each and provide a general
Data Cleaning & Noise Reduction
Imagine you have a dataset and all the samples have been hand labeled. You know they can't all be labeled appropriately because it was manual labeling and you want to improve your labels. You have a few options:
- Go through every datapoint again and manually verify them all
- Somehow identify the ones that are likely to be wrong and put more focus on those ones.
Pseudo labeling can help with option number 2. By creating a prediction on a datapoint, you can see which labels the model disagrees with. Even better, you can look at the confidence that model has in the prediction. So by looking at datapoints that the model is confident are wrong, you can really narrow your focus on your problem areas quickly.
You then can fix it in 2 ways:
- Replace your labels with the predicted labels following some threshold (ie score of .9 or higher).
- Manually re-classify these labels if you have the time and domain expertise to do these.
Data Augmentation
This approach can also be used on unlabeled data. Rather than trying to replace bad labels, this approach focuses on creating labels for unlabeled data. This can be used on a kaggle test set for example. The reason this can work is because you are teaching the model the structure of the data. Even if not all labels are correct, a lot can still be learning.
Think about if you were to learn what a new type of object looks like. Maybe a type of furniture you'd never heard of before. Doing a google image search for that name and looking at all the results is really helpful, even if not all of the images that are shown are all correct.
I will use the validation set for this example because it is a little bit more involved. You can simplify this approach a bit if you are doing this on unlabeled data.
from fastai.vision.all import *
path = untar_data(URLs.MNIST,force_download=True)
from sklearn.model_selection import StratifiedKFold
from numpy.random import default_rng
x = get_image_files(path)
y = L(parent_label(o) for o in get_image_files(path))
Get 10% of the indexes to randomly change
n = len(x)
rng = default_rng()
noise_idxs = rng.choice(n, size=round(n*0.1), replace=False)
len(noise_idxs),noise_idxs[:5]
Randomly change these so we have some bad labels
for i in range(0,len(noise_idxs)):
old_path = str(x[noise_idxs[i]])
if 'training' in old_path:
new_path = str(x[noise_idxs[i]])[:49]+f'{np.random.randint(0,10)}'+str(x[noise_idxs[i]])[50:]
elif 'testing' in old_path:
new_path = str(x[noise_idxs[i]])[:48]+f'{np.random.randint(0,10)}'+str(x[noise_idxs[i]])[49:]
os.system(f'mv {old_path} {new_path}')
Some of our labels are now misclassified, but we don't know which ones. We could look at every image to find them, but that would take a ton of time. Let's try to find the misclassified images and correct them using a pseudo labeling approach.
mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=parent_label)
dls = mnist.dataloaders(path,bs=16)
dls.show_batch(max_n=36,figsize=(6,6))
Create Crossfold, Train, and Predict
This step is much simpler if you are generating labels for the test set, as you would train your model as normal and predict as normal. The reason I am doing cross folds is to get predicted labels on the training set.
skf = StratifiedKFold(n_splits=2, shuffle=True, random_state=1)
splits, preds, targs, preds_c, items = [],[],[],[], []
for _, val_idx in skf.split(x,y):
splitter = IndexSplitter(val_idx)
splits.append(val_idx)
mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock),
get_items=get_image_files,
splitter=splitter,
get_y=parent_label)
dls = mnist.dataloaders(path,bs=16)
learn = cnn_learner(dls,resnet18,metrics=accuracy)
learn.fine_tune(2,reset_opt=True)
# store predictions
p, t, c = learn.get_preds(ds_idx=1,with_decoded=True)
preds.append(p); targs.append(t); preds_c.append(c); items.append(dls.valid.items)
items_flat = L(itertools.chain.from_iterable(items))
imgs = L(o for o in items_flat)
y_true = L(int(parent_label(o)) for o in items_flat) # Labels from dataset
y_targ = L(int(o) for o in torch.cat(targs)) # Labels from out predictions
y_pred = L(int(o) for o in torch.cat(preds_c)) # predicted labels or "pseudo labels"
p_max = torch.cat(preds).max(dim=1)[0] # max model score for row
We can double check we are matching things upp correctly by checking that the labels line up from the predictions and the original data. Throwing some simple assert statements in is nice because it takes no time and it will let you know if you screw something up later as you are tinkering with things.
assert (y_true == y_targ) # test we matched these up correct
Put it in a dataframe and see what we have.
res = pd.DataFrame({'imgs':imgs,'y_true':y_true,'y_pred':y_pred,'p_max':p_max})
res.head(5)
Perfect so lets get a list of our images our model got 'wrong' and grab some random ones out of the top 5000 the model was most confident about. The theory is that many of these may be misclassified, and we can reclassify them either using the predicted 'pseudo' labels, or with manual classification.
imgs = res[res.y_true != res.y_pred].sort_values('p_max',ascending=False)[:5000].sample(frac=1)
And then we plot them and see our predicted labels of these are WAY better than the actual labels. A great way to identify some bad labels.
%matplotlib inline
fig, ax = plt.subplots(5,5,figsize=(10,10))
for row in range(0,5):
for col in range(0,5):
img_path1 = imgs.iloc[row*4+col,0]
img_path1 = np.array(Image.open(img_path1))
ax[row,col].imshow(img_path1,cmap='Greys')
ax[row,col].set_title(f'Label:{parent_label(imgs.iloc[row*4+col,0])} | Pred:{imgs.iloc[row*4+col,2]}')
ax[row,col].get_xaxis().set_visible(False)
ax[row,col].get_yaxis().set_visible(False)
Now that we have found mislabeled data, we can fix them. We see that in this problem in the top 5000 most confident wrong answers our predicted labels are much better.
So the next step would be to replace the labels with our predicted labels, then train our model on the newly cleaned labels!