一 实例
将模型里的内容打印出来,同时演示将指定的内容打印出来。

二 代码
import tensorflow as tf 
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file 
savedir = "log/" 
print_tensors_in_checkpoint_file(savedir+"linermodel.cpkt", None, True) 
W = tf.Variable(1.0, name="weight") 
b = tf.Variable(2.0, name="bias") 
# 放到一个字典里: 
saver = tf.train.Saver({'weight': b, 'bias': W}) 
with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
    saver.save(sess, savedir+"linermodel.cpkt") 
print_tensors_in_checkpoint_file(savedir+"linermodel.cpkt", None, True)
三 运行结果
tensor_name:  bias
[ 0.06552324]
tensor_name:  weight
[ 2.04334879]
tensor_name:  bias
1.0
tensor_name:  weight
2.0

四 运行说明
可以看到,tensor_name:后面跟的就是创建的变量名,接着是它的数值。
tf.train.Saver函数里还可以放参数来实现更高级的功能,可以指定存储变量与变量的对应关系。

例子中,W的值设置为1.0,b的值设置为2.0。在创建saver时,将它们颠倒,保存的模型打印出来之后可以看到,bias变成了1.0,而weight变成了2.0。

 

评论关闭
IT虾米网

微信公众号号:IT虾米 (左侧二维码扫一扫)欢迎添加!