一 实例描述
为一个线性回归的模型添加“保存检查点”功能。通过该功能,可以生成载入检查点文件,并能够指定生成检测点的个数。
二 代码
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
plotdata = { "batchsize":[], "loss":[] }
def moving_average(a, w=10):
if len(a) < w:
return a[:]
return [val if idx < w else sum(a[(idx-w):idx])/w for idx, val in enumerate(a)]
#生成模拟数据
train_X = np.linspace(-1, 1, 100)
train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.3 # y=2x,但是加入了噪声
#图形显示
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.legend()
plt.show()
tf.reset_default_graph()
# 创建模型
# 占位符
X = tf.placeholder("float")
Y = tf.placeholder("float")
# 模型参数
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")
# 前向结构
z = tf.multiply(X, W)+ b
#反向优化
cost =tf.reduce_mean( tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) #Gradient descent
# 初始化变量
init = tf.global_variables_initializer()
#参数设置
training_epochs = 20
display_step = 2
saver = tf.train.Saver(max_to_keep=1) # 生成saver
savedir = "log/"
# 启动session
with tf.Session() as sess:
sess.run(init)
# Fit all training data
for epoch in range(training_epochs):
for (x, y) in zip(train_X, train_Y):
sess.run(optimizer, feed_dict={X: x, Y: y})
#显示训练中的详细信息
if epoch % display_step == 0:
loss = sess.run(cost, feed_dict={X: train_X, Y:train_Y})
print ("Epoch:", epoch+1, "cost=", loss,"W=", sess.run(W), "b=", sess.run(b))
if not (loss == "NA" ):
plotdata["batchsize"].append(epoch)
plotdata["loss"].append(loss)
saver.save(sess, savedir+"linermodel.cpkt", global_step=epoch)
print (" Finished!")
print ("cost=", sess.run(cost, feed_dict={X: train_X, Y: train_Y}), "W=", sess.run(W), "b=", sess.run(b))
#print ("cost:",cost.eval({X: train_X, Y: train_Y}))
#图形显示
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')
plt.legend()
plt.show()
plotdata["avgloss"] = moving_average(plotdata["loss"])
plt.figure(1)
plt.subplot(211)
plt.plot(plotdata["batchsize"], plotdata["avgloss"], 'b--')
plt.xlabel('Minibatch number')
plt.ylabel('Loss')
plt.title('Minibatch run vs. Training loss')
plt.show()
#重启一个session
load_epoch=18
with tf.Session() as sess2:
sess2.run(tf.global_variables_initializer())
saver.restore(sess2, savedir+"linermodel.cpkt-" + str(load_epoch))
print ("x=0.2,z=", sess2.run(z, feed_dict={X: 0.2}))
with tf.Session() as sess3:
sess3.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(savedir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess3, ckpt.model_checkpoint_path)
print ("x=0.2,z=", sess3.run(z, feed_dict={X: 0.2}))
with tf.Session() as sess4:
sess4.run(tf.global_variables_initializer())
kpt = tf.train.latest_checkpoint(savedir)
if kpt!=None:
saver.restore(sess4, kpt)
print ("x=0.2,z=", sess4.run(z, feed_dict={X: 0.2}))
三 运行结果


四 说明
1 保存模型并不限制在训练之后,在训练之中也需要保存,因为TensorFlow训练模型时难免出现中断的情况。我们自然希望能够将辛苦得出的中间参数保留下来,否则下次又会重新开始。在这种训练中保存模型,习惯上称为保存检查点。
2 该实例与保存模型的功能类似,只是保存的位置发生了些变化,我们希望在显示信息时将检查点保存起来,于是就将保存位置放在了迭代训练中的打印信息后面。
3 另外,本例用到了Saver的另外一个参数——max_to_keep=1,表明最多只保存一个检查点文件。在保存时候使用如下代码传入迭代次数。
saver.save(sess, savedir+"linermodel.cpkt", global_step=epoch)
TensorFlow会将迭代次数一起放在检查点的名字上,所以在载入时,同样要指定迭代次数。
saver.restore(sess2, savedir+"linermodel.cpkt-" + str(load_epoch))
4 上面代码运行后,会看到log文件夹下多了几个linermodel.cpkt-18*文件,就是检查点文件。
这里使用tf.train.Saver(max_to_keep=1)代码创建saver时传入的参数max_to_keep=1,表示在迭代过程中只保存一个文件。这样,在循环训练中,新生成的模型就会覆盖以前的模型。
5 在代码的最后,提供了两种简单方法快速获取检查点文件。