Tensorflow中的共享变量机制

写在前面

本文总结了Tensorflow使用过程中易混淆的一些接口,比如get_variable()和Variable()。name_scope()和variable_scope()等。

tf.get_varibale()/ tf.Variable()

tensorflow中关于variable的op有tf.get_variable()和tf.Variable两个.

tf.get_variable()

1
2
3
4
5
6
7
8
9
10
11
12
tf.get_variable(name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None)

tf.Variable()

1
2
3
4
5
6
7
8
9
10
tf.Variable(initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None)

tf.get_variable()/tf.Variable()的区别

先看下面的两个例子:

1
2
3
4
5
import tensorflow as tf 
a1 = tf.Variable(0, name="a1")
a2 = tf.Variable(1, name='a1')
print(a1.name)
print(a2.name)

输出结果

1
2
a1:0
a1_1:0
1
2
3
4
5
import tensorflow as tf
a1 = tf.get_variable("a1", 0)
a2 = tf.get_variable('a1', 0)
print(a1.name)
print(a2.name)

输出结果:

1
ValueError: Variable a1 already exists, disallowed. Did you mean to set reuse=True in VarScope?

使用tf.Varibale()定义变量的时候,如果检测到命名冲突,系统会自动解决,但是使用tf.get_varibale()时,系统不会解决冲突,并且会报错
所以如果需要共享变量则需要使用tf.get_variable()。在其他情况下两者的用法基本一样。

我们再来看一段代码:

1
2
3
4
5
6
7
8
9
10
import tensorflow as tf

with tf.variable_scope("scope1"):
w1 = tf.get_variable("w1", shape=[])
w2 = tf.Variable(0.0, name="w2")
with tf.variable_scope("scope1", reuse=True):
w1_p = tf.get_variable("w1", shape=[])
w2_p = tf.Variable(1.0, name="w2")

print(w1 is w1_p, w2 is w2_p)

输出结果:

1
(True, False)

从输出结果可以看出,对于get_variable(),来说,如果已经创建的变量对象,就把那个对象返回,如果没有创建变量对象的话,就创建一个新的。
而tf.Variable()每次都在创建新对象。
这里没有太多的提到共享变量的问题,

tf.name_scope()/ tf.variable_scope()

参考文献

tensorflow学习笔记(二十三):variable与get_variable
共享变量