TensorFlow 常用技巧收集

备份与恢复模型

基本操作

1
2
3
4
5
6
saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess, saved_model_path)

with tf.Session() as sess:
saver.restore(sess, saved_model_path)

部分恢复模型

首先通过 tf.trainable_variables () 获得新模型的所有可训练参数,然后把新增参数剔除。新建一个 saver 对象,并把要恢复的参数传进去。

1
2
3
4
5
6
7
restore_variable = tf.trainable_variables()

<do some modified jobs here>

saver = tf.train.Saver(modified_restore_variable)
with tf.Session() as sess:
saver.restore(sess, model_path)