0
点赞
收藏
分享

微信扫一扫

TensorFlow 模型保存和恢复示例


前言

在之前一篇文章里:​​使用CNN+ Auto-Encoder 实现无监督Sentence Embedding (代码基于Tensorflow)​​,训练完成后,encode的参数也就被训练好了,这个时候我们利用这些参数对数据进行编码处理,从而得到一个向量。

保存模型

如果回忆下,上次的模型基本是这样的:

Input(段落) -> encoder -> encoder -> decoder -> decoder -> lost function (consine夹角)

我需要用到的是第二个encoder,在Tensorflow里,所有的都是Tensor,因此给定输入,就可以通过tensor给出输出。训练的过程中,涉及到第二个encoder的代码如下:

....
flattened = tf.reshape(conv_out, [-1, 51 * 128]) if USE_CNN else tf.reshape(input_x,
[-1, SEQUENCE_LENGTH * VOCAB_SIZE])

encoder_op = encoder(flattened)
....

我们真个训练过程其实是在tunning encoder的参数。现在我需要把encoder_op保留下来,供下次使用,这可以通过add_collection方法

tf.add_to_collection('encoder_op', encoder_op)

在​​sess.run(tf.global_variables_initializer())​​ 之后,我们获取Saver对象:

saver = tf.train.Saver()

然后在迭代的过程中,比如每迭代五次就保存一次模型:

if i %5 = 0: 
saver.save(sess, MODEL_SAVE_DIR)

恢复模型

sess = tf.Session()
## 这里是恢复graph
saver = tf.train.import_meta_graph(MODEL_SAVE_DIR + '/' + MODEL_NAME + '.meta')
## 这里是恢复各个权重参数
saver.restore(sess, tf.train.latest_checkpoint(MODEL_SAVE_DIR))


sess.run(tf.global_variables_initializer())
## 获取输入的tensor
input_x = tf.get_default_graph().get_tensor_by_name("input_x:0")
......

x_in = result1[0:SEQUENCE_LENGTH]
## 获取到encoder_op
encoder_op = tf.get_collection("encoder_op")[0]
## 给定数据,运行encoder_op
s = sess.run(encoder_op, feed_dict={input_x: [x_in]})

具体的解释已经在代码中提及。这样我们就可以利用encoder_op对新数据进行编码了。

完整的恢复模型参看:​​tensorflow_restore.py​​

额外的话

参考资料:

​​A quick complete tutorial to save and restore Tensorflow models​​

在该参考资料中,你还可以看到多种保存和使用tensor的方式。另外除了保存模型以外,还有 ​​tf.summary.FileWriter​

train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

使用它可以让你通过tensorbord 查看训练和运行情况。

举报

相关推荐

0 条评论