Code360 powered by Coding Ninjas X Naukri.com. Code360 powered by Coding Ninjas X Naukri.com
Last Updated: Mar 27, 2024
Difficulty: Easy

Transfer Learning with Keras

Leveraging ChatGPT - GenAI as a Microsoft Data Expert
Speaker
Prerita Agarwal
Data Specialist @
23 Jul, 2024 @ 01:30 PM

Introduction

Everyone has heard the phrase "Those who cannot remember the past are condemned to repeat it" this is how humankind has evolved via using the available resources and making the best decision for themselves. There's no point in reinventing the wheel again and again. Transfer learning is based on a similar concept.

It is a popular deep learning approach. Transfer learning is a problem in research machine learning that concentrates on solving a new issue using the existing knowledge gained from solving the previous situation.

For example, knowledge gained while recognizing cats could be applied to identify dogs.

What is Transfer Learning? 

Transfer Learning is about leveraging feature representations from an existing model, so you don't have to train a model from scratch. It is generally used in image-related learning and natural language processing tasks. These pre-trained models are usually trained on a vast dataset, making them highly robust and a standard benchmark among deep learning enthusiasts. These models are used to make predictions on new tasks or may be integrated into building a new model. Using pre-trained models lowers the training time, and new models generally have lower generalization errors. 

Transfer learning is an excellent technique to incorporate when you don't have enough data to train a new model. The weights of the pre-trained model become the initial weights of the new model, which significantly reduces the time and data requirement. 

Source - link

Get the tech career you deserve, faster!
Connect with our expert counsellors to understand how to hack your way to success
User rating 4.7/5
1:1 doubt support
95% placement record
Akash Pal
Senior Software Engineer
326% Hike After Job Bootcamp
Himanshu Gusain
Programmer Analyst
32 LPA After Job Bootcamp
After Job
Bootcamp

Why Transfer Learning?

We have a basic idea that transfer learning uses past knowledge to solve new but similar problems. But why do we need to transfer learning? There are a couple of challenges before you could jump into training a deep learning model. Deep learning models are data-hungry. The ImageNet dataset has over 1 million image samples to get its gist. It's a mammoth task to gather this much data. But even if you could, it still requires a lot of computation power to train a model from scratch over this vast dataset. Transfer learning addresses this problem well.  

Source - link

How to implement Transfer Learning?

  • Obtain a pre-trained model.
  • Create a base model.
  • Freeze layers to preserve previous training results.
  • Add new trainable layers.
  • Train new layers on the dataset. 
  • Fine-tune the model.

Code Implementation with Keras

We’ll be implementing a Resnet-50 model for flower classification

import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
from tensorflow.keras import layers, Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

 

Importing the dataset. You may download the dataset from the given URL locally on your machine

import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

 

Splitting the dataset into testing and validation images. We transform the image into dimensions of 180*180. This ensures uniformity in the image samples. The validation split is 0.2

img_height,img_width=180,180
batch_size=32
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

 

Validation subset

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

 

Visualizing the data

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, labels in trainds.take(1):
  for i in range(6):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(classnames[labels[i]])
    plt.axis("off")*

We are importing the model. You may choose any model of your choice from keras.applications. We'll be using ResNet-50

resnet_model = Sequential()

pretrained_model= tf.keras.applications.ResNet50(include_top=False,
                  input_shape=(180,180,3),
                  pooling='avg',classes=5,
                  weights='imagenet')
for layer in pretrained_model.layers:
        layer.trainable=False

resnet_model.add(pretrained_model)

 

Weights = ‘imagenet’ enables the model to use the weights it learnt while being trained on the Imagenet dataset. layers.trainable = ‘False’ ensures the model isn’t retrained to learn the weights all over again.  We’ll now add the output layer

resnet_model.add(Flatten())
resnet_model.add(Dense(512, activation='relu'))
resnet_model.add(Dense(5, activation='softmax'))
resnet_model.summary()

Now the model is ready; we compile the model. The loss function used is categorical cross-entropy. We train the model with 10 epochs. It may seem a bit low, but this makes transfer learning unique. We aren't starting from scratch

resnet_model.compile(optimizer=Adam(lr=0.001),loss='categorical_crossentropy',metrics=['accuracy'])

history = resnet_model.fit(train_ds, validation_data=val_ds, epochs=10)

 

Evaluating the model. We’ll measure the accuracy after each epoch

fig1 = plt.gcf()
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.axis(ymin=0.4,ymax=1)
plt.grid()
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
plt.legend(['train', 'validation'])
plt.show()

The validation accuracy comes out to be around 90%. 

Frequently Asked Questions

What are some disadvantages of transfer learning? 

As of now, there are no proven disadvantages of transfer learning. But it's a very new practice and still has a broad scope of research. However, it can only incorporate transfer learning for similar problems. 

Mention some popular pre-trained models. 

Aside from ResNet-50, MobileNet, VGG16, and Xception are other popular pre-trained models used for transfer learning. 

In which situations do transfer learning models tend to overfit?

When the dataset for the target model is minimal, the model may tend to overfit if it is too similar to the source model. Adjusting the learning rate and freezing some layers in the source model can prevent it. 

Conclusion

The blog focuses on implementing a model using transfer learning. The blog's objective was to make readers comfortable using transfer learning and various deep learning libraries. We recommend readers code along to grasp the minute details. If you are interested in machine learning and deep learning, check out our industry-oriented Machine learning course curated by Stanford University alumni and industry experts.

Topics covered
1.
Introduction
2.
What is Transfer Learning? 
3.
Why Transfer Learning?
4.
How to implement Transfer Learning?
5.
Code Implementation with Keras
6.
Frequently Asked Questions
6.1.
What are some disadvantages of transfer learning? 
6.2.
Mention some popular pre-trained models. 
6.3.
In which situations do transfer learning models tend to overfit?
7.
Conclusion