Table of contents
1.
Introduction
2.
Pre-Requisites
3.
What is PyTorch?
4.
What is Pytorch Lightning?
5.
Basic Terminologies
6.
Necessary Installations and import
7.
Migrating from PyTorch to PyTorch Lightning
8.
General Structure of PyTorch Lightning Code
9.
Example
9.1.
PyTorch Model 
9.2.
Training the Model
9.3.
Visualizing The Model
9.4.
PyTorch Lightning Model
9.5.
DataModule in PyTorch Lightning
9.6.
The Trainer Class
10.
Benefits of PyTorch Lightning
11.
Frequently Asked Questions
11.1.
What is PyTorch?
11.2.
What is PyTorch Lightning?
11.3.
Why should we migrate from PyTorch to PyTorch Lightning?
11.4.
What is the Trainer class in PyTorch Lightning?
11.5.
What is a Learning rate Scheduler?
12.
Conclusion
Last Updated: Mar 27, 2024
Hard

How to Migrate from PyTorch to PyTorch Lightning

Author Aditya Gupta
0 upvote
Career growth poll
Do you think IIT Guwahati certified course can help you in your career?

Introduction

Hey Ninjas! In the world of Deep Learning, you might have needed help in creating Neural networks and working with large datasets while using PyTorch. If that's the case, then PyTorch Lightning is for you.

How to migrate from PyTorch to PyTorch Lightning

It simplifies many complex tasks and makes it easy for you to train models. In this Blog, we will learn about PyTorch Lightning and how to migrate from Pytorch to PyTorch Lightning.

Pre-Requisites

Before heading to the blog, let us discuss the Pre-Requisites for a better understanding of the blog.

  • Understanding of Python Programming language.
     
  • Understanding of Deep Learning models.
     
  • Keen interest in Machine Learning.
     
  • Basic understanding of Neural Networks.
     
  • Core Concepts of PyTorch and PyTorch Lightning.
     
  • Understanding of basic terminologies in building models.

What is PyTorch?

PyTorch

PyTorch is an open-source framework based on Python that is used to create and modify Neural networks. It also has some pre-built collection of Neural networks, which can be used to create different structures.

It also has a tensor computational library, which is used for efficient calculations on large datasets due to the support of GPU acceleration.

What is Pytorch Lightning?

PyTorch Lightning

So Ninjas, you must be wondering about PyTorch Lightning. It is an open-source Python library that is used to train deep learning models in an easy way by using PyTorch. It takes care of the repeating part of training models so that we can look at the complex tasks to train the model.

We should not confuse PyTorch Lightning with a new framework. It is a high-level interface of PyTorch that gives the same results every time we run it, making it more reproducible. 

Basic Terminologies

Before learning about migrating PyTorch models to PyTorch Lightning models, let us discuss about the basic terminologies in PyTorch and PyTorch Lightning that we will use in this blog.

  • LightningModule: It encloses the Neural network architecture and defines the forward method, loss function, and any custom methods.
     
  • LightningDataModule: It is used to handle data loading and preprocessing. It gives us a constant way to prepare data for training, validation, and testing.
     
  • Trainer: In PyTorch Lightning, the Trainer class manages the training process.
     
  • Callbacks: These are the functions that can be used to customise how the training process works while creating a model.
     
  • Forward Pass: Forward Pass refers to the input we give to the Neural network, which provides the output.
     
  • Optimizers: PyTorch contains various optimizers which are used to train Neural networks. We will use Adam Optimizer in this blog because it adapts to both learning rates and momentum.
     
  • Learning rate schedulers: The learning rate schedulers are used to determine the step size, which tells how quickly or slowly the model learns.
     
  • EPOCH: Epoch refers to the time when all the samples in the training dataset are used to update the model we created.
     
  • Cross_Entropy: It refers to the error between the predicted output using a model and the true output.

Necessary Installations and import

Firstly we would install PyTorch Lightning and import all the necessary libraries. We will import other libraries in this blog as we need them.

Installing PyTorch Lightning

! pip install PyTorch-lightning --quiet 
You can also try this code with Online Python Compiler
Run Code


While using the conda package manager, use the below command.

! conda install PyTorch-lightning -c conda-forge
You can also try this code with Online Python Compiler
Run Code


Necessary Imports

import PyTorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
You can also try this code with Online Python Compiler
Run Code

Migrating from PyTorch to PyTorch Lightning

Let us look at the below steps to migrate from PyTorch to PyTorch Lightning.

Step 1: Firstly, we will create a PyTorch code and its Forward pass. We will directly use it in PyTorch Lightning.

Step 2: Now, to build the training logic in PyTorch Lightning, we would subclass the LightningModule class, and in the training_step of LightningModule class, we will see how the model will work on a single batch of data.

Step 3: In PyTorch, we keep optimizers and learning rate schedulers in the main training loop but in PyTorch Lightning, we will move them to the configure_optimizers method in LightningModule class.

Step 4: Now, to validate the model during training, we will subclass the LightningModule class into the validation_step method.

Step 5: For testing the model after training, we will subclass the Lightning Module class into the test_step method.

Step 6: In PyTorch, we manually run the training loop, but in PyTorch Lightning, we will use the Trainer class to handle the training process.

So, Ninjas, these are the general steps to migrate PyTorch code into PyTorch Lightning code. You might be confused after seeing all these steps, but we will look at the general structure of PyTorch Lightning code and take an example to make it more clear for you.

General Structure of PyTorch Lightning Code

Let us look at the general structure of PyTorch Lightning code to have a gist of the code we write in PyTorch Lightning.

import PyTorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # Defining the model architecture here
        # ...

    def forward(self, x):
        # Defining the forward pass of the model
        # ...

    def training_step(self, batch, batch_idx):
        # Defining the training step logic here
        # ...

    def validation_step(self, batch, batch_idx):
        # Defining the validation step logic here
        # ...

    def test_step(self, batch, batch_idx):
        # Defining the test step logic here
        # ...

    def configure_optimizers(self):
        # Defining your optimizer and learning rate scheduler here
        # …

# Training the model
trainer = pl.Trainer(max_epochs=10)
trainer.fit(MNISTLightningModule(), train_data, val_data)
You can also try this code with Online Python Compiler
Run Code

Example

Let us look at a simple PyTorch model and convert it into PyTorch Lightning.

PyTorch Model 

Let us create a PyTorch model using the MNIST (Modified National Institute of Standards and Technology ) dataset.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Define the model
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.fc1 = nn.Linear(784, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        return x

# Define the training data
train_data = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        root='./data',
        train=True,
        download=True,
        transform=transforms.ToTensor()
    ),
    batch_size=32,
    shuffle=True
)

# Define the validation data
val_data = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        root='./data',
        train=False,
        download=True,
        transform=transforms.ToTensor()
    ),
    batch_size=32,
    shuffle=False
)

# Create an instance of the model
model = MNISTClassifier()

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define the learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
You can also try this code with Online Python Compiler
Run Code


Output

Output PyTorch

Explanation

This is the model of a simple Neural network classifier for the MNIST handwritten dataset.

Here is a step-by-step explanation of the model.

  • Firstly we input a 28x28 pixel grayscale image of a handwritten digit.
     
  • We have a single fully connected layer with 784 input features and 10 output features in this model. These input features are the flattened version of the input image (28x28 = 784), and the output features are the probabilities of the input image belonging to digits 0 to 9.
     
  • In the forward pass, the input image is reshaped into a 2D tensor and passed through the fully connected layer.
     
  • The output of the model is a tensor of shape (batch_size, 10), where batch_size is the number of input images.
     
  • We use the Adam optimizer (optim.Adam) to update the parameters of the model during training and set the learning rate for the optimizer to 0.001.
     
  • The learning rate scheduler is used to adjust the learning rate during training. In this case, we reduce the learning rate by 0.1 every 10 steps.
     
  • The training and validation data loaders load the MNIST dataset for training and validation.

Training the Model

After creating the PyTorch model we need to train the model to get the average loss for the epoch.

import torch.nn.functional as F

# Defining the number of training epochs
num_epochs = 10

# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    total_loss = 0.0
    
    for inputs, labels in train_data:
        optimizer.zero_grad()  # Zero the gradients
        
        # Forward pass
        outputs = model(inputs)
        
        # Compute the loss
        loss = F.cross_entropy(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    # Computing the average loss for the epoch
    avg_loss = total_loss / len(train_data)
    
    # Printing the average loss for the epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")
    
    # Update the learning rate
    scheduler.step()
You can also try this code with Online Python Compiler
Run Code


Output

Output Training

Explanation

In this code snippet, firstly, we iterate over the training dataset for 10 epochs to get accurate results. For every epoch, we set the model to training mode and initialized the total loss. We used the optimizer to train the model and calculated the average loss for the epoch.

Visualizing The Model

We will visualize the model by using torchviz and graphviz library in Python. Firstly we install the torchviz library if not installed.

!apt-get install graphviz -y
!pip install torchviz
import torch
from torchviz import make_dot

# Creating an instance of the model
model = MNISTClassifier()

# Generating a dummy input
dummy_input = torch.zeros(1, 784)

# Passing the dummy input through the model
output = model(dummy_input)

# Visualizing the computation graph
dot = make_dot(output, params=dict(model.named_parameters()))
dot.render("model_graph")  # Save the graph to a file
You can also try this code with Online Python Compiler
Run Code


This saves the model in model_graph.pdf.

Model Graph

Explanation

To visualize the model, we have made a dummy tenor of shape (1,784) where 1 is the batch size of our dummy tensor, and 784 is the number of features.

Now let us see the above graph and understand what each term represents.

  • Fc1.weight is the weight parameter in the first fully connected layer (fc1) of shape (10,784), where 10 is the number of neurons in fc1, and 784 is the number of features we sent in the input tensor.
     
  • Fc1.bias is the bias tensor in the fc1 layer of shape (1,10), where 1 is the number of bias terms, and 10 is the number of neurons in the output layer.
     
  • The AccumulateGrad node in this graph represents the collection of gradients during the backward pass in our model.
     
  • Tbackward0 represents the starting point of the backward pass in the graph.
     
  • AddmmBackward0 represents the “Add Matrix Multiply operation” applied on the input tensor.
     
  • The output tensor is (1,10), where 1 represents the batch size, and 10 represents the number of neurons in the output tensor. 

PyTorch Lightning Model

Now let us convert the following PyTorch model into PyTorch Lightning model by following the steps we mentioned above.

import torch.optim as optim
import pytorch_lightning as pl
from torch.optim.lr_scheduler import StepLR

# Define the LightningModule
class MNISTLightningModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = MNISTClassifier()  # this is the PyTorch model we made above

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self(x)
        loss = F.cross_entropy(output, y)
        self.log('train_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        output = self(x)
        loss = F.cross_entropy(output, y)
        self.log('test_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self(x)
        loss = F.cross_entropy(output, y)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
You can also try this code with Online Python Compiler
Run Code


Explanation

  • In this code, we used PyTorch Lightning to train a model on the MNIST dataset that we configured in the PyTorch model.
     
  • We have a MNISTLightningModule class that inherits from pl.LightningModule.
     
  • We have included methods for the forward pass, training step, validation step, and configuring optimizers.
     
  • The training_step and validation_step methods calculate the loss and log it.
     
  • The configure_optimizers method has an Adam optimizer and learning rate scheduler.
     
  • The training loop in the lightning module is initiated using trainer.fit() for training/validation data.
     
  • To test the model, we used the test_step function in MNISTLightningModule.

DataModule in PyTorch Lightning

DataModule is a separate module that uses the operations such as data loading, splitting into train, validation, test datasets and preprocessing.

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms

class MNISTDataModule(pl.LightningDataModule):

  def setup(self, stage):
    # transforms for images
    transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
      
    # prepare transforms standard to MNIST
    mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
    self.mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
    
    self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])

  def train_dataloader(self):
    return DataLoader(self.mnist_train, batch_size=64)

  def val_dataloader(self):
    return DataLoader(self.mnist_val, batch_size=64)

  def test_dataloader(self):
    return DataLoader(self.mnist_test, batch_size=64)
You can also try this code with Online Python Compiler
Run Code


Explanation

  • In this code, we defined a MNISTDataModule class that is taken from pl.LightningDataModule.
     
  • We have set up the MNIST dataset for training, validation, and testing.
     
  • In this code, transforms.Compose is used to tell about the image transformations.
     
  • We have used random_split() to split the training dataset.
     
  • We have created the train_dataloader(), val_dataloader(), and test_dataloader() with the proper datasets and batch sizes.

The Trainer Class

Finally, we make use of the Trainer class in PyTorch Lightning to handle the training loops and validation in our model.

dm = MNISTDataModule()
dm.setup(stage='fit')  # Assuming you want to set up the data module for training

model = MNISTLightningModule()
checkpoint_callback = pl.callbacks.ModelCheckpoint()

trainer = pl.Trainer(max_epochs=10,callbacks=[checkpoint_callback])

# Train the model
trainer.fit(model, dm)

# Evaluate the model
trainer.test(dataloaders=dm.test_dataloader())
You can also try this code with Online Python Compiler
Run Code


Output

Output PyTorch Lightning
Output PyTorch Lightning

Explanation

  • In this code, firstly, we initiate the MNISTDataModule object (dm) we have created in the above step.
     
  • The MNISTLightningModule model, which we created by migrating from PyTorch to PyTorch Lightning, is created.
     
  • We then save model checkpoints using callbacks.
     
  • A Trainer object from the Trainer class in PyTorch Lightning is created with max_epochs=10 and the ModelCheckpoint callback.
     
  • The trainer.fit() method is called to train the model.
     
  • The trainer.test() method is used to test the trained model on the test dataset.

Benefits of PyTorch Lightning

Here are some of the benefits of using PyTorch Lightning in place of PyTorch.

  • The code structure is simple and organised in comparison to PyTorch, which makes code easy to look and read.
     
  • The Trainer class in Pytorch Lightning takes care of training the data, and we need to take care of only complex tasks.
     
  • PyTorch Lightning gets easily integrated with other Python libraries such as Pandas, Numpy, PyTorch Ecosystem and Tensorboard.
     
  • It is a high-level interface of PyTorch that takes care of Boilerplate code and gives the same results every time we run it, making it more reproducible. 

Frequently Asked Questions

What is PyTorch?

It is an open-source framework based on Python that is used to create and modify Neural networks.

What is PyTorch Lightning?

It is an open-source Python library that takes care of the repeating part of training deep learning models so that we can look upon the complex tasks to train the model.

Why should we migrate from PyTorch to PyTorch Lightning?

Migrating from PyTorch to PyTorch Lightning has many benefits, such as simple code structure, improved code and readability better reproducibility.

What is the Trainer class in PyTorch Lightning?

In PyTorch, we manually run the training loop, but in PyTorch Lightning, we use the Trainer class to handle the training process.

What is a Learning rate Scheduler?

The learning rate schedulers are used to determine the step size, which tells how quickly or slowly the model learns.

Conclusion

This article discusses the topic of how to migrate from PyTorch to PyTorch Lightning. In this blog, we have discussed PyTorch and PyTorch Lightningincluding the benefits of PyTorch Lightning.

We hope this blog has helped you enhance your knowledge of how to migrate from PyTorch to PyTorch Lightning.

And many more on our platform Coding Ninjas Studio.

Refer to our Guided Path to upskill yourself in DSACompetitive ProgrammingJavaScriptSystem Design, and many more! If you want to test your coding ability, you may check out the mock test series and participate in the contests hosted on Coding Ninjas Studio!

But suppose you have just started your learning process and are looking for questions from tech giants like Amazon, Microsoft, Uber, etc. In that case, you must look at the problemsinterview experiences, and interview bundles for placement preparations.

However, you may consider our paid courses to give your career an edge over others!

Happy Learning!

Live masterclass