写在前面
本文总结了Tensorflow使用过程中易混淆的一些接口,比如get_variable()和Variable()。name_scope()和variable_scope()等。
tf.get_varibale()/ tf.Variable()
tensorflow中关于variable的op有tf.get_variable()和tf.Variable两个.
tf.get_variable()
1 2 3 4 5 6 7 8 9 10 11 12
| tf.get_variable(name, shape=None, dtype=None, initializer=None, regularizer=None, trainable=True, collections=None, caching_device=None, partitioner=None, validate_shape=True, use_resource=None, custom_getter=None)
|
tf.Variable()
1 2 3 4 5 6 7 8 9 10
| tf.Variable(initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, variable_def=None, dtype=None, expected_shape=None, import_scope=None)
|
tf.get_variable()/tf.Variable()的区别
先看下面的两个例子:
1 2 3 4 5
| import tensorflow as tf a1 = tf.Variable(0, name="a1") a2 = tf.Variable(1, name='a1') print(a1.name) print(a2.name)
|
输出结果
1 2 3 4 5
| import tensorflow as tf a1 = tf.get_variable("a1", 0) a2 = tf.get_variable('a1', 0) print(a1.name) print(a2.name)
|
输出结果:
1
| ValueError: Variable a1 already exists, disallowed. Did you mean to set reuse=True in VarScope?
|
使用tf.Varibale()定义变量的时候,如果检测到命名冲突,系统会自动解决,但是使用tf.get_varibale()时,系统不会解决冲突,并且会报错。
所以如果需要共享变量则需要使用tf.get_variable()。在其他情况下两者的用法基本一样。
我们再来看一段代码:
1 2 3 4 5 6 7 8 9 10
| import tensorflow as tf
with tf.variable_scope("scope1"): w1 = tf.get_variable("w1", shape=[]) w2 = tf.Variable(0.0, name="w2") with tf.variable_scope("scope1", reuse=True): w1_p = tf.get_variable("w1", shape=[]) w2_p = tf.Variable(1.0, name="w2")
print(w1 is w1_p, w2 is w2_p)
|
输出结果:
从输出结果可以看出,对于get_variable(),来说,如果已经创建的变量对象,就把那个对象返回,如果没有创建变量对象的话,就创建一个新的。
而tf.Variable()每次都在创建新对象。
这里没有太多的提到共享变量的问题,
tf.name_scope()/ tf.variable_scope()
参考文献
tensorflow学习笔记(二十三):variable与get_variable
共享变量