一 实例描述
为一个线性回归的模型添加“保存检查点”功能。通过该功能,可以生成载入检查点文件,并能够指定生成检测点的个数。

二 代码
import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
plotdata = { "batchsize":[], "loss":[] } 
def moving_average(a, w=10): 
    if len(a) < w: 
        return a[:]     
    return [val if idx < w else sum(a[(idx-w):idx])/w for idx, val in enumerate(a)] 
#生成模拟数据 
train_X = np.linspace(-1, 1, 100) 
train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.3 # y=2x,但是加入了噪声 
#图形显示 
plt.plot(train_X, train_Y, 'ro', label='Original data') 
plt.legend() 
plt.show() 
tf.reset_default_graph() 
# 创建模型 
# 占位符 
X = tf.placeholder("float") 
Y = tf.placeholder("float") 
# 模型参数 
W = tf.Variable(tf.random_normal([1]), name="weight") 
b = tf.Variable(tf.zeros([1]), name="bias") 
# 前向结构 
z = tf.multiply(X, W)+ b 
#反向优化 
cost =tf.reduce_mean( tf.square(Y - z)) 
learning_rate = 0.01 
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) #Gradient descent 
# 初始化变量 
init = tf.global_variables_initializer() 
#参数设置 
training_epochs = 20 
display_step = 2 
saver = tf.train.Saver(max_to_keep=1) # 生成saver 
savedir = "log/" 
# 启动session 
with tf.Session() as sess: 
    sess.run(init) 
    # Fit all training data 
    for epoch in range(training_epochs): 
        for (x, y) in zip(train_X, train_Y): 
            sess.run(optimizer, feed_dict={X: x, Y: y}) 
        #显示训练中的详细信息 
        if epoch % display_step == 0: 
            loss = sess.run(cost, feed_dict={X: train_X, Y:train_Y}) 
            print ("Epoch:", epoch+1, "cost=", loss,"W=", sess.run(W), "b=", sess.run(b)) 
            if not (loss == "NA" ): 
                plotdata["batchsize"].append(epoch) 
                plotdata["loss"].append(loss) 
            saver.save(sess, savedir+"linermodel.cpkt", global_step=epoch) 
                 
    print (" Finished!") 
     
    print ("cost=", sess.run(cost, feed_dict={X: train_X, Y: train_Y}), "W=", sess.run(W), "b=", sess.run(b)) 
    #print ("cost:",cost.eval({X: train_X, Y: train_Y})) 
    #图形显示 
    plt.plot(train_X, train_Y, 'ro', label='Original data') 
    plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line') 
    plt.legend() 
    plt.show() 
     
    plotdata["avgloss"] = moving_average(plotdata["loss"]) 
    plt.figure(1) 
    plt.subplot(211) 
    plt.plot(plotdata["batchsize"], plotdata["avgloss"], 'b--') 
    plt.xlabel('Minibatch number') 
    plt.ylabel('Loss') 
    plt.title('Minibatch run vs. Training loss') 
      
    plt.show() 
     
#重启一个session     
load_epoch=18     
with tf.Session() as sess2: 
    sess2.run(tf.global_variables_initializer())      
    saver.restore(sess2, savedir+"linermodel.cpkt-" + str(load_epoch)) 
    print ("x=0.2,z=", sess2.run(z, feed_dict={X: 0.2})) 
     
with tf.Session() as sess3: 
    sess3.run(tf.global_variables_initializer()) 
    ckpt = tf.train.get_checkpoint_state(savedir) 
    if ckpt and ckpt.model_checkpoint_path: 
        saver.restore(sess3, ckpt.model_checkpoint_path) 
        print ("x=0.2,z=", sess3.run(z, feed_dict={X: 0.2})) 
with tf.Session() as sess4: 
    sess4.run(tf.global_variables_initializer()) 
    kpt = tf.train.latest_checkpoint(savedir) 
    if kpt!=None: 
        saver.restore(sess4, kpt) 
        print ("x=0.2,z=", sess4.run(z, feed_dict={X: 0.2}))
三 运行结果
四 说明
1 保存模型并不限制在训练之后,在训练之中也需要保存,因为TensorFlow训练模型时难免出现中断的情况。我们自然希望能够将辛苦得出的中间参数保留下来,否则下次又会重新开始。在这种训练中保存模型,习惯上称为保存检查点。
2 该实例与保存模型的功能类似,只是保存的位置发生了些变化,我们希望在显示信息时将检查点保存起来,于是就将保存位置放在了迭代训练中的打印信息后面。
3 另外,本例用到了Saver的另外一个参数——max_to_keep=1,表明最多只保存一个检查点文件。在保存时候使用如下代码传入迭代次数。
saver.save(sess, savedir+"linermodel.cpkt", global_step=epoch)
TensorFlow会将迭代次数一起放在检查点的名字上,所以在载入时,同样要指定迭代次数。
saver.restore(sess2, savedir+"linermodel.cpkt-" + str(load_epoch))
4 上面代码运行后,会看到log文件夹下多了几个linermodel.cpkt-18*文件,就是检查点文件。
这里使用tf.train.Saver(max_to_keep=1)代码创建saver时传入的参数max_to_keep=1,表示在迭代过程中只保存一个文件。这样,在循环训练中,新生成的模型就会覆盖以前的模型。
5 在代码的最后,提供了两种简单方法快速获取检查点文件。

评论关闭
IT虾米网

微信公众号号:IT虾米 (左侧二维码扫一扫)欢迎添加!