Do you think IIT Guwahati certified course can help you in your career?
No
Introduction
PyTorch is a popular deep-learning framework that researchers often use for its flexible graph features. However, users still face problems when PyTorch is used for real-world applications and high-performance environments. Hence we have TorchScript, a game-changing solution that overcomes these problems.
In this article, we will be discussing PyTorch and TorchScript in detail, followed by the importance of converting PyTorch to TorchScript.
Introduction to PyTorch
PyTorch is an open-source deep learning framework developed by Facebook's AI Research Lab (FAIR). It is used to build and train neural networks for many machine-learning tasks. PyTorch provides a flexible and engaging computation graph, thus making it easier to work with than other static graph frameworks like TensorFlow.
PyTorch has two modes: eager mode and script mode. While eager mode is great for trying new ideas and testing models quickly, script mode focuses more on real-world applications. The Script mode is further divided into two parts: PyTorch JIT and TorchScript.
Script Mode
The script mode allows users to convert Pytorch code into a format that is used in production. That is it is important, we can take our trained models and make them work faster and better when they are applied to real-world examples like websites and mobile apps. It makes the models run smoothly and quickly, even on devices with fewer resources.
Importance of Script Mode
The Script mode in PyTorch has two key benefits:
Portability: Script mode allows us to convert the model code into a format that is not tied to Python. That is, developers can run the model without needing a Python runtime. This makes it more independent and easy to use across different platforms and devices.
Performance: PyTorch JIT (part of script mode) uses the information about how the model runs during execution to make it faster and more efficient. It improves the model by combining the layers and rounding the data, thus making it more efficient. This, in turn, results in better performance when the model is applied to real-world uses.
Understanding TorchScript
TorchScript is a subset of PythonPython designed to work well with the PyTorch framework's machine learning models and neural networks. It improves the performance of PyTorch models by converting the Python code into a more efficient form.
Thus TorchScript is a way of writing PyTorch models that allows them to run faster and work on various devices without needing the full Python language. This makes it easier to use the models in real-world uses like mobile apps or systems where PythonPython may be unavailable or too slow.
Importance of TorchScript
The importance of TorchScript are:
TorchScript allows models to be deployed on different platforms and devices without relying on PythonPython, thus increasing its flexibility.
TorchScript enables us to move models from PythonPython to other high-performance environments like C++, making deployment smooth and easy.
A number of models in TorchScript can be processed at the same time by multiple requests without any performance issues.
As Python programs may cause performance and multi-threading issues in production setups, with TorchScript, these issues can be avoided.
TorchScript provides tools to capture the definition of the model, thus providing more adaptability in model representation.
TorchScript comes with the feature of easy export and import. Models can be easily trained and developed in PythonPython, then exported using TorchScript for use in other environments.
Understanding PyTorch JIT
The PyTorch JIT (Just In Time) compiler is a smart tool that takes our model's TorchScript code and converts it into a simpler form, thus making it run faster. PyTorch JIT gives a speed boost to our machine-learning model.
There are two methods by which we can make our PyTorch models work with the JIT compiler:
Tracing: This method involves using the torch.jit.trace API.
Scripting: This method involves using the torch.jit.script API.
Features of PyTorch JIT
The features of PyTorch JIT are:
PyTorch acts like a smart and safe interpreter that can handle many tasks at the same time without getting confused, thus making our models run faster.
By writing simple code transformations, you can improve your models as per specific needs.
Besides using as an inference, PyTorch JIT is also used to train models. It automatically figures out how to adjust and improve the models while training them. This is also known as "auto-differentiation."
Converting PyTorch to TorchScript
Converting PyTorch to TorchScript can be done in two ways:
Tracing, and
Scripting
Let us understand each of them one by one.
Tracing
Tracing in TorchScript is a method that is used for converting Pytorch to TorchScript format. It is like making a step-by-step record of how a model works when given certain inputs. Here all the tasks are recorded in a computational graph. This graph is then used to state the behavior of the model.
We can perform Tracing using 'torch.jit.trace(model, input).’ With the help of TorchScript, a PyTorch model can be converted into an improved format that can be run without PythonPython.
Let us understand tracing better with an example where we will create a custom model, give it some example inputs, and record how it behaves during Tracing.
Code
Python
Python
import torch
class CustomModel(torch.nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
self.fc_layer = torch.nn.Linear(4, 4)
def forward(self, x_input, hidden):
new_hidden = torch.tanh(self.fc_layer(x_input) + hidden)
return new_hidden, new_hidden
# Creating an instance of the custom model
net = CustomModel()
# Creating example input
x_input, hidden_state = torch.rand(3, 4), torch.rand(3, 4)
# Tracing the custom model
traced_net = torch.jit.trace(net, (x_input, hidden_state))
# Printing the traced model
print(traced_net)
# Performing inference with the help of the traced model
output_hidden, output = traced_net(x_input, hidden_state)
print(output_hidden, output)
You can also try this code with Online Python Compiler
In the above example, the PyTorch code is converted into a simpler form called TorchScript. Only the important tensor tasks are recorded, while the others are ignored. The 'torch.jit.trace' function used in the above code runs the model with inputs and creates a TorchScript version.
TorchScript keeps the model definitions in a graph format which can be accessed through the '.graph' property of the traced model, as shown below.
Code
Python
Python
print(traced_net.graph)
You can also try this code with Online Python Compiler
The above output is a simplified version of the Python code which is shorter and more readable.
Scripting
Scripting is another way of converting PyTorch to TorchScript. It is done using the 'torch.jit.script' API. With the help of scripting, we can write the code directly in TorchScript mode, which may otherwise require more lines of code. This allows users a wider range of PyTorch features in the model code.
Scripting offers more flexibility than tracing. It is generally used when Tracing may not record all the complexities of the model.
Let us understand this with a simple example.
Code
Python
Python
import torch
# Defining the custom model module using the MyDecisionGate module
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
# Defining the custom model module using the MyDecisionGate module
class MyModel(torch.nn.Module):
def __init__(self, dg):
super(MyModel, self).__init__()
self.my_dg = dg
self.my_linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_hidden = torch.tanh(self.my_dg(self.my_linear(x)) + h)
return new_hidden, new_hidden
# Creating an instance of MyModel with MyDecisionGate as the argument
net = MyModel(MyDecisionGate())
# Creating random input 'x' and 'h'
x, h = torch.rand(3, 4), torch.rand(3, 4)
# Tracing the MyModel using torch.jit.trace
traced_net = torch.jit.trace(net, (x, h))
# Printing the TorchScript code of MyDecisionGate and MyModel after tracing
print(traced_net.my_dg.code)
print(traced_net.code)
You can also try this code with Online Python Compiler
The code provided above tried to trace the model using torch.jit.trace, but it produced a warning because Tracing removed some parts of the model's behavior, including the if-else statements. This happened because tracing records only the tasks that occur during a specific run of the model and constructs a script accordingly, which may miss some control flow logic.
To overcome this issue, torch.jit.script performs scripting on the code.
Code
Python
Python
# Scripting the MyDecisionGate and MyModel using torch.jit.script
scripted_gate = torch.jit.script(MyDecisionGate())
net = MyModel(scripted_gate)
scripted_net = torch.jit.script(net)
# Printing the TorchScript code of MyDecisionGate and MyModel after scripting
print(scripted_gate.code)
print(scripted_net.code)
You can also try this code with Online Python Compiler
torch.jit.script removes the warning that had occurred earlier. Scripting directly analyzes the model's Python source code and converts it into TorchScript mode. This way, scripting records the entire model graph correctly, including control flow, without losing any important parts of the model's behavior.
Difference between Tracing and scripting
The key difference between Tracing and Scripting are:
Tracing
Scripting
Tracing does not support control flow. It records tensor operations only.
Scripting supports Python control flow (if-else, loops) and preserves it.
Tracing does not support data structures like lists or dictionaries.
Scripting supports a wide range of data structures, including lists and dictionaries.
Tracing requires us to make sure that the traced models can handle many inputs explicitly.
Scripted models are always ready to use and do not require extra general usability checks. They automatically support all the features compatible with the JIT compiler without any extra effort.
Tracing does not support any advanced Python features like classes, or built-ins (e.g., range, zip).
Scripting supports more advanced Python features and abstract types.
Frequently Asked Questions
Why do you need to convert PyTorch to TorchScript?
Converting PyTorch to TorchScript makes our models independent of PythonPython, thus making their deployment possible in high-performance environments like C++. This boosts the performance of our model, increases portability, and resolves any multi-threading issues.
What are the two methods for converting PyTorch to TorchScript?
The two methods for converting PyTorch to TorchScript are "tracing" and "scripting." While tracing records how the model behaves from example inputs, scripting directly changes the model's code to make it work independently.
Can TorchScript be used in production deployment?
Yes, TorchScript can be used in production deployment. It improves the performance and efficiency of PyTorch modules by transforming the Python code into a more simplified form. It boosts the loading and execution speed making it a good choice for big projects.
Can you combine TorchScript with other deployment tools?
Yes, we can combine TorchScript with other deployment tools like TorchServe, ONNX Runtime, LibTorch and PyTorch JIT compiler. This allows us to deploy our model across different platforms from web servers to cloud services effortlessly.
Conclusion
Kudos on finishing this article! We have discussed the importance of converting PyTorch models into TorchScript format. TorchScript makes our runs faster and gives better results. Hence is preferred by researchers and developers.
We hope this blog has helped you expand your knowledge on converting PyTorch to TorchScript.
Keep learning! We suggest you read some of our other articles related to Python and Deep Learning:
But suppose you are just a beginner and are looking for questions from tech giants like Amazon, Microsoft, Uber, etc. For placement preparations, you must look at the problems, interview experiences, and interview bundles.