Artificial Intelligence is the technology that teaches machines to learn the hidden pattern enclosed in the training data. These machines are trained to get the most accurate predictions according to the provided data, and they can be evaluated by using different methods like accuracy, recall, precision, etc. These methods explain the accuracy of the model, but the loss function explains the problems in the model.
Table of Content
This guide explains the following sections:
- What is Focal Loss in PyTorch
- Difference Between Focal Loss and Cross Entropy Loss
- How to Calculate Focal Loss in PyTorch
- Prerequisites
- Example 1: Binary Classification
- Example 2: Multi-Class Classification
- Example 3: Multi-Class Sequence Classification
- Conclusion
What is Focal Loss in PyTorch
There are many loss functions used to optimize Machine Learning and Deep Learning(DL) models to get rid of different problems. The Focal Loss function is used to rectify the class im-balancing problem which means that if the data set is more focused on one class, and other classes are neglected heavily. The model tends to lean toward the dominant class and find its insights leaving others unattended. The resultant model won’t be very effective when extracting the information generally and doesn’t give a clearer picture.
The mathematical representation of the focal loss in deep learning models is as follows:
1-Pt: it adds the focus on the non-dominant classes
?: determines the complexities of the samples
log(Pt): cross-entropy to get the logarithm loss of the dominant class
Difference Between Focal Loss and Cross Entropy Loss
The focal loss is focused on minimizing the errors related to the class imbalance in the dataset while training the DL models. The focal loss takes both dominant and non-dominant classes into account and adds the (1-pt) in the cross-entropy loss function. The cross-entropy doesn’t handle the complexity of samples which is considered in the focal loss functions.
How to Calculate Focal Loss in PyTorch
To calculate the focal loss in PyTorch and get the optimal deep-learning model, go through the following examples:
Note: The Python code used to implement these examples is attached here
Prerequisites
Before getting through the examples, it is required to complete the following prerequisites:
- Access Python Notebook
- Install Modules
- Import Libraries
Access Python Notebook
The first requirement for implementing the Python code is to access the notebook to write the Python code. This guide uses the Google Colab notebook that can be created by clicking on the “New Notebook” button from its official website:
Install Modules
Once the notebook is created, install the focal_loss_torch module using the pip command to get the functions to calculate the focal loss:
pip install focal_loss_torch
Import Libraries
The last prerequisite is importing the torch library by executing the following code:
import torch
Confirm the installation process by displaying the installed version which in our case is 2.1.0+cu118:
print(torch.__version__)
Example 1: Binary Classification
The first example calculates the focal loss for the binary classification which categorizes new observations in one of two classes. Import the FocalLoss library from the installed module and give numerous values of the gamma with different parameters like weight and ignore_index. Store the FocalLoss() method in the criterion variable that can be used to calculate the focal loss in PyTorch:
from focal_loss.focal_loss import FocalLoss
criterion = FocalLoss(gamma=0.7)
weights = torch.FloatTensor([2, 3.2, 0.7])
criterion = FocalLoss(gamma=0.7, weights=weights)
criterion = FocalLoss(gamma=0.7, ignore_index=0)
criterion = FocalLoss(gamma=0)
Create the tensors containing random data values within the range of the batch_size provided in the first line of the code. Use the “s” variable to call the sigmoid() method and call the FocalLoss() method with the tensors to calculate the value of the loss. At the end simply print the value of the loss on the screen by calling the name of the variable:
batch = 10
log = torch.randn(batch)
target = torch.randint(0, 2, size=(batch,))s = torch.nn.Sigmoid()
loss = criterion(s(log), target)
loss
Example 2: Multi-Class Classification
The next example is used to get the loss of the dataset with multiple classes that can be any number greater than 2. Configure the size of the dataset using the batch variable and the number of classes to which these values belong. Call the softmax() method in the m variable to get all the values in a specified range like between 0 to 1:
classes = 5
batch = 10
log = torch.randn(batch, classes)
target = torch.randint(0, classes, size=(batch,))s = torch.nn.Softmax(dim=-1)
criterion(s(log), target)
Call the criterion variable to get the value of the loss and print the value on the screen as displayed in the following screenshot:
Example 3: Multi-Class Sequence Classification
This example uses multi-class with different sequences in the data as the max_length parameter is added while creating the tensors in PyTorch. After that, call the criterion variable to invoke the FocalLoss() method and identify the value of the loss extracted by the method:\
batch_size = 10
max_length = 20
n_class = 5
m = torch.nn.Softmax(dim=-1)
logits = torch.randn(batch_size, max_length, n_class)
target = torch.randint(0, n_class, size=(batch_size, max_length))
criterion(m(logits), target)
That’s all about the process of calculating the focal loss in PyTorch.
Conclusion
To calculate the focal loss in PyTorch, install the focal_loss_torch module and import the torch library from it to call the FocalLoss() method. The FocalLoss() method is used to calculate the focal loss in PyTorch and it can be configured using different values of the gamma and other parameters. The focal loss is used to calculate the loss due to the class imbalance problem in deep learning models. This article has elaborated on how to calculate the focal loss using its method from the PyTorch environment.