一 实例
介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。
演示使用MonitoredTrainingSession函数来自动管理检查点文件。
二 代码
import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
print(sess.run([global_step]))
while not sess.should_stop():
i = sess.run( step)
print( i)
三 运行结果
1 第一次运行后,会发现log文件夹下产生如下文件

2 第二次运行后,结果如下:
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159
四 说明
本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。
可见程序自动载入检查点是从第15147次开始运行的。
五 注意
1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。
2 使用该方法,必须要定义global_step变量,否则会报错误。