Table of contents
1.
Introduction
2.
Understanding Transfer Learning
3.
Scenarios in Transfer Learning
3.1.
Feature extractor using a pre-trained model
3.2.
Full Fine-Tuning
3.3.
Transfer Learning with Domain Adaptation
4.
Example to Understand Transfer Learning
4.1.
Import Libraries
4.2.
Python
4.3.
Training and Validation Datasets
4.4.
Python
4.5.
Loading pre-trained ResNet-18 Datasets
4.6.
Python
4.7.
Defining Loss Function and Optimizer
4.8.
Python
4.9.
Training the Model
4.10.
Python
4.11.
Evaluation of Model
4.12.
Python
5.
Frequently Asked Questions
5.1.
What is transfer learning, and why is it important in Deep Learning?
5.2.
How does transfer learning work in PyTorch with pre-trained models?
5.3.
What are the benefits of utilizing PyTorch for transfer learning?
5.4.
What is the importance of data preparation and augmentation in transfer learning?
5.5.
What is fine-tuning?
6.
Conclusion
Last Updated: Mar 27, 2024
Medium

Transfer Learning using PyTorch

Author Ayush Mishra
0 upvote
Career growth poll
Do you think IIT Guwahati certified course can help you in your career?

Introduction

PyTorch has become a favorite among researchers, students, and practitioners because of its ease of use, versatility, and expressive syntax. Anyone can easily create and train complex deep learning models with PyTorch, giving you the tools to take on real-world issues across various domains.

Transfer Learning using PyTorch

In this blog, we will discuss Transfer Learning using Pytorch. Let’s start going!

Understanding Transfer Learning

In PyTorch, Transfer learning refers to using a pre-trained neural network model as a jumping-off point to address a separate but related job. The pre-trained model is often trained on a sizable dataset for a specific goal (for instance, image classification on ImageNet), and its discovered characteristics can be applied as a starting point for resolving other issues.

The torchvision.models module in PyTorch offers a variety of pre-trained models, making it simple to access cutting-edge architectures that have been trained on massive datasets. 

Scenarios in Transfer Learning

The three major transfer learning scenarios are as follows:

Feature extractor using a pre-trained model

In this scenario, a fixed feature extractor is created using a pre-trained model, with the final layers (such as the classifier) being updated exclusively for the intended purpose.

Only the weights of the new layers are changed during training; the weights of the pre-trained model are frozen.

This method works well when the target task's dataset is modest, and the pre-trained model has acquired general features that can be applied to the new task.

Full Fine-Tuning

In this scenario, a pre-trained model serves as the starting point, and the goal job is used to fine-tune the entire model, including both the pre-trained layers and the additional layers.

The pre-trained model's weights are modified throughout training, enabling the model to adjust to the unique characteristics of the current task.

This method works best when the target task differs significantly from the source task or when the target task's dataset is moderate to large.

Transfer Learning with Domain Adaptation

This scenario is applicable when the data distributions in the source and target domains are marginally different.

To improve generalization to the target domain, it is intended to align the feature distributions between the two domains.

To achieve domain adaptation, methods such as domain adversarial training or the use of domain-specific layers might be used.

When the target work is linked to the source task but originates from a different area, this technique can be helpful.

In the next, we will understand Transfer Learning using Python by an example.

Example to Understand Transfer Learning

In this section, we will see an example of understanding Transfer Learning using Pytorch. We'll use a ResNet-18 model already trained on the ImageNet dataset and fine-tune it for a different task. 

Import Libraries

In this "Transfer learning using Pytorch" section, we will import all the important and necessary libraries and set the device to GPU, if available, otherwise to CPU.

Code

  • Python

Python

import torch

import torch.nn as nn

import torch.optim as optim

import torchvision

import torchvision.transforms as transforms

import torchvision.models as models

# Set device (GPU or CPU)

device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
You can also try this code with Online Python Compiler
Run Code


Output

Output

Training and Validation Datasets

In this section of "Transfer learning using Pytorch," we will perform data transformations such as data augmentation, normalization, cropping, resizing, etc. After that, we will load the CIFAR-10 dataset.

Code

  • Python

Python

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
You can also try this code with Online Python Compiler
Run Code


Output

Output

Loading pre-trained ResNet-18 Datasets

In this section of "Transfer Learning using Pytorch," we will load the pre-trained ResNet-18 model and modify it for the CIFAR-10 classification task.

Code

  • Python

Python

# Load pre-trained ResNet-18 model

pretrained_model = models.resnet18(pretrained=True)

# Modify the model for CIFAR-10 classification (10 classes)

num_classes = 10

in_features = pretrained_model.fc.in_features

pretrained_model.fc = nn.Linear(in_features, num_classes)

# Send the model to the device (GPU or CPU)

pretrained_model.to(device)
You can also try this code with Online Python Compiler
Run Code


Output

output

Defining Loss Function and Optimizer

In this section of "Transfer Learning using Pytorch," we will use the Cross-Entropy Loss function for classification and Stochastic Gradient Descent for the optimizer.

Code

  • Python

Python

# Define loss function (CrossEntropyLoss for classification)

criterion = nn.CrossEntropyLoss()

# Define optimizer (Stochastic Gradient Descent)

optimizer = optim.SGD(pretrained_model.parameters(), lr=0.001, momentum=0.9)
You can also try this code with Online Python Compiler
Run Code

Training the Model

In this "Transfer Learning using Pytorch" section, we will train the modified model.

Code

  • Python

Python

# Training loop

num_epochs = 2

for epoch in range(num_epochs):

    pretrained_model.train()

    running_loss = 0.0

    for i, (inputs, labels) in enumerate(trainloader, 0):

        inputs, labels = inputs.to(device), labels.to(device)

        

        # Zero the parameter gradients

        optimizer.zero_grad()

        

        # Forward pass

        outputs = pretrained_model(inputs)

        loss = criterion(outputs, labels)

        

        # Backward pass and optimize

        loss.backward()

        optimizer.step()

        

        # Print statistics

        running_loss += loss.item()

        if i % 100 == 99:

            print(f"Epoch [{epoch + 1}/{num_epochs}], Batch [{i + 1}/{len(trainloader)}], Loss: {running_loss / 100:.4f}")

            running_loss = 0.0
You can also try this code with Online Python Compiler
Run Code


Output

Output

Evaluation of Model

In this "Transfer Learning using Pytorch" section, we will evaluate the model on test data.

Code

  • Python

Python

# Evaluation

pretrained_model.eval()

correct = 0

total = 0

with torch.no_grad():

    for inputs, labels in testloader:

        inputs, labels = inputs.to(device), labels.to(device)

        outputs = pretrained_model(inputs)

        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)

        correct += (predicted == labels).sum().item()

print(f"Accuracy: {100 * correct / total:.2f}%")
You can also try this code with Online Python Compiler
Run Code


Output

Output

Frequently Asked Questions

What is transfer learning, and why is it important in Deep Learning?

Transfer learning is an important technique in deep Learning that utilizes knowledge from one task to enhance performance on a related one. Pre-trained models improve performance on related tasks, enabling faster and more accurate training on limited datasets.

How does transfer learning work in PyTorch with pre-trained models?

PyTorch uses pre-trained models like ResNet, VGG, and MobileNet to extract patterns from large datasets like ImageNet. These models can be customized by modifying their last layers to match target dataset classes.

What are the benefits of utilizing PyTorch for transfer learning?

PyTorch is a popular transfer learning tool due to its flexibility, ease of use, and strong community support. It enables researchers and developers to efficiently fine-tune pre-trained models for various tasks.

What is the importance of data preparation and augmentation in transfer learning?

Data preprocessing is crucial in transfer learning, as they ensure a suitable format for the pre-trained model, and data augmentation increases the size and diversity of the training dataset, improving generalization and model performance.

What is fine-tuning?

Fine-tuning updates pre-trained model weights during training, allowing the model to adapt to dataset specifics without forgetting pre-training knowledge, especially when the target task is similar and limited labeled data.

Conclusion

Transfer Learning is an approach that involves learning from training one model and applying knowledge to another model. It works with a pre-trained model, which consists of larger datasets.

We hope this blog has helped you to gain knowledge of Transfer Learning using Pytorch. Do not stop learning! We recommend you read some of our related articles to Pytorch: 

 

Refer to our Guided Path to upskill yourself in DSACompetitive ProgrammingJavaScriptSystem Design, and many more! If you want to test your competency in coding, you may check out the mock test series and participate in the contests hosted on Coding Ninjas Studio!

But suppose you have just started your learning process and are looking for questions from tech giants like Amazon, Microsoft, Uber, etc. For placement preparations, you must look at the problemsinterview experiences, and interview bundles.

We wish you Good Luck! 

Happy Learning!

Live masterclass