备份与恢复模型
基本操作
1 2 3 4 5 6
| saver = tf.train.Saver() with tf.Session() as sess: saver.save(sess, saved_model_path)
with tf.Session() as sess: saver.restore(sess, saved_model_path)
|
部分恢复模型
首先通过 tf.trainable_variables () 获得新模型的所有可训练参数,然后把新增参数剔除。新建一个 saver 对象,并把要恢复的参数传进去。
1 2 3 4 5 6 7
| restore_variable = tf.trainable_variables()
<do some modified jobs here>
saver = tf.train.Saver(modified_restore_variable) with tf.Session() as sess: saver.restore(sess, model_path)
|