Handling corrupted data in Pytorch Dataloader

Recently, while working on a video dataset, I noticed that some of the videos contained a few corrupted frames. While running the training, my dataloader used to return an incorrect shaped tensor since it was not able to read the corrupted frame. In this post, I will walk you through the process of utlilising Pytorch’s collate_fn for overcoming this issue. I came across this solution in a Github comment posted by Tsveti Iko.

First and foremost, start returning None from the dataset’s __getitem__ function for corrupted item. For eg. for my video dataset, I started checking the shape of the tensor and returned None if it didn’t match the expected shape. The expected shape was (30, 3, 224, 224).

from torch.utils.data import Dataset
import torch
import torchvision

class VideoDataset(Dataset):

    ...
    
    def __getitem__(self, idx):
        labels = labels[idx]
        labels = torch.tensor(labels)
        
        vid, audio, dict = torchvision.io.read_video(filename=video_list[idx])

        if vid.shape[0]!= 30 or vid.shape[1]!=3 or vid.shape[2]!=224 or vid.shape[3]!=224:
            return None
        
        return vid, labels

Next, define a collate function that filters the None records.

import torch

def collate_fn(self, batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

Finally, use the collate_fn while defining the dataloader.

from torch.utils.data import DataLoader
import os

dataset = VideoDataset()

dataloader = DataLoader(dataset, 
    batch_size=4, 
    shuffle=True, 
    num_workers=os.cpu_count() - 1, 
    pin_memory=True,
    collate_fn=collate_fn)
Vivek Maskara
Vivek Maskara
SDE @ Remitly

SDE @ Remitly | Graduated from MS CS @ ASU | Ex-Morgan, Amazon, Zeta

Related