Tensorflow Tip: Pretrain and Retrain

I recently ran into a situation where I had to initially train a neural network first on one dataset, save it and then load it up later to train it on a different dataset (or using a different training procedure). I implemented this in Tensorflow and thought I’d share a stripped down version of the script here as it could serve as an instructive example on the use of Tensorflow sessions. Note that this is not necessarily the best way of doing this, and it might indeed be simpler to load the original graph and train that graph itself by making its parameters trainable, or something else like that.

The script can be found here. In the first stage of this script (the pre-training stage) there is only a single graph which contains the randomly initialised and trained model. One might as well avoid explicitly defining a graph as Tensorflow’s default graph will be used for this purpose. This model (together with its parameters) is saved to a file and then loaded for the second re-training stage. In this second stage, there are two graphs. The first graph is loaded from the saved file and contains the pre-trained model whose parameters are the ones whose values we wish to assign to those of the second model before training the latter on a different dataset. The parameters of the second model are randomly initialised prior to this assignment step. In order for the assignment to work, I found it necessary to assign parameters across graphs and this could be done by saving the parameters of the first model as numpy tensors and assigning the values of these numpy tensors to the right parameters of the second model.

The Tensorflow Datasets API for Sequence Data (Code Examples)

This post was originally meant to be an entire tutorial (with a link to the GitHub repository) on how to use the Tensorflow Datasets API and how this contrasts with the placeholder approach for passing data into Tensorflow graphs that is generally more widely used. Unfortunately, I’m unable to set aside the time for writing about it in detail as I had originally intended, and thus I’m sharing the code with a few notes to help one make use of it.

First off, here is the link to the GitHub repository. It contains two main scripts – placeholder_vs_iterators.py and generator_vs_tfrecord.py. The first script implements three ways in which data can be passed into the Tensorflow graph. Note that in all cases this is sequence data. The first is the standard placeholder approach that most are familiar with. The second uses iterators and the third uses feedable iterators respectively to input data to the graph. The latter two are what I gathered to be the new methods to pass data into the graphs that the new Tensorflow Datasets API introduces. The script can be invoked with an integer command-line argument (1, 2 or 3) that chooses between the three approaches.

The second script is generator_vs_tfrecord.py. Having played around a bit with the two new data input approaches – iterators and feedable iterators, I decided to stick with the former in examining three different ways in which one can iterate through data while passing it into the graph during training. The first takes unbatched sequences via a generator function and applies certain standard preprocessing steps to it (zero-padding, batching, etc.) to it before using the data to train the model. The second approach begins with data that has been zero-padded and batched and passes that to the model via a generator function during training. The final approach first creates a Tensorflow Record file following the SequenceExample Protocol Buffer and reads sequences from this file, zero-pads and batches them before passing them to the graph during training. The third approach is what I would consider the most dependent on the Tensorflow Datasets API whereas the other two are to a greater extent reliant on Numpy. This script is also invoked with an integer command-line argument (1, 2 or 3) that chooses between the three approaches.

So there you have it! The rest of the code contains global constants for batch size, sequence length, etc. that can be changed as needed, basic training loops and a simple LSTM model to get the above examples working. I found, when I first got started with using this API around the time Tensorflow 1.4 was released, that there were few fully working examples that use it (and this is often what one is looking for when getting started with it) so I decided to share this code. One can realise more complex data pipelines using this handy API by building on these examples.

I refer the readers to the following useful links to understand more about the API, and Google Protocol Buffers while going through the code:

The Tensorflow Datasets API Blog Post

The Official Tensorflow Documentation

A blog post that helped me understand the SequenceExample Protocol Buffer format better.

A useful StackOverflow post on fixed and variable length features.

The definition of the Example and SequenceExample ProtoBufs