Save and restore TensorFlow session

This blog post is written when TensorFlow was version r0.12.

Stackoverflow and the TensorFlow documentations are pretty clear that Saver is what you want, but it is less clear how to use it in the code. It boils down to this:

  1. Define model
  2. Create session
  3. Initialize variables
  4. If restoring, call saver.restore, passing in the session and path to the directory containing checkpoint files
  5. When saving, call saver.save, passing in the session and the path to the save directory
# ... code to define model omitted ...
# Initializing the variables
init = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init)
    if restore:
        ckpt = tf.train.get_checkpoint_state(save_path)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        # ... training code omitted ...
        saver.save(sess, save_path)

What if you want to train on one computer and restore it on a different computer? You will get an error: open the file named “checkpoint” with a text editor and edit the paths to the new path.

Here is a working example. Credits to the original author Aymeric Damien. I forked the code and added my edits to save/restore.

 

 

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s