In the previous two posts, we learned how to use pre-trained models and how to extract features from them for training a model for a different task. In this tutorial, we will learn how to fine-tune a pre-trained model for a different task than it was originally trained for.

We will try to improve on the problem of classifying pumpkin, watermelon, and tomato discussed in the previous post. We will be using the same data for this tutorial.

What is Fine-tuning of a network

We have already explained the importance of using pre-trained networks in our previous article. Just to recap, when we train a network from scratch, we encounter the following two limitations :

Huge data required – Since the network has millions of parameters, to get an optimal set of parameters, we need to have a lot of data.

Huge computing power required – Even if we have a lot of data, training generally requires multiple iterations and it takes a toll on the computing resources.

The task of fine-tuning a network is to tweak the parameters of an already trained network so that it adapts to the new task at hand. As explained here, the initial layers learn very general features and as we go higher up the network, the layers tend to learn patterns more specific to the task it is being trained on. Thus, for fine-tuning, we want to keep the initial layers intact ( or freeze them ) and retrain the later layers for our task.

Thus, fine-tuning avoids both the limitations discussed above.

The amount of data required for training is not much because of two reasons. First, we are not training the entire network. Second, the part that is being trained is not trained from scratch.

Since the parameters that need to be updated is less, the amount of time needed will also be less.

Fine-tuning in Keras

Let us directly dive into the code without much ado. We will be using the same data which we used in the previous post. You can choose to use a larger dataset if you have a GPU as the training will take much longer if you do it on a CPU for a large dataset. We will use the VGG model for fine-tuning.

Download Code To easily follow along this tutorial, please download code by clicking on the button below. It’s FREE!

Freeze the required layers

In Keras, each layer has a parameter called “trainable”. For freezing the weights of a particular layer, we should set this parameter to False, indicating that this layer should not be trained. That’s it! We go over each layer and select which layers we want to train.

# Freeze the layers except the last 4 layers
for layer in vgg_conv.layers[:-4]:
layer.trainable = False
# Check the trainable status of the individual layers
for layer in vgg_conv.layers:
print(layer, layer.trainable)

Create a new model

Now that we have set the trainable parameters of our base network, we would like to add a classifier on top of the convolutional base. We will simply add a fully connected layer followed by a softmax layer with 3 outputs. This is done as given below.

Setup the data generators

We have already separated the data into train and validation and kept it in the “train” and “validation” folders. We can use ImageDataGenerator available in Keras to read images in batches directly from these folders and optionally perform data augmentation. We will use two different data generators for train and validation folders.

Train the model

Till now, we have created the model and set up the data for training. So, we should proceed with the training and check out the performance. We will have to specify the optimizer and the learning rate and start training using the model.fit() function. After the training is over, we will save the model.

References

Subscribe & Download Code

If you liked this article and would like to download code and example images used in this post, please subscribe to our newsletter. You will also receive a free Computer Vision Resource Guide. In our newsletter, we share OpenCV tutorials and examples written in C++/Python, and Computer Vision and Machine Learning algorithms and news.

Comments

we talk about fine tunning the full network, when New dataset is large and similar to the original dataset. Here i am confused about What exactly does it mean by fine tuning. I mean do we use the weights of per-trained model for fine tunning. Or We fine tune the network using random initializations.

Fine tuning always means starting with the pre-trained weights ( not random initialization ) and tuning these weights. You can chose which layer you want to tune( by keeping it trainable ) or freezed ( by keeping trainable=False).

We use random init when we don’t have the weights for that layer or variable. For example we have changed the last layer to suit our problem. In this case the weights get intialized randomly.

Great tutorial, but I have one question. Supose I have a model architecture in wich I applied transfer learning + finetunning to adapt it to my task on hand, and after a reasonable number of epochs it gets stucked on overfitting. Then I apply more regularization techniques to get rid of overfiiting. When doing this I should use the learned weigths after those epochs (who leaded to overfitting) as starting point or I should do all the transfer learning + finetunning process again from begining (Imagenet weights)?

If that is the case, then you should save models every M epochs ( M can be 10 or 15 or 20 ) and then use the model just before overfitting starts! suppose overfitting starts at epoch 20, start the next cycle with model of epoch 15.

Thanks for this great tutorial. Only a doubt: you use rescale in ImageDataGenerator. Why not use the vgg16 function “preprocess_input”? I have seen that ImageDataGenerator permits a parameter to indicate preprocessing functions:

Join Course

Resources

Disclaimer

This site is not affiliated with OpenCV.org

I am an entrepreneur with a love for Computer Vision and Machine Learning with a dozen years of experience (and a Ph.D.) in the field.

In 2007, right after finishing my Ph.D., I co-founded TAAZ Inc. with my advisor Dr. David Kriegman and Kevin Barnes. The scalability, and robustness of our computer vision and machine learning algorithms have been put to rigorous test by more than 100M users who have tried our products. Read More…