Do you think IIT Guwahati certified course can help you in your career?
No
Introduction
Everyone has heard the phrase "Those who cannot remember the past are condemned to repeat it" this is how humankind has evolved via using the available resources and making the best decision for themselves. There's no point in reinventing the wheel again and again. Transfer learning is based on a similar concept.
It is a popular deep learning approach. Transfer learning is a problem in research machine learning that concentrates on solving a new issue using the existing knowledge gained from solving the previous situation.
For example, knowledge gained while recognizing cats could be applied to identify dogs.
What is Transfer Learning?
Transfer Learning is about leveraging feature representations from an existing model, so you don't have to train a model from scratch. It is generally used in image-related learning and natural language processing tasks. These pre-trained models are usually trained on a vast dataset, making them highly robust and a standard benchmark among deep learning enthusiasts. These models are used to make predictions on new tasks or may be integrated into building a new model. Using pre-trained models lowers the training time, and new models generally have lower generalization errors.
Transfer learning is an excellent technique to incorporate when you don't have enough data to train a new model. The weights of the pre-trained model become the initial weights of the new model, which significantly reduces the time and data requirement.
We have a basic idea that transfer learning uses past knowledge to solve new but similar problems. But why do we need to transfer learning? There are a couple of challenges before you could jump into training a deep learning model. Deep learning models are data-hungry. The ImageNet dataset has over 1 million image samples to get its gist. It's a mammoth task to gather this much data. But even if you could, it still requires a lot of computation power to train a model from scratch over this vast dataset. Transfer learning addresses this problem well.
Freeze layers to preserve previous training results.
Add new trainable layers.
Train new layers on the dataset.
Fine-tune the model.
Code Implementation with Keras
We’ll be implementing a Resnet-50 model for flower classification
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
from tensorflow.keras import layers, Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
Importing the dataset. You may download the dataset from the given URL locally on your machine
Splitting the dataset into testing and validation images. We transform the image into dimensions of 180*180. This ensures uniformity in the image samples. The validation split is 0.2
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
for images, labels in trainds.take(1):
for i in range(6):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(classnames[labels[i]])
plt.axis("off")*
We are importing the model. You may choose any model of your choice from keras.applications. We'll be using ResNet-50
resnet_model = Sequential()
pretrained_model= tf.keras.applications.ResNet50(include_top=False,
input_shape=(180,180,3),
pooling='avg',classes=5,
weights='imagenet')
for layer in pretrained_model.layers:
layer.trainable=False
resnet_model.add(pretrained_model)
Weights = ‘imagenet’ enables the model to use the weights it learnt while being trained on the Imagenet dataset. layers.trainable = ‘False’ ensures the model isn’t retrained to learn the weights all over again. We’ll now add the output layer
Now the model is ready; we compile the model. The loss function used is categorical cross-entropy. We train the model with 10 epochs. It may seem a bit low, but this makes transfer learning unique. We aren't starting from scratch
resnet_model.compile(optimizer=Adam(lr=0.001),loss='categorical_crossentropy',metrics=['accuracy'])
history = resnet_model.fit(train_ds, validation_data=val_ds, epochs=10)
Evaluating the model. We’ll measure the accuracy after each epoch
The validation accuracy comes out to be around 90%.
Frequently Asked Questions
What are some disadvantages of transfer learning?
As of now, there are no proven disadvantages of transfer learning. But it's a very new practice and still has a broad scope of research. However, it can only incorporate transfer learning for similar problems.
Mention some popular pre-trained models.
Aside from ResNet-50, MobileNet, VGG16, and Xception are other popular pre-trained models used for transfer learning.
In which situations do transfer learning models tend to overfit?
When the dataset for the target model is minimal, the model may tend to overfit if it is too similar to the source model. Adjusting the learning rate and freezing some layers in the source model can prevent it.
Conclusion
The blog focuses on implementing a model using transfer learning. The blog's objective was to make readers comfortable using transfer learning and various deep learning libraries. We recommend readers code along to grasp the minute details. If you are interested in machine learning and deep learning, check out our industry-oriented Machine learning course curated by Stanford University alumni and industry experts.