Handling Imbalanced Classes with Weighted Loss in PyTorch

When it comes to real world data collections, we don’t have the prestige of having perfectly balanced labelled datasets for training models. Most of the machine learning algorithms are not immune for imbalanced classes and cause less accurate and biased models. There are many approaches that we can follow to tackle imbalanced data problem. Either we have to choose a ML algorithm which is reluctant for imbalanced data or we may have to generate synthetic data in order to make the classes balanced.

Neural networks are trained using backpropagation which treats each class same when calculating the loss. If the data is not balanced, that makes the model bias for one class than another.

A, B, C, D classes are imbalanced

I had to face this issue when experimenting with a computer vision based multi-class classification problem. The data I had was so much skewed and some classes had a very less amount of data compared to the majority class. Model was not performing well at all and need to take some actions to tackle the class imbalance problem.

These were the solutions I thought of try out.

  1. Creating synthetic data –
    Creating new synthetic data points is one of the main methods which is used mostly for numerical data and in some cases in imagery data too with the help of GAN and image augmentations. As in the starting point, I took the decision not to go with synthetic data generation since it may introduce abnormal characteristics to my dataset. So I keep that for a later part.
  2. Sampling the dataset with balanced classes –
    In this approach, what we normally do is, sample the dataset similar number of samples for each data label. For an example, will say we have a dataset which is having 3 classes named A, B & C with 100, 50, 20 data points for each class accordingly. When sampling what we do is randomly selecting 20 samples from each A, B & C classes and get a dataset with 60 data points.

In some cases this approach comes as a better option if we have very large amounts of data for each class (Even for the minority classes) In my case, I was not able to take the cost of loosing a huge portion of my data just by sampling it based on the data points having in the minority class.

Since both methods were not going well for me, I used a weighted loss function for training my neural network. Since this is a multi-class classification problem, I used Cross Entropy Loss in PyTorch as my loss function. (You can follow the similar approach if you using BCELoss for binary classification too)

import torch.nn as nn

#class weights for 6 class multi-class classification
class_weights = [0.5281, 0.8411, 0.9619, 0.8634, 0.8477, 0.9577]

#loss function with class weights
criterion = nn.CrossEntropyLoss(weight = class_weights) 

How I calculated the weight for each class? –

This is so simple. What I did was calculating a manual re-scaling weight for each class and pass it to “weight” parameter in the loss function. Make sure that you have a Tensor with the size of number of classes as the class weights. (In simpler words each class should have a weight).

Hint : If you using GPU for model training, make sure to put your class weights tensor to the GPU too.

Did it worked? Hell yeah! I was able to train my model accurately with less bias and without overfitting for a single class by using this simple trick. Let me know any other trick you use for training neural network models with imbalanced data.

Happy coding 🙂

4 thoughts on “Handling Imbalanced Classes with Weighted Loss in PyTorch

Leave a reply to Haritha Thilakarathne Cancel reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.